猫になりたい

コンサルのデータ分析屋、計量経済とか機械学習をやっています。pyてょnは3.7を使ってマスコレルウィンストングリーン。

Deep Learning:ResNetの解説とpytorchによる実装

前回の記事(VGG16をkerasで実装した)の続きです。
今回はResNetについてまとめた上でpytorchを用いて実装します。

ResNetとは

性能

ResNetとはILSVRC & COCO 2015 Competitionsに於ける以下の主要な5タスク、 ImageNet Classification、ImageNet Detection、ImageNet Localization、COCO Detection、COCO Segmentation の全てで1位を取ったDeep learning / Deep neural net(以下DNN)のモデルです。 聞く限り、発表から4年経った2019年現在でもResNetは画像認識のモデルのベースラインとされているようでありその性能は折り紙付きと言えます。 ではResNetの何が新しかったのでしょうか。

新規性

Resnet の新規性は後に述べるresidual learningを導入したことにより、

  1. DNNのレイヤーの数が多くなってもきちんと学習ができ
  2. レイヤーを増やしたことにより精度が向上した

の2点です。

DNNが流行り始めてから、DNNのレイヤー数は多いほうが性能が良いことが知らてきていました。 レイヤーを増やすと勾配が消失し、学習が進まないという問題(勾配消失問題)はあったものの、normalized-initializationやintermediate normalization layersによりこの問題は概ね解決しました*1。 しかし勾配消失問題が解決していざ学習が進むように為ると、今度は深いDNNの方が浅いDNNより性能が悪くなるdegradation problem(劣化問題)が現れました。 この劣化問題をresidual learningで解決したのがResNetです。

ResNetのアイディア

ではresidual learningと何なのでしょうか。
通常DNNの各レイヤーでは入力 xと出力 yが与えられたときに、 $$ \begin{align} y = H(x) \end{align} $$ となる関数 H(x)を学習させます。  H(x)は単一のレイヤーである必要はなく、複数のレイヤーをまとめたものと捉えても良いです。

一方、residual learningでは上式を

$$ \begin{align} y &= H(x) \\ &= H(x) - x + x \\ \Leftrightarrow y - x &= H(x) - x \\ &:= F(x) \end{align} $$

と変形して

$$ \begin{align} F(x) = H(x) - x \end{align} $$

という関数 F(x)を学習させます。Residual learning(残差学習)という名前は F(x) = y - x が各レイヤーのインプットとアウトプットの差を表していることに由来しているようです(線形回帰の文脈で言う残差とは別物であることに注意)。

これを図で表すと以下の様になります(原論文Fig. 2を改変)。

f:id:shikiponn:20190524182517p:plain:w600
Figure 2
この図では2枚のconv. レイヤーがF(x)に相当します。 また、F(x)への入力 xF(x)をショートカットして最後に足し合わされることから、この処理を特にshortcut connectionと呼びます。 このF(x)とshortcut connectionを合わせて、building blockやresidual blockと呼び、ResNetはこのResidual blockを複数積み上げて構成していきます。

Residual learningが深いDNNの学習を可能にすると著者が考えた理屈は以下の通りです。

  1. あるDNNのモデルAとそれより深いモデルBがあった時、より深いモデルであるBの方が表現能力が高いのでBのtraining error はAのtraining error以下であるはずだが実際はそうはならない(劣化問題)
  2. AにはなくてBにのみあるレイヤーが全て恒等写像になっていれば少なくともAとBのtraining errorは同じに為るはずなのに、そうならないのは恒等写像を学習するのが難しいからだ
  3. ならば恒等写像を学習しやすいモデルを使用しよう

Residual learningであるレイヤーに於ける入力 xと出力 H(x)が恒等写像の関係になる様に学習をさせたければ、 H(x) = F(x) + x  F(x)が常に 0を出力するように学習すればいいだけです。複雑に非線形関数を重ねているDNNで一から恒等写像を学習することに比べるとそれよりは簡単に恒等写像が学習できそうです。 実際、著者達がCIFAR10を用いて F(x)、つまり畳み込み層の出力を見てみると通常のCNNと比べて小さい値を取ることがわかっています。

ResNetはこのresidual learningを導入することにより劣化問題を回避し、以前より多くのレイヤーを持つモデルを学習できるようになりました。

Bottleneck Architectureによる更なる深化

Residual blockを実装するにあたって、筆者たちは最終的にFigure 2とは異なりF(x)をレイヤー3層で構成することで更にレイヤーの数を増やしました。 そしてこのレイヤー3層からなるresidual blockをbottleneck architectureと名付けました。 以下の図はその構成です(原論文Fig. 5より)。

