Chainerのソースを解析。LSTM クラスと時系列処理

Chainer に LSTM というクラスがあります。名前が示すとおり、RNN(再帰型ニューラルネット)の LSTM(長・短期記憶)を実現するためのクラスで、自然言語処理の公式サンプル ptb などで使用されています。

ソースを追ってみます。

- 目次 -

スポンサーリンク

2つの LSTM クラス

LSTM と名前のつくクラスは Chainer に 2 つ存在します。1 つは Chain のサブクラス、もう 1 つは Function のサブクラスです。ここでは便宜上、前者を LSTM (C)、後者を LSTM (F) と書いて区別します。

両クラスは下図の関係にあり、LSTM (C) が LSTM (F) を生成します。順伝播や誤差逆伝播の実際の計算は LSTM (F) で行います。

lstm_c_f

LSTM (C) の初期状態

LSTM (C) は内部に Linear インスタンスを 2 つ持っており、変数名はそれぞれ “lateral” と “upward” となっています。
Linear クラスについて詳しくはこちら

lstm_ini
Linear クラスも Chainer に 2 つ存在するので、Link のサブクラスを Linear (L)、Function のサブクラスを Linear (F) と表記します。

lateral と upward は LSTM (C) の初期化のタイミングで生成されます。
ソースは次のようになっています。

lstm_base

lateral と upward の役割

LSTM は時系列処理で使われます。時系列のデータを扱う場合、データの流れていく方向は下図のように 2 種類あります。1 つは前レイヤーから次レイヤーへの流れ(縦方向)、もう 1 つは時間に沿った t-1 から t+1 への流れ(横方向)です。

lstm_seq

データが流れていく際に重みやバイアスを反映しますが、前レイヤーから受け取るデータには upward が、t-1 から受け取るデータには lateral がそれぞれ反映を行います。

lstm_in

LSTM (F) の生成

モデル生成直後、 LSTM (F) のインスタンスはまだ存在していません。Linear (F) のインスタンスも同様に存在していません。では、いつ生成されるかというと、学習ループが始まり順伝播が動き出してからです。

生成後の LSTM (F) と Linear (F) は下図の関係にあり、おおまかに言うと、Linear (F) で重み・バイアスを反映し、 LSTM (F) が LSTM 独自 の処理を行います。

LSTM 独自 の処理 ?

メモリセルの管理や入力、出力、忘却の 3 ゲートの制御です。
詳細については 「深層学習 (機械学習プロフェッショナルシリーズ)」 に詳しく解説されています。
lstm_t

時系列

RNN は時系列データを扱うため、時刻 t の結果を 時刻 t+1 に引き継ぎます。t と t+1 を図示すると

lstm_t1

データが横につながっていきます。
ソースで確認すると

lstm_call

self.c はメモリセル値、self.h は出力値です。両方とも次回の__call__で使用されるので、t+1、t+2 ・・・ へとデータは引き継がれていきます。

LSTM (F) の内部

LSTM (F) の内部をのぞいてみます。

順伝播を図解

下図に描いたのは順伝播の処理ですが、メモリセルを中心に 3 つのゲート(入力ゲート出力ゲート忘却ゲート)が働いています。図では i ゲート、o ゲート、f ゲートと記載しています。

lstm_detail

a、i、f、o の各値は、入力データや t-1 の出力から upward と lateral によってつくられます。
実際にデータを当てはめてみます。極端な例ですが、入力サイズが 2 、出力サイズが 3 の場合、下図のようになります(バッチサイズは 1)。

lstm_detail_dat

ptb の場合

Chainer の公式サンプル ptb の場合、出力は 650、バッチサイズは 20 なので、メモリセルのサイズはこうなります。

lstm_ptb


計算

LSTM (F) の役割は 順伝播と誤差逆伝播の計算です。該当のソースを確認しておきたいと思います。

順伝播

forward メソッドです。

lstm_fwd

のぞき穴(peephole)は実装されていないようです。

誤差逆伝播

誤差逆伝播は、順伝播で生成されたインスタンスを逆方向にたどりながら計算します。詳しくはこちら
LSTM (F) の backward メソッドは、t+1 と次レイヤー から δ(デルタ)を受け取り、δ を計算します。

δ(デルタ) ?

δ(デルタ)は、書籍「深層学習」の用語をそのまま用いています。
lstm_bwd

Variable の加算

LSTM に関わる Function として、ここまでに LSTM (F) と Linear (F) の 2 つを取り上げましたが、実はもう 1 つ Function が関わっています。Variable に関する Function です。

Variable には __add__ や __sub__の特殊メソッドが定義されており、加算時は Add というクラスが生成される仕組みになっています。この Add が Function を継承しています。

var_add

(上の図にはこっそり書き込んであります。lateral と upward の出力を加算する場所)

Variable の加算が動くタイミングはここ、LSTM (C) の__call__メソッドです。

このタイミングで Add クラスの forward が動き、誤差逆伝播で backward が動きます。

スポンサーリンク
その他の記事

コメントはお気軽に