CapsuleNetworkの論文読んだ

CapsuleNetworkの論文読んだので復習も兼ねてブログにまとめます。
分かってない部分も多々あるかもですが、参考にする場合はご了承ください。
計算方法とかに関してはあんまり強くないのでそこまで触れていません・・・。

1.CapsuleNetworkとは

大体半年前ぐらいに論文で発表された技術で、CNNとはまた異なる手法を使った技術。CNNはニューロンを使用しているが、CapsuleNetwork(CapsNet)ではカプセルを用いているのが特徴。
CNNでは画像内の最も特徴の強いピクセルを圧縮し続けて抽象化することである程度のロバスト性を獲得するが、CapsNetでは特徴の強い部分同士の空間的な相互関係を学習するものになっている。
これにより、CNNでは物体の変化に対応するためにある程度訓練時に様々な角度の画像が必要であったが、CapsNetでは特徴間を空間的に把握するため、訓練時に様々な角度の画像が必要なくなるといったことが言える。

MNISTではあまり精度の向上はうかがえないが、一応エラー率は従来のCNNよりも高い0.25%となっている。またこの論文で独自?にしようされたMultiMNISTと呼ばれる数字が2つ重なったデータセットを用いた結果、数字が重なっているためにCNNではほとんど判別不能だったが、CapsNetでは2つの数字を判別できるといった結果が出ている。

簡単に言えば、CapsNetは様々な角度の画像なしでもロバスト性があり、物体が重なっていても判別できるといったことが言えるかもしれない。

2.カプセルって何?

カプセルっていうのはニューロンを複数合わせたものを1つにしたもの?だと思っていたけど、ニューロンがスカラを使うのに対して、カプセルはベクトルを用いるというのが最大の特徴かも?
自分はベクトルを入出力としたニューロンということで認識している()
ベクトルを使うので、学習時に空間的な情報も扱えるから最強だよね!っていうことだと思うことにしました。(適当ですみません)
複数のニューロンの相互関係(向き)と特徴の強さ(値の大きさ)を使えるっていうことなのかな。


