Chainerのソースを解析。学習ループは Trainer に

2016年、夏、chainer のサンプルがごっそり別物みたいに変わっていました。
MNIST のサンプルはこんな感じです。

train

新クラスのオンパレード。
Trainer、Updater、Extension、IntervalTrigger、Iterator ・・・
目まいがしそうな変わりよう。

そして、学習ループは忽然と姿を消しています。

学習ループはどこへ

学習ループはどこかにあるはずです。気になって仕方ないので探しました。

run

ありました。Trainer の run メソッドの中です。

traier-loop

run の中でいろいろやっているようなので見ていきます。

run → Updater

上のソースで updater とあるのは Updater の サブクラス StandardUpdater です。そのインスタンスの update メソッドが学習ループの中で呼び出されています。後でソースを確認しますが、update メソッド の中で Optimizer → 順伝播、誤差逆伝播 が実行されます。

train_blk

データの管理は Iterator(のサブクラス)が行っており、バッチサイズ毎にデータを切り出して渡してくれます。
では、その辺の流れを StandardUpdater のソースで確認します。

updater

Trainer から update が呼び出され、update_core の中でバッチ 1 回分の学習を行っています。繰り返しを管理するのは Trainer 側です。

Trainer と Extension

ログ出力や進捗バー表示といった拡張機能が Extension(のサブクラス) として部品化されています。
Extension は

  • Trainer に登録され
  • Trainer から起動され

ます。

ex_reg

起動条件となる トリガは登録時に設定されます。

Extension の登録、起動部分のソースを見てみます。

traier-trigger

extend メソッドで登録、run の学習ループの中で起動しています。

評価機能も Extension

テストデータによる評価も Extension 化されており、クラス名は Evaluator となっています。
他の Extension と同じく Trainer から起動され、起動後は順伝播の処理を行います。

train_blk_eval

ログ出力

学習 時

ログ情報は Trainer のインスタンス変数 observation(辞書型)に一旦書き込まれ、その内容が LogReport クラスによって集計されファイルに出力されます。

traier-report

ログ情報の中身は 損失(loss)と精度(accuracy)です。この情報を observation へ実際に書き込むのは Classifier です。

評価 時

テストデータを使った評価の際も、同様に observation 経由でファイルに出力されます。
ただし、Evaluator の observation は いったん Trainer の observation に統合され、その後、まとめてファイルに出力されます。

eval-report

まとめ

以上を図にしてみます。
煩雑になるので、Extension は一部だけにとどめます。

train_all

学習ループやデータ管理は、以前はアプリケーション側の実装になっていました。それが、抽象化されてフレームワーク内部に隠蔽された格好です。

これで単純な実装ミスは減ると思います。でも、なんか難しくなった気がしないでもない、です。

コメントはお気軽に