TF-Agents 0.4 Tutorials : ドライバー (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/20/2020 (0.4)
* 本ページは、TF Agents の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
ドライバー
イントロダクション
強化学習の一般的なパターンはステップかエピソードの指定数のために環境でポリシーを実行します。これは例えば、データ収集、評価そしてエージェントのビデオ生成する間に発生します。
これは Python で書くことは比較的簡単である一方、TensorFlow で書いてデバッグすることは遥かにより複雑です、何故ならばそれは tf.while loops, tf.cond と tf.control_dependencies を伴うからです。そのためこの実行ループの概念をドライバーと呼ばれるクラスに抽象して Python と TensorFlow の両者で良くテストされた実装を提供します。
追加で、各ステップでドライバーに遭遇したデータは Trajectory と呼ばれる名前付きタプルにセーブされて再生バッファとメトリクスのような観測者のセットにブロードキャストされます。このデータは環境からの観測、ポリシーにより勧められるアクション、得られた報酬、現在と次のステップのタイプ、等。
セットアップ
tf-agents か gym をまだインストールしていないのであれば、以下を実行します :
!pip install --upgrade tensorflow-probability !pip install tf-agents !pip install gym
from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tf_agents.environments import suite_gym from tf_agents.environments import tf_py_environment from tf_agents.policies import random_py_policy from tf_agents.policies import random_tf_policy from tf_agents.metrics import py_metrics from tf_agents.metrics import tf_metrics from tf_agents.drivers import py_driver from tf_agents.drivers import dynamic_episode_driver tf.compat.v1.enable_v2_behavior()
Python ドライバー
PyDriver クラスは各ステップで更新するために python 環境、python ポリシーそして観測者のリストを取ります。主要なメソッドは run() で、これはポリシーからのアクションを使用して少なくとも以下の停止基準の一つに遭遇するまで環境に踏み入ります : ステップ数が max_steps に到達するかエピソード数が max_episodes に達する。
実装はおおよそ次のようなものです :
class PyDriver(object): def __init__(self, env, policy, observers, max_steps=1, max_episodes=1): self._env = env self._policy = policy self._observers = observers or [] self._max_steps = max_steps or np.inf self._max_episodes = max_episodes or np.inf def run(self, time_step, policy_state=()): num_steps = 0 num_episodes = 0 while num_steps < self._max_steps and num_episodes < self._max_episodes: # Compute an action using the policy for the given time_step action_step = self._policy.action(time_step, policy_state) # Apply the action to the environment and get the next step next_time_step = self._env.step(action_step.action) # Package information into a trajectory traj = trajectory.Trajectory( time_step.step_type, time_step.observation, action_step.action, action_step.info, next_time_step.step_type, next_time_step.reward, next_time_step.discount) for observer in self._observers: observer(traj) # Update statistics to check termination num_episodes += np.sum(traj.is_last()) num_steps += np.sum(~traj.is_boundary()) time_step = next_time_step policy_state = action_step.state return time_step, policy_state
今は、カートポール環境でランダムポリシーを実行するサンプルを通して実行しましょう、結果を再生バッファにセーブして幾つかのメトリクスを計算します。
env = suite_gym.load('CartPole-v0') policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), action_spec=env.action_spec()) replay_buffer = [] metric = py_metrics.AverageReturnMetric() observers = [replay_buffer.append, metric] driver = py_driver.PyDriver( env, policy, observers, max_steps=20, max_episodes=1) initial_time_step = env.reset() final_time_step, _ = driver.run(initial_time_step) print('Replay Buffer:') for traj in replay_buffer: print(traj) print('Average Return: ', metric.result())
TensorFlow ドライバー
私達はまた TensorFlow のドライバーも持ちます、これは機能的には Python ドライバーに類似していますが、TF 環境, TF ポリシー, TF 観測者等を利用します。現在は 2 つの TensorFlow ドライバーを持ちます : DynamicStepDriver, これは (正当な) 環境ステップの与えられた数の後停止します、そして DynamicEpisodeDriver, これはエピソードの与えられた数の後停止します。アクションの DynamicEpisode のサンプルを見ましょう。
env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(), time_step_spec=tf_env.time_step_spec()) num_episodes = tf_metrics.NumberOfEpisodes() env_steps = tf_metrics.EnvironmentSteps() observers = [num_episodes, env_steps] driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, tf_policy, observers, num_episodes=2) # Initial driver.run will reset the environment and initialize the policy. final_time_step, policy_state = driver.run() print('final_time_step', final_time_step) print('Number of Steps: ', env_steps.result().numpy()) print('Number of Episodes: ', num_episodes.result().numpy())
# Continue running from previous state final_time_step, _ = driver.run(final_time_step, policy_state) print('final_time_step', final_time_step) print('Number of Steps: ', env_steps.result().numpy()) print('Number of Episodes: ', num_episodes.result().numpy())
以上