簡単な構造の画像は以下になります。
f:id:sizuruna-193:20180914194353p:plain
(引用:https://qiita.com/hiyoko9t/items/f426cba38b6ca1a7aa2b
左がカプセル、右がニューロン

3.学習方法(アルゴリズム

CapsNetではDynamic Routingという手法を使う。これが一番重要・・・。
抽象的に言うと、カプセル同士の繋がり(重み)を学習する。
この学習方法のおかげでCapsuleNetworkの最大の利点である特徴間の空間的な相互関係が学習できる
Dynamicなのは、学習過程でカプセル間の繋がりを強くしたり、弱くしたりするために、構造が動的に変化するからだと思われる。(CNNでも同じことしてるじゃない?なんて一瞬思った)

自分なりに翻訳したアルゴリズムは以下になります。(画像の数字とその下の数字は関連してないので注意してください。)

f:id:sizuruna-193:20180914193845p:plain

(1)まずl層目のi個のカプセルと(l+1)層目のj個のカプセルとの間の変数bijを0で初期化する
(2)l層目のi個のカプセルより、SoftMax(bij)で「繋がり度合いcij」を計算する
(3)(l+1)層目のj個のカプセルより、cij*uij(uiはl層目のi番目の出力ベクトル)の総和であるsjを計算する
(4)(l+1)層目のj個のカプセルより、繋がりが強い場合は1、そうでない場合は0に近づける活性化関数Squash(sj)を計算する
(5)カプセル間の変数bijを更新する bij←bij + uji*vj
(6)2~5をルーティングしたい回数繰り返す

これを実際の顔画像で例えるなら
l層目の下位カプセルは目、鼻、口を示していて、l+1層目は顔や花、建物といった上位カプセルと考える。
入力画像に顔があったとき、目があって顔が存在し、鼻があって顔が存在するが、目が花や建物には存在しない。
このとき目と花といったカプセルは繋がりが弱いと判断し、目と顔のカプセル間は繋がりを強くするという学習を行う。
(この繋がりの強さを1~0と考えるとして、Squash関数を用いて弱い場合は0に近づけ、強い場合は1に近づける)

f:id:sizuruna-193:20180914194956p:plain

4.CapsNetの構造

f:id:sizuruna-193:20180914200528p:plain

左から順に
・入力画像(MNIST)
・Conv1
・PrimaryCaps
・DigitCaps
・一番右は再構成ネットワーク?だと思われる(後で解説)

CapsNetでも畳み込み層は使用しているそうです。(CNNはプーリング層が情報をそぎ落とす欠点だった)
Conv1で画像の特徴を抽出して、
PrimaryCapsとDigitCapsの間で先ほどのDynamic Routingのアルゴリズムを使用して画像の空間的関係を学習する。
パラメータに関してはいじってみたのですが正直まだわかっていないのでスルーします・・・。扱う入力画像によって変わると思うのでモノクロやカラーだとどうなるのか調査してみたいところ。

5.損失関数

もちろん、CapsNetでも損失関数はあります。この論文ではMarginLoss関数と呼んでました。

f:id:sizuruna-193:20180914201234p:plain

・Tkはデータがクラスk(存在すれば)1、そうでなければ0
・vkはカプセルの出力
・λは0.5 m+は0.9 m-は0.1
・kは0,1,2,3...9のどれか(MNISTの10クラス分)

つまり、クラスkが存在した場合は左側の計算、存在しない場合は右側の計算を行うといった損失関数になっています。
λとかmの値に関しては論文で既に設定されていました。これが一番適切らしい。

6.再構成損失

先ほどのネットワーク構造の画像の一番右について解説します。
この技術はぶっちゃけ多分CapsNetとは関係ないと思われるのですが、使ってみたらうまくいっちゃった!みたいな感じで紹介されていたので多分そうだと思う()

f:id:sizuruna-193:20180914201643p:plain

カプセルの出力から画像を再構成することで、入力画像と再構成した画像の差異を損失として先ほどのMarginLossに加えるというもの。
要するに、このネットワークが思い出して入力画像を作ることができるかどうかを確かめているということだと思っています。うまく作れれば損失は少なくなるし、そうでなければ損失が高くなる(まだ学習できてないからもっと学習することになる)

もちろんこの技術はMarginLossに加えることができるといったので、学習する段階で再構成するかしないかを選択することは可能です。

ちなみに計算方法は訓練データ(入力)と再構成した画像とのズレを最小二乗法で計算してます。

これを行うことで過学習の防止、性能の向上が見込めるので使わないに越したことはないでしょう。
ちなみにこれを使うか使わないかでどのくらい学習時間が変わるのか調べてみたのですが、MNISTぐらいの画像だと対して変わりませんでした()カラー画像だともしかしたら結構違いが出るかもしれない。

7.所感

カプセルネットワークはかなり人間の脳に近い動きをするネットワークということはなんとなくわかりました。CNNだと訓練時に多くの角度のデータが必要になるけれど、人間の目は正面で見た物体は斜めからみたところで大体同じものだと認識できるので、かなりカプセルネットワークは人間の脳に近いことがわかりました。
論文だとモノクロ画像のMNISTしか使ってなかったのですがそれでもCNNと比較するとかなり物体認識に安定感があるなあと思いました。
しかし、実際に動かしてみたのですが、学習データが少なくなっても学習時間はCNNと比べると4倍ぐらい掛かってしまいました・・・。この部分に関してはCNNが研究されまくってきたからというのもありそうなので、今後に期待ですね・・・。
カプセルネットワークはグーグルみたいな大企業で研究するところよりも、個人で研究している人向けかもしれませんね。学習データが少なくなるのはかなりうれしいと思いました。

参考文献など
カプセルネットワークはニューラルネットワークを超えるか。
Dynamic Routing Between Capsulesを読む - データサイエンティスト(仮)
https://arxiv.org/pdf/1710.09829.pdf

ソースコード
github.com