CaffeでDeep Q-Networkを実装して深層強化学習してみた
Deep Q-Network
Deep Q-Network(以下DQN)は,2013年のNIPSのDeep Learning Workshopの"Playing Atari with Deep Reinforcement Learning"という論文で提案されたアルゴリズムで,行動価値関数Q(s,a)を深層ニューラルネットワークにより近似するという,近年の深層学習の研究成果を強化学習に活かしたものです.Atari 2600のゲームに適用され,既存手法を圧倒するとともに一部のゲームでは人間のエキスパートを上回るスコアを達成しています.論文の著者らは今年Googleに買収されたDeepMindの研究者です.
NIPS2013読み会で自分が紹介した際のスライドがこちらになります.
他の方が作成したスライドもあります.
必要なもの
- Caffe
- まだ本家にマージされていないpull requestを修正した上で使用しています.とりあえず動かしてみたい場合は自分のforkレポジトリのdqnブランチを使えば動くと思います.
- https://github.com/BVLC/caffe/pull/1228 (ソルバーのステップ実行に必要)
- https://github.com/BVLC/caffe/pull/1122 (AdaDelta)
- まだ本家にマージされていないpull requestを修正した上で使用しています.とりあえず動かしてみたい場合は自分のforkレポジトリのdqnブランチを使えば動くと思います.
- Arcade Learning Environment
- http://www.arcadelearningenvironment.org/ からダウンロードしてビルドします.ゲームスクリーンを表示するためにMakefileのUSE_SDLを1にセットします.libsdl,libsdl-gfx,libsdl-imageが必要になります.
ソースコード
GitHubで公開しています.DQN-in-the-Caffe
ネットワークの構成
ネットワークの構成は元論文の通り,
- 入力層:84x84x4(ラスト4フレームのダウンサンプリング&グレイスケール化)
- 隠れ層1:8x8のフィルタx8(ストライド4)による畳込み+ReLU
- 隠れ層2:4x4のフィルタx16(ストライド2)による畳込み+ReLU
- 隠れ層3:fully-connectedなノードx256+ReLU
- 出力層:fully-connectedなノードx18(18種類のアクションそれぞれの行動価値)
としました.このネットワークを逆伝播により学習するためには,複数ある出力のうち1つの出力のみに対して誤差を計算する必要があるのですが,それを可能にするためにCaffeのELTWISEレイヤーを使い,1つの要素のみ1で残りは0であるようなベクトルをネットワークの出力に掛け合わせることで望みの出力だけを取り出せるようにしています.Caffeのネットワーク表記でネットワーク全体を書くと下のようになりました.
パラメータの学習
パラメータの学習のためには,「状態で行動を選択したところ,報酬を獲得し,次の状態がであった」という状態遷移の経験をreplay memoryというメモリに保管していき,パラメータ更新の際にはそこからランダムサンプリングした一定数の遷移それぞれについて
となるように勾配を計算した上で,まとめて更新を行うミニバッチ学習を行います.
元論文ではここでRMSPropというパラメータ更新量の自動調節アルゴリズムを用いていますが,Caffeには今のところRMSPropは実装されておらず,その代わりAdaDeltaというRMSPropによく似たアルゴリズムをすでに実装してpull requestを投げている人がいたので,それを使いました.ただし,AdaDeltaをそのまま使用するとパラメータが発散してしまうことが多かったため,AdaDeltaによる更新量にさらに一定の係数(最初の100万イテレーションでは0.2,次の100万イテレーションでは0.02)を掛けて用いました(同じようなことをやっている?人).ミニバッチの大きさは元論文と同じ32,割引率は元論文では示されていませんが0.95としました.
元論文ではreplay memoryの容量は100万フレームでしたが,メモリの都合上,半分の50万で実験しました.
学習時間
実行環境は
- Intel Core i7-2600
- GeForce GTX 780 Ti
- メモリ16GB
- Ubuntu 14.04 64bit
です.CaffeはGPUモードで,さらにcuDNNを使いました.この構成でミニバッチ5万個の学習に45分ほどかかりましたが,元論文では5万個分を30分ほどで学習しているので,1.5倍ほど遅い結果となりました.