Chainerのソースを解析。学習ループは Trainer に

MNIST のサンプルの 大体 の流れを追ってみます。(詳しい処理はこちらで)
バージョンは 1.14.0 、2016年9月5日時点のソースです。

全体の流れはこうなっており

  1. ① モデルの生成
  2. ② オプティマイザの生成
  3. ③ 学習データのダウンロード
  4. ④ trainer、updater の生成
  5. ⑤ Extension の登録
  6. ⑥ 学習ループ

ソースはこんな感じです。

では見ていきます。

- 目次 -

スポンサーリンク

学習ループ

run

学習ループはTrainerrunメソッドの中にあります。

traier-loop

run の中でいろいろやっているようなので見ていきます。

run → Updater

上のソースで updater とあるのは Updater の サブクラスStandardUpdaterです。そのインスタンスの update メソッドが学習ループの中で呼び出されています。後でソースを確認しますが、update メソッド の中でOptimizer → 順伝播、誤差逆伝播が実行されます。

train_blk

データの管理は Iterator(のサブクラス)が行っており、バッチサイズ毎にデータを切り出して渡してくれます。
では、その辺の流れを StandardUpdater のソースで確認します。

updater

Trainer から update が呼び出され、update_core の中でバッチ 1 回分の学習を行っています。繰り返しを管理するのは Trainer 側です。

Trainer と Extension

ログ出力や進捗バー表示といった拡張機能がExtension(のサブクラス) として部品化されています。
Extension は

  • Trainer に登録され
  • Trainer から起動され

ます。

ex_reg

起動条件となる トリガは登録時に設定されます。

Extension の登録、起動部分のソースを見てみます。

traier-trigger

extend メソッドで登録、run の学習ループの中で起動しています。

評価機能も Extension

テストデータによる評価も Extension 化されており、クラス名は Evaluator となっています。
他の Extension と同じく Trainer から起動され、起動後は順伝播の処理を行います。

train_blk_eval

ログ出力

学習 時

ログ情報は Trainer のインスタンス変数 observation(辞書型)に一旦書き込まれ、その内容が LogReport クラスによって集計されファイルに出力されます。

traier-report

ログ情報の中身は損失(loss)精度(accuracy)です。この情報を observation へ実際に書き込むのは Classifier です。

評価 時

テストデータを使った評価の際も、同様に observation 経由でファイルに出力されます。
ただし、Evaluator の observation は いったん Trainer の observation に統合され、その後、まとめてファイルに出力されます。

eval-report

まとめ

以上を図にしてみます。
煩雑になるので、Extension は一部だけにとどめます。

train_all

学習ループやデータ管理は、以前はアプリケーション側の実装になっていました。それが、抽象化されてフレームワーク内部に隠蔽された格好です。

これで単純な実装ミスは減ると思います。でも、なんか難しくなった気がしないでもない、です。

スポンサーリンク
その他の記事

コメントはお気軽に