f:id:shikiponn:20190603173844p:plain:w600
Figure 5

Bottleneck architectureを導入する目的は、一旦1x1 conv.でchannel方向の次元を削減してから畳み込みを行って、再度1x1で今度は次元を復元することで計算量を抑えつつレイヤー数を増やすことです。これにより計算量を殆ど変えずに residual blockあたり1レイヤー数を1枚増やすことに成功しています。 bootleneckという名前は、いったんchannel次元を削減してから復元していることに由来しています。 逆に言えば計算リソースが確保出来るのであればbottleneck構造を使う必要はないと思われます。

また1つのresidual blockに挟むレイヤー数は論文では3枚になっていますが、この数字自体には特に理由は無いようです。 但しレイヤー数が1枚だけの場合はresidual learningを使う効果はなかったと著者は書いていますし、あまり多すぎてもresidual learningの恩恵を受けられないでしょう。

Shortcut connectionの実装方法

Residual block内のレイヤーの出力 F(x)にshortcut connectionのxを加える操作はelement-wiseつまり、各チャンネルの画素 (w_i, h_i, C_i)毎に行われます。 しかし畳み込みの中でchannelの次元数cを増やして (w, h)を小さくすると F(x)xの次元が一致しないため F(x) + xという操作が行なえません。 このような場合に F(x)xの次元を揃えるために行うのがprojection shortcutで、実装上は1x1のconv.で行っています。

尚、このprojection shortcutはF(x)xの次元数が同じ箇所にも導入することも可能ですが、計算量が増える割には性能向上に寄与しなかったため原論文では入出力の次元が変わるresidual blockでのみ採用されています。


スポンサーリンク

実装と評価

ResNetの論文について概観したので実装と評価をしてみましょう。

原論文との差異

基本的には原論文のImageNetを使った実験を踏襲していますが、以下の様な差異があります。

  • 計算時間の都合からepoch数は40に設定(論文ではCIFAR10は182程、ImageNetは120 epoch程学習させている)。
  • データはImageNetの代わりにCIFAR10を用いた
    • CIFAR10のクラス数は10なのでfc層のサイズが変わっている
    • CIFAR10の解像度はImageNetより小さいので最初にアップサンプリング用のレイヤーを入れてアップスケールしている
    • CIFAR10の画像はサイズが揃っているので、croppingはしなかった
    • バッチサイズはCIFAR10を用いた実験に合わせて128としている
    • CIFAR10の実験に合わせてtrain dataを45k枚のtrainと5k枚のvalに分割している
  • Per-pixel mean subtractionは行っていない
    (行うと学習が不安定になり進まなかった)*2
  • 学習率は0.01から初めてvalidation lossのエラーが減少しなくなったら \times \frac{1}{10}する
    (論文通り0.1から始めると学習が不安定になり進まなかった)

実装

今回は比較的パラメータの少ないResNet50と呼ばれる、レイヤー数が50のResNetを実装しました。 Residual blockを構成するBottleneckクラスをResNet50クラスの中で積み上げる形で実装しています。 実装の際には必要なstrideの設定などの情報が論文からは抜けていたので、 その様な箇所は著者の実装[1]と、pytorchとkeras公式のresnet実装[2, 3] を参考にしました。

ちなみにResNetの実装はkerasとpytochの公式それぞれで微妙に異なっており、 keras公式の実装では畳み込み層にバイアス項がありますが、pytorch公式の実装にはバイアス項が無いという違いがあります。 原論文の実装ではバイアス項があるのでpytorchの実装が何故そうなっているかは不明です。*3 本実装ではバイアス項は入れてあります。


import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau


