fast.aiでTensorFlowをバックエンドとして使う

fast.ai はKerasのバックエンドとしてTheanoを前提としていますがTensorFlowも使うことができます。

~/.keras/keras.json

{
    "image_dim_ordering": "tf",
    "epsilon": 1e-07,
    "floatx": "float32",
    "backend": "tensorflow"
}

とすればTensorFlowになり、特にエラーもなくJupyter Notebookを実行できます。

が、

これだけではうまくいきません。。

なぜか、学習したときの精度が低くなっちゃう.. ><

Lesson1のThe punchline: state of the art custom model in 7 lines of codeを実行すると、

Theano TensorFlow
val_acc: 0.9825 val_acc: 0.9190

となります。。

原因

VGG16モデルの転移学習用の重みデータ vgg16.h5vgg16_bn.h5 が Theano用になっているからです。

KerasのWikiページ Converting convolution kernels from Theano to TensorFlow and vice versa · fchollet/keras Wiki · GitHub に重みファイルの変換の方法が記載されていますが

どうやらKeras v2用のコードのようでfast.aiで使用するKeras v1だとエラーになります。

Keras v1での重み変換

ここのコードを参考にしてうまく変換できました。:-)

実装したものはこちら gist.github.com

このコードで作った vgg16.tf.h5vgg16.h5 の代わりに使うようにして TensorFlow バックエンドで学習を実行すると、

Theano TensorFlow(重み変換前) TensorFlow(重み変換後)
val_acc: 0.9825 val_acc: 0.9190 val_acc: 0.9840

期待通りの精度が出るようになりました。( •ᴗ•)

utils.py の修正とバックエンド切り替えスクリプト

このままでは Theano と TensorFlow のバックエンドを頻繁に変えたいときに面倒なので utils.py を少し修正します。

== 46行目あたり ==

if K.backend() == 'theano':
    import theano
    from theano import shared, tensor as T
    from theano.tensor.nnet import conv2d, nnet
    from theano.tensor.signal import pool

のように修正し、バックエンドが Theano のときだけ Theanoのライブラリを読み込むようにします。

さらに、次のようなシェルスクリプトを作ると

switch_keras_backend.sh

#!/bin/bash

if [ $# -ne 1 ]; then
  echo "ERROR: Wrong arguments"
  echo "Usage: switch_keras_backend.sh th|tf"
  exit 1
fi

if [ $1 != 'th' ] && [ $1 != 'tf' ]; then
  echo "ERROR: Wrong parameter. Parameter must be 'th' or 'tf'."
  echo "Usage: switch_keras_backend.sh th|tf"
  exit 1
fi

backend=$1
if [ $backend = 'th' ]; then
  echo "Use theano as backend"
elif [ $backend = 'tf' ]; then
  echo "Use tensorflow as backend"
fi

cd ~/.keras
ln -nsf keras.$backend.json keras.json

cd ~/.keras/models
ln -nsf vgg16.$backend.h5 vgg16.h5
ln -nsf vgg16_bn.$backend.h5 vgg16_bn.h5

echo 'Done!'

重みファイルも簡単に切り替えができます。

vgg16.h5keras.json はそれぞれのバックエンド毎に用意したファイルへのシンボリックリンクにしておく必要あり)