Chainerのソースを解析。Linear クラス

深層学習の本を読んでもなかなか理解できないので Chainer のソースを読んでみました。なにか判るかと期待しましたが、とりあえず python に目が慣れてきたのが成果です。ソースのバージョンは 1.14.0 、2016年9月5日時点のものです。

Linear クラスについて判ったことを忘れないように書いておきます。

- 目次 -

スポンサーリンク

全結合層は Linear クラス

Chainer では全結合層を Linear クラスで組み立てます。
(chainer/links/connection/linear.py)

全結合とは、隣接レイヤーのユニット同士が下図のようにすべてつながる形態のことです。
fully_coupled

全結合を構成する要素

全結合を構成する要素として

  • 重み
  • バイアス
  • 活性化関数

があります。この中で、Linear クラスが関わるのは重みとバイアスです。
活性化関数については関与しません。

Linear のインスタンスはレイヤーに相当

Linear の 1 インスタンスが 1 レイヤーに相当します。そして、レイヤー内の各ユニットが持つパラメータ(重み、バイアス)は Linear インスタンスの属性として管理されます。

補足
ユニットに相当するクラスは存在しません。したがって、Linear インスタンスが直接、重みとバイアスを保持します。

重み、バイアスの生成は Chainer まかせ

たとえば、図のような 3 ユニットのレイヤーがあるとします。

layer

このレイヤーに必要な重みは 3 × 2 桁、バイアスは 3 桁の配列ですが、
W_bias
Linear 生成時に Chainer がデフォルトで生成、管理してくれます。
必要な処理はサイズを指定するだけで、上記の例であれば、入力を 2、出力を 3 と指定します。

呼び出される Linear の__init__はこう定義されてます。

lenear_init

インスタンスのイメージ

生成される Linear インスタンスは 「重み」 と 「バイアス」 を Variable で保持します。

Chainerのlinear

W が重み、b がバイアスです。

Variable インスタンスには data と grad があり

  • data にはパラメータ (重み、バイアスそのもの)
  • grad には逆伝播で求める勾配

がセットされます。

data、grad はともに同サイズの配列で、イメージとしては図のようになります。

linear_w_b

重み、バイアスを指定する場合

重み、バイアスの初期値を指定したければ、引数 initialW、initial_bias で渡します。
linear_init3

順伝播、逆伝播のロジックはどこに?

Linear クラスの役割は重みとバイアスを保持することであり、計算のロジックは持っていません。では、ロジックはどこにあるかというと、別クラス LinearFunction(functions/connection/linear.py)に定義されており、forward メソッド、backward メソッドがそれぞれ順伝播、逆伝播に相当します。

forward メソッド

forward は入力 x を受け取り、下図の計算で出力 y を求めます。

foward_cal

backward メソッド

  • 重みとバイアスの勾配を計算し
  • 下位層へのデータ引き渡し

を行います。> くわしくは、こちらで説明

ソースを確認

linear_method

foward の呼び出し方

自分で foward を呼び出す必要はありません。Linear インスタンスに学習データを渡すと、__call__メソッド経由で forward が自動で呼ばれます。

backward の呼び出し方

順伝播が終わると、その出力が Variable で返ってきます。Variable にも backward メソッドが存在するのですが、そのメソッドを呼ぶと逆伝播が始まり、LinearFunction の backward が呼ばれる仕組みになっています。> くわしくは、こちらで説明

図にすると

foward_call

Function は抽象クラス、LinearFunction はその具象クラスです。

オプティマイザ

backward は勾配を計算してくれますが、その計算した値を重みとバイアスに反映することまではやってくれません。では、誰がやるのかというと、オプティマイザです。> こちらで説明

Linear インスタンスができるまで

見てもあまり役にたたないとは思いますが。

linear_init2
スポンサーリンク
その他の記事

コメントはお気軽に