class Bottleneck(nn.Module):
    """
    Bottleneckを使用したresidual blockクラス
    """
    def __init__(self, indim, outdim, is_first_resblock=False):
        super(Bottleneck, self).__init__()
        self.is_dim_changed = (indim != outdim)
        # W, Hを小さくしてCを増やす際はstrideを2にする +
        # projection shortcutを使う様にセット
        if self.is_dim_changed:
            if is_first_resblock:
                # 最初のresblockは(W、 H)は変更しないのでstrideは1にする
                stride = 1
            else:
                stride = 2
            self.shortcut = nn.Conv2d(indim, outdim, 1, stride=stride)
        else:
            stride = 1
        
        dim_inter = int(outdim / 4)
        self.conv1 = nn.Conv2d(indim, dim_inter , 1)
        self.bn1 = nn.BatchNorm2d(dim_inter)
        self.conv2 = nn.Conv2d(dim_inter, dim_inter, 3,
                               stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(dim_inter)
        self.conv3 = nn.Conv2d(dim_inter, outdim, 1)
        self.bn3 = nn.BatchNorm2d(outdim)
        self.relu = nn.ReLU(inplace=True)
        

    def forward(self, x):
        shortcut = x
  
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # Projection shortcutの場合
        if self.is_dim_changed:
            shortcut = self.shortcut(x)

        out += shortcut
        out = self.relu(out)

        return out


class ResNet50(nn.Module):
    
    def __init__(self): 
          
        super(ResNet50, self).__init__()
        
        # Due to memory limitation, images will be resized on-the-fly.
        self.upsampler = nn.Upsample(size=(224, 224))

        # Prior block
        self.layer_1 = nn.Conv2d(3, 64, 7, padding=3, stride=2)
        self.bn_1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2, 2)
        
        # Residual blocks
        self.resblock1 = Bottleneck(64, 256, True)
        self.resblock2 = Bottleneck(256, 256)
        self.resblock3 = Bottleneck(256, 256)
        self.resblock4 = Bottleneck(256, 512)
        self.resblock5 = Bottleneck(512, 512)
        self.resblock6 = Bottleneck(512, 512)
        self.resblock7 = Bottleneck(512, 512)
        self.resblock8 = Bottleneck(512, 1024)
        self.resblock9 = Bottleneck(1024, 1024)
        self.resblock10 =Bottleneck(1024, 1024)
        self.resblock11 =Bottleneck(1024, 1024)
        self.resblock12 =Bottleneck(1024, 1024)
        self.resblock13 =Bottleneck(1024, 1024)
        self.resblock14 =Bottleneck(1024, 2048)
        self.resblock15 =Bottleneck(2048, 2048)
        self.resblock16 =Bottleneck(2048, 2048)
        
        # Postreior Block
        self.glob_avg_pool = nn.AdaptiveAvgPool2d((1, 1))        
        self.fc = nn.Linear(2048, 10)

    def forward(self, x):
        x = self.upsampler(x)
        
        # Prior block
        x = self.relu(self.bn_1(self.layer_1(x)))
        x = self.pool(x)
        
        # Residual blocks
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)
        x = self.resblock5(x)
        x = self.resblock6(x)
        x = self.resblock7(x)
        x = self.resblock8(x)
        x = self.resblock9(x)
        x = self.resblock10(x)
        x = self.resblock11(x)
        x = self.resblock12(x)
        x = self.resblock13(x)
        x = self.resblock14(x)
        x = self.resblock15(x)
        x = self.resblock16(x)
        
        # Postreior Block
        x = self.glob_avg_pool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

評価

環境

Google colabolatoryのGPUを使用しました(Tesla T4)。
Pythonとpytorchのバージョンは以下の通りです。

  • python: 3.6.7
  • pytorch: 1.1.0

データの用意

実験ではImageNetの代わりにCIFAR10を使用しました。
CIFAR10の公式からCIFAR-10 python versionをDLし、 Googleドライブ上の"./drive/My Drive/Colab Notebooks/dataset/に展開しました。

展開したらデータを読み込みましょう。

import os

from keras.utils import np_utils
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from PIL import Image
from tqdm import tqdm_notebook as tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms


def load_data(path):
    """
    Load CIFAR10 data
    Reference:
      https://www.kaggle.com/vassiliskrikonis/cifar-10-analysis-with-a-neural-network/data

    """
    def _load_batch_file(batch_filename):
        filepath = os.path.join(path, batch_filename)
        unpickled = _unpickle(filepath)
        return unpickled

    def _unpickle(file):
        import pickle
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='latin')
        return dict

    train_batch_1 = _load_batch_file('data_batch_1')
    train_batch_2 = _load_batch_file('data_batch_2')
    train_batch_3 = _load_batch_file('data_batch_3')
    train_batch_4 = _load_batch_file('data_batch_4')
    train_batch_5 = _load_batch_file('data_batch_5')
    test_batch = _load_batch_file('test_batch')

    num_classes = 10
    batches = [train_batch_1['data'], train_batch_2['data'], train_batch_3['data'], train_batch_4['data'], train_batch_5['data']]
    train_x = np.concatenate(batches)
    train_x = train_x.astype('float32') # this is necessary for the division below

    train_y = np.concatenate([np_utils.to_categorical(labels, num_classes) for labels in [train_batch_1['labels'], train_batch_2['labels'], train_batch_3['labels'], train_batch_4['labels'], train_batch_5['labels']]])
    test_x = test_batch['data'].astype('float32') #/ 255
    test_y = np_utils.to_categorical(test_batch['labels'], num_classes)
    print(num_classes)
   
    img_rows, img_cols = 32, 32
    channels = 3
    print(train_x.shape)
    train_x = train_x.reshape(len(train_x), channels, img_rows, img_cols)
    test_x = test_x.reshape(len(test_x), channels, img_rows, img_cols)
    train_x = train_x.transpose((0, 2, 3, 1))
    test_x = test_x.transpose((0, 2, 3, 1))
    per_pixel_mean = (train_x).mean(0) # 計算はするが使用しない

    train_x = [Image.fromarray(img.astype(np.uint8)) for img in train_x]
    test_x = [Image.fromarray(img.astype(np.uint8)) for img in test_x]
    
    train = [(x,np.argmax(y)) for x, y in zip(train_x, train_y)]
    test = [(x,np.argmax(y)) for x, y in zip(test_x, test_y)]
    return train, test, per_pixel_mean


