Batch Normalizationの実験

Practical Deep Learning For Coders, Part 1の lesson3.ipynb

Batchnorm resolves this problem by normalizing each intermediate layer as well. The details of how it works are not terribly important (although I will outline them in a moment) - the important takeaway is that all modern networks should use batchnorm, or something equivalent. There are two reasons for this:

Adding batchnorm to a model can result in 10x or more improvements in training speed Because normalization greatly reduces the ability of a small number of outlying inputs to over-influence the training, it also tends to reduce overfitting.

と、Batch Normalizationを使うことが強く推奨されていますが MNISTでBatch Normalization(BN)を使ってみたところ、あまり変わりませんでした。

というか、BN使ったほうが微妙に劣っていました。。

lr=0.001で1epoch学習後にlr=0.0001でさらに8epoch学習

loss acc val_loss val_acc
BNなし 0.0133 0.9957 0.0205 0.9936
BNあり 0.0158 0.9948 0.0369 0.9889

というわけで、MNISTとCIFAR10のデータセットを使って少し実験してみました。

Batch Normalization

日本語で概要をつかむには以下の記事が分かりやすかったです。 deepage.net

Batch Normalizationでの逆伝搬の仕組みについてはいろいろな記事で Understanding the backward pass through Batch Normalization Layer がよく引用されていました。

MNIST

モデルは以下のように定義しています。

model = Sequential([
    Lambda(norm_input, input_shape=(1,28,28)),
    ZeroPadding2D((1, 1)),
    Convolution2D(32, 3, 3, activation='relu'),
    MaxPooling2D(),
    ZeroPadding2D((1, 1)),
    Convolution2D(64, 3, 3, activation='relu'),  
    MaxPooling2D(),
    Flatten(),
    Dense(512, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
    ])

Batch Normalizationを使うと大きい学習係数を使えると言われているので学習率を10倍にしてみました。

lr=0.01で1epoch学習後にlr=0.001でさらに8epoch学習 (学習率を10倍にする)

loss acc val_loss val_acc
BNなし 0.2219 0.9352 0.1436 0.9582
BNあり 0.0507 0.9886 0.0567 0.9877

BNなしだと 93.5% 程度で止まってしまったのに対して BNありだと 98.8% まで学習が進みました!

確かに効果が確認できました。

CIFAR10

続いてMNISTより難しい分類問題のCIFAR10です。

モデルは Kerasのチュートリアルにあるコードを参考に定義しました。

model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=x_train.shape[1:]))
model.add(Convolution2D(32, 3, 3, activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(Dropout(0.25))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, 3, 3, activation='relu'))
model.add(Convolution2D(64, 3, 3, activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

50epoch学習した結果

loss acc val_loss val_acc
lr=0.0001、BNなし 0.8875 0.7005 0.8007 0.7398
lr=0.0001、BNあり 0.6417 0.7824 0.7124 0.7671
lr=0.001、BNなし 2.1267 0.2226 2.0676 0.2230
lr=0.001、BNあり 0.5938 0.8006 0.5055 0.8264

MNISTよりも差が顕著に出ました。 学習係数を大きくした lr = 0.001 だと BNなしでは発散してしまっています。

一方、BNありは学習係数を大きくしても学習が進んでいき、 学習係数の大きいlr=0.001 + BNありが最もよい結果になりました。

FC層のみをファインチューニングする場合

Practical Deep Learning For Coders, Part 1の中で幾度となく登場する あらかじめ学習したモデルで最終Conv層の出力を保存しておき、FC層のみをファインチューニングする方法 でもやってみました。

MNIST

FC層のみをlr=0.0001で8epoch学習(lr=0.001で1epoch学習したモデルでconv層出力を計算しておく)

loss acc val_loss val_acc
lr=0.0001, BNなし 0.0170 0.9949 0.0226 0.9918
lr=0.0001, BNあり 0.0144 0.9955 0.0237 0.9923

FC層のみをlr=0.001で8epoch学習(lr=0.01で1epoch学習したモデルでconv層出力を計算しておく)

loss acc val_loss val_acc
lr=0.001, BNなし 0.0664 0.9790 0.0663 0.9801
lr=0.001, BNあり 0.0765 0.9756 0.0704 0.9794

なんと、学習係数の大きい lr=0.001でも BNありと BNなし で差が見られなくなってしまいました。。

CIFAR10

FC層のみをlr=0.0001で25epoch学習(lr=0.0001で25epoch学習したモデルでconv層出力を計算しておく)

loss acc val_loss val_acc
lr=0.0001、BNなし 0.6768 0.7764 0.7716 0.7444
lr=0.0001、BNあり 0.4751 0.8418 0.7227 0.7636

BNありだと学習データのロスがより小さくなりましたが完全に過学習になってます。。

実験結果のまとめ

  • 複雑なモデル、学習率を高くした場合でBatch Normalizationによる安定化効果が効く
  • FC層のファインチューニングだけだとそのメリットが小さくなる。むしろ学習が進みすぎて過学習を招く危険性?!

BNを使うと学習が安定する = 学習が速く進む ということが確認できました。

FC層のファインチューニングだとありがたみが薄れたので学習初期ほどBatch Normalizationによる安定化が効果を発揮するということでしょうか

今回使ったコードはhttps://github.com/tanakatsu/bn-experiments にあります。

fast.aiの記事から

http://www.fast.ai/2017/09/08/introducing-pytorch-for-fastai/

With the increased productivity this enabled, we were able to try far more techniques, and in the process we discovered a number of current standard practices that are actually extremely poor approaches. For example, we found that the combination of batch normalisation (which nearly all modern CNN architectures use) and model pretraining and fine-tuning (which you should use in every project if possible) can result in a 500% decrease in accuracy using standard training approaches. (We will be discussing this issue in-depth in a future post.) The results of this research are being incorporated directly into our framework.

Conv層出力のprecompute + FC層のファインチューニングは良くないと書いてあります。

oh... やはりそうだったのか・・・ !

We will be discussing this issue in-depth in a future post とあるので記事を待ちたいと思います。