RNNで感情分析(Sentiment Analysis)

RNN(Recurrent Neural Network)を使って感情分析(Sentiment Analysis)をしてみます。 今回はchainerを使って実装していく中で理解に苦労した点を図にまとめます。

RNNとは何なのか? RNNの基本、については以下の記事が分かりやすいです。

ちなみに、3つ目の記事は おなじみのMNIST文字認識が CNN ではなくRNN で実装されていて興味深かったです。

Sentiment Analysis

感情分析を理解および実装していく上でポイントになるのは以下の2点です。

  1. CNN(画像分類)との違い
  2. ロスの計算

CNN(画像分類)との違い

CNNで画像分類をするときは次のようにバッチサイズ分のデータを1回で入力します。

f:id:sanshonoki:20170929051417p:plain:w450

一方、感情認識では 時間方向の依存関係、つまり文脈を学習するので時間方向に入力していく必要があります。

f:id:sanshonoki:20171002202054p:plain

つまり、バッチサイズ分の文章を取り出し、それを位置ごとに切って文章の長さの回数分ネットワークに入力します。

ちなみに、このように時間方向に展開していくことを unroll や unfold と言うようです。

ロスの計算

感情分析の場合、文章(単語列)に対して1つのラベルがつくので最後の単語の出力(y_n)に対してのみ教師ラベルとのロスを計算します。それ以外の出力は無視します。

f:id:sanshonoki:20171002215028p:plain:w400

chainer のコードだと

if j <  seqlen - 1:
    model.predictor(x)
else:
    accum_loss = model(x, t)
    accum_loss.backward()

あるいは

loss = model(x, t)
if j == seqlen - 1:
    accum_loss += loss 
    accum_loss.backward()

といった形になります。

また、TensorFlowの場合は時間方向にまとめて展開できる dynamic_rnnモジュールを使って以下のように表現できます。 (Embed、LSTM層とFC層が別々になっています。)

f:id:sanshonoki:20171002215824p:plain:w400

感情分析のチートシート的なまとめ図

記憶の定着のためにまとめてみました。( •ᴗ•)

f:id:sanshonoki:20171002215211p:plain

その他のRNN応用でのロスの計算

感情分析以外もRNNの応用はいくつかあり

  • 文章生成
  • 機械翻訳(seq2seq)
  • 画像のキャプショニング

などがあります。

f:id:sanshonoki:20171003210723p:plain:w400

(画像は http://karpathy.github.io/2015/05/21/rnn-effectiveness/ から)

入力と教師ラベルの与え方でよく混乱するのでこれも図にしました。

ポイントは

  • 推論(予測)フェーズと違って出力はロスの計算以外に使わない(推論フェーズでは出力が次の入力になる)
  • 入力(x)を1つずらして教師ラベル(t)を作る

かなと思います。

文章生成(character-level language model)

f:id:sanshonoki:20171002215425p:plain:w350

機械翻訳(seq2seq)

f:id:sanshonoki:20171003211707p:plain

<GO><EOS> は文章の始まりと終わりを表す記号です。

機械翻訳では入力単語列を反転して入れると結果が良いらしいです。 順番に入れると1番目の単語を予測する上で大事な最初の入力単語の情報がデコーダの時点でより失われるのだと勝手に理解しています。

chainerでの実装

データセット

UdacityのDeepLearning nanodegree講座のSentiment-RNNの課題で使ったデータセットを使います。

映画レビューの英語テキストを positive, negative の2クラスに分類するという課題で実装が正しければ数epochの学習の後、おおよそ80%程度の精度が出るはずです。

コード

chainerの昔のptbのサンプルを参考にしつつ実装してみました。

(trainerを使っても実装しようとしましたがなんかうまく動いてくれません。。><)

github.com

動かしたところ80%の精度が出たのでアルゴリズムは大丈夫のようです。ƪ(•◡•ƪ)"

次はこれを使って Slackのメッセージを分類してみようと思います。