class ImageDataset(Dataset):
    """
    データにtransformsを適用するためのクラス
    """
    def __init__(self, data, transform=None):

        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        
        if self.transform:
            img = self.transform(img)

        return img, label


# Googleドライブのマウント
from google.colab import drive
drive.mount('./drive')

BATCH_SIZE = 128
path = "./drive/My Drive/Colab Notebooks/dataset/cifar-10-batches-py/"
train, test = load_data(path)


# train dataの作成
train_transform = torchvision.transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Lambda(lambda img: np.array(img)),
    transforms.ToTensor(),
    transforms.Lambda(lambda img: img.float()),
])
train_dataset = ImageDataset(train[:45000], transform=train_transform)
trainloader = DataLoader(train_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=0)


# validation data, test dataの作成
valtest_transform = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda img: np.array(img)),
    transforms.ToTensor(),
    transforms.Lambda(lambda img: img.float()),    
    ])
valid_dataset = ImageDataset(train[45000:], transform=valtest_transform)
validloader = DataLoader(valid_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=0)

test_dataset = ImageDataset(test, transform=valtest_transform)
testloader = DataLoader(test_dataset,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        num_workers=0)

画像の確認

データセットがきちんとロードできたかサンプル画像を出力してみます。

def imshow(img):
    """
    functions to show an image
    """
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.numpy().shape)
# show images
imshow(torchvision.utils.make_grid(images ))

以上のコードを実行して次の様に画像が出力されれば大丈夫です。

f:id:shikiponn:20190531131338p:plain:w300
sample images

学習

それでは学習を走らせてみましょう。 学習用のパラメータはepoch数を除いて論文の通りです。 但し、ReduceLROnPlateaupatienceの様に論文に書かれていないパラメータは適当に決めています。

def validate(net, validloader):
  """
  epoch毎に性能評価をするための関数
  """
  net.eval()
  correct = 0
  total = 0
  preds = torch.tensor([]).float().to(device)
  trues = torch.tensor([]).long().to(device)
  with torch.no_grad():
      for data in validloader:
          images, labels = data
          images, labels = images.to(device), labels.to(device)

          outputs = net(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
           
          preds = torch.cat((preds, outputs))
          trues = torch.cat((trues, labels))
      val_loss = criterion(preds, trues)
      err_rate = 100 * (1 - correct / total)
 
  return val_loss, err_rate


# 学習用に必要なインスタンスを作成
net = ResNet50().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01,
                      momentum=0.9, weight_decay=0.0001)


scheduler = ReduceLROnPlateau(
                              optimizer,
                              mode='min',
                              factor=0.1,
                              patience=10,
                              verbose=True
                             )


# ロギング用のリスト
log = {'train_loss':[],
       'val_loss': [],
       'train_err_rate': [],
       'val_err_rate': []}

N_EPOCH = 40


# 学習を実行
for epoch in tqdm(range(N_EPOCH)):
    net.train()
    for i, data in tqdm(enumerate(trainloader, 0)):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # epoch内でのlossを確認
        if i % 100 == 0:
            print(loss)
         
    else:
        # trainとvalに対する指標を計算
        train_loss, train_err_rate = validate(net, trainloader)
        val_loss, val_err_rate = validate(net, validloader)
        log['train_loss'].append(train_loss.item())
        log['val_loss'].append(val_loss.item())
        log['val_err_rate'].append(val_err_rate)
        log['train_err_rate'].append(train_err_rate)
        print(loss)
        print(f'train_err_rate:\t{train_err_rate:.1f}')
        print(f'val_err_rate:\t{val_err_rate:.1f}')
        scheduler.step(val_loss)
else:
    print('Finished Training')

結果と考察

Testセットに対する性能は以下の表のようにprecision、recall共に0.89となりました。
また40 epochの学習にかかった時間は7時間9分でした。

学習に於けるerror rateとlossの推移は以下の通りです。

f:id:shikiponn:20190604001543p:plain:w300
Error rate

f:id:shikiponn:20190604001558p:plain:w300
Loss

前回のVGGがBatchNormalizationを入れて74epoch回してようやく0.81だったことを考えると40epochでこの精度は圧倒的です。今回は計算時間の都合からepoch数を抑えましたが、epoch数を増やせばもっと精度は見込めると思われます。

また今回学習に7時間かかりましたが以前の調査通りであれば、kerasで同じ実験をしていればざっくり14時間は学習にかかっていたことになります。これだけ学習時間が変わると為るとやはり今後はpytorchによる実装の方がkerasのより良いのではないかと改めて思いました。

              precision    recall  f1-score   support

           0       0.89      0.91      0.90      1000
           1       0.94      0.95      0.94      1000
           2       0.85      0.83      0.84      1000
           3       0.80      0.77      0.79      1000
           4       0.86      0.89      0.87      1000
           5       0.83      0.83      0.83      1000
           6       0.90      0.93      0.91      1000
           7       0.93      0.91      0.92      1000
           8       0.95      0.94      0.94      1000
           9       0.94      0.92      0.93      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000

array([[912,   5,  21,   8,   9,   1,   2,   6,  23,  13],
       [  8, 949,   0,   2,   1,   0,   0,   2,   7,  31],
       [ 24,   1, 833,  23,  48,  27,  32,   6,   5,   1],
       [ 11,   2,  38, 772,  35,  93,  32,  12,   4,   1],
       [  6,   2,  26,  20, 885,  18,  19,  21,   2,   1],
       [  6,   3,  15,  96,  21, 828,  11,  18,   0,   2],
       [  4,   0,  26,  21,   9,   7, 928,   3,   2,   0],
       [  7,   0,  12,  16,  22,  27,   0, 914,   0,   2],
       [ 27,  10,   5,   2,   1,   0,   3,   1, 939,  12],
       [ 14,  39,   3,   4,   0,   0,   2,   5,  10, 923]])

スポンサーリンク

まとめ

以上ResNetの論文についてまとめた上で、ResNet50をpytorchで実装しました。
CIFAR10を用いた実験ではVGG16よりも少ないepoch数で高い精度を達成できることが確認できました。

一方で学習時間については、前回のkerasによるVGG16の学習時間が74 epochで1時間ほどだったのに比べて、pytorchによるResNet50は40 epochで7時間かかることが分かりました。 以前の実験から、pytorchはkerasの2倍程度高速であると見積もるとRexNet50の学習はVGG16の14倍程度かかると言えそうです。
ResNetを使うときは一から学習させるのではなく、fine tuningを行うのが良さそうです。

また今回の実験ではImageNet用のResNet50をCIFAR10に無理やり適用したためか、per-pixel mean subtractionを適用したり学習率を0.1にすると学習が上手くいかない現象も確認できました。 今後、画像の前処理の方法やlearning rateの決め方についてもう少し調べられればと思います。

参考文献

  1. ResNet原論文:[1409.1556] Very Deep Convolutional Networks for Large-Scale Image Recognition
  2. 原論文著者によるcaffeのprototxt: deep-residual-networks/ResNet-50-deploy.prototxt at master · KaimingHe/deep-residual-networks · GitHub
  3. pytorch公式のResNet実装:vision/resnet.py at master · pytorch/vision · GitHub
  4. keras公式のResNet実装: keras-applications/resnet50.py at master · keras-team/keras-applications · GitHub
  5. 著者等のILSVRC 2015に於ける発表: http://image-net.org/challenges/talks/ilsvrc2015_deep_residual_learning_kaiminghe.pdf

*1:論文には書かれていないがBatchnormalization, dropout, ReLUの貢献もあった

*2:なんでper-pixel mean subtractionを行うと学習が上手くいかない(lossがnanになったり、収束しなかったりする)かは調査したのですが分かりませんでした。

*3:正確に言うと原論文著者のResNet50の実装にはバイアス項はあるが、ResNet101とResNet152はバイアス項があったり無かったりする。