TF-Agents 0.4 Tutorials : Checkpointer と PolicySaver (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/23/2020 (0.4)
* 本ページは、TF Agents の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
Checkpointer と PolicySaver
イントロダクション
tf_agents.utils.common.Checkpointer は訓練状態、ポリシー状態と reply_buffer 状態をローカルストレージに/からセーブ/ロードするユティリティです。
tf_agents.policies.policy_saver.PolicySaver はポリシーだけをセーブ/ロードするツールで Checkpointer よりも軽いです。ポリシーを作成したコードのどのような知識なしでもモデルを配備するために PolicySaver を利用できます。
このチュートリアルでは、モデルを訓練するために DQN を使用てから、状態とモデルを対話的にどのようにストアしてロードできるかを示すために Checkpointer と PolicySaver を使用します。PolicySaver のために TF 2.0 の新しい saved_model ツールと形式を使用することに注意してください。
セットアップ
以下の依存性をインストールしていないのであれば、以下を実行します :
#@test {"skip": true}
!sudo apt-get install -y xvfb ffmpeg
!pip install 'gym==0.10.11'
!pip install 'imageio==2.4.0'
!pip install 'pyglet==1.3.2'
!pip install 'xvfbwrapper==0.2.9'
!pip install --upgrade tensorflow-probability
!pip install tf-agents
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython
try:
from google.colab import files
except ImportError:
files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tf.compat.v1.enable_v2_behavior()
tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
#@test {"skip": true}
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()
DQN エージェント
ちょうど前の colab のように、DQN エージェントをセットアップしていきます。詳細はデフォルトでは隠されます、何故ならばこの colab の中核パートではないからです、しかし詳細を見るために ‘SHOW CODE’ 上をクリックできます。
ハイパーパラメータ
env_name = "CartPole-v1" collect_steps_per_iteration = 100 replay_buffer_capacity = 100000 fc_layer_params = (100,) batch_size = 64 learning_rate = 1e-3 log_interval = 5 num_eval_episodes = 10 eval_interval = 1000
環境
train_py_env = suite_gym.load(env_name) eval_py_env = suite_gym.load(env_name) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
エージェント
#@title
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
agent.initialize()
データ収集
#@title
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
collect_driver = dynamic_step_driver.DynamicStepDriver(
train_env,
agent.collect_policy,
observers=[replay_buffer.add_batch],
num_steps=collect_steps_per_iteration)
# Initial data collection
collect_driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
エージェントを訓練する
#@title
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
def train_one_iteration():
# Collect a few steps using collect_policy and save to the replay buffer.
for _ in range(collect_steps_per_iteration):
collect_driver.run()
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
iteration = agent.train_step_counter.numpy()
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))
ビデオ生成
#@title def embed_gif(gif_buffer): """Embeds a gif file in the notebook.""" tag = ''.format(base64.b64encode(gif_buffer).decode()) return IPython.display.HTML(tag) def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env): num_episodes = 3 frames = [] for _ in range(num_episodes): time_step = eval_tf_env.reset() frames.append(eval_py_env.render()) while not time_step.is_last(): action_step = policy.action(time_step) time_step = eval_tf_env.step(action_step.action) frames.append(eval_py_env.render()) gif_file = io.BytesIO() imageio.mimsave(gif_file, frames, format='gif', fps=60) IPython.display.display(embed_gif(gif_file.getvalue()))
ビデオを生成する
ビデオを生成することによりポリシーのパフォーマンスを確認する。
print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
Checkpointer と PolicySaver をセットアップする
今は Checkpointer と PolicySaver を使用する準備ができました。
Checkpointer
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
Policy Saver
policy_dir = os.path.join(tempdir, 'policy') tf_policy_saver = policy_saver.PolicySaver(agent.policy)
1 反復訓練する
#@test {"skip": true}
print('Training one iteration....')
train_one_iteration()
チェックポイントにセーブする
train_checkpointer.save(global_step)
チェックポイントをリストアする
これが動作するするためには、オブジェクトのセット全体がチェックポイントが作成されたときと同じ方法で再作成されるべきです。
train_checkpointer.initialize_or_restore() global_step = tf.compat.v1.train.get_global_step()
ポリシーもセーブして位置にエクスポートします。
tf_policy_saver.save(policy_dir)
ポリシーは、それを作成するために何のエージェントやネットワークが使用されたかのどのような知識も持つことなくロードできます。これはポリシーの配備を遥かに容易にします。
セーブされたポリシーをロードしてそれがどのように遂行するか確認します。
saved_policy = tf.compat.v2.saved_model.load(policy_dir) run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
Export と import
colab の残りは、後の時点で訓練をで続けてそして再度訓練しなければならないことなくモデルを配備できるように checkpointer とポリシー・ディレクトリを export / import するに役立つでしょう。
今は「1 反復を訓練する」に戻って後で差異を理解できるように更に 2, 3 回訓練することができます。僅かばかりより良い結果をひとたび見始めるのであれば、以下を続けてください。
#@title Create zip file and upload zip file (double-click to see the code)
def create_zip_file(dirname, base_filename):
return shutil.make_archive(base_filename, 'zip', dirname)
def upload_and_unzip_file_to(dirname):
if files is None:
return
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
shutil.rmtree(dirname)
zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
zip_files.extractall(dirname)
zip_files.close()
チェックポイント・ディレクトリから zip されたファイルを作成します。
train_checkpointer.save(global_step) checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))
zip ファイルをダウンロードします。
#@test {"skip": true}
if files is not None:
files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
ある程度の時間 (10-15 回) の間訓練した後、チェックポイント zip ファイルをダウンロードして、そして訓練をリセットするために “Runtime > Restart and run all” に行き、そしてこのセルに戻ります。今はダウンロードされた zip ファイルをアップロードして訓練を続けることができます。
#@test {"skip": true}
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
ひとたびチェックポイント・ディレクトリをアップロードしたら、訓練を続けるために「1 反復を訓練する」に戻るかロードされたポリシーのパフォーマンスを確認するために「ビデオを生成する」に戻ります。
代わりに、ポリシー (モデル) をセーブしてそれをリストアできます。checkpointer とは違い、訓練を続けることはできませんが、依然としてモデルを配備できます。ダウンロードされたファイルは checkpointer のそれよりも遥かに小さいことに注意してください。
tf_policy_saver.save(policy_dir) policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
#@test {"skip": true}
if files is not None:
files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
ダウンロードされたポリシー・ディレクトリ (exported_policy.zip) をアップロードしてセーブされたポリシーがどのように遂行するかを確認します。
#@test {"skip": true}
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.compat.v2.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
SavedModelPyTFEagerPolicy
TF ポリシーを使用することを望まない場合、py_tf_eager_policy.SavedModelPyTFEagerPolicy の利用を通して Python env で saved_model を直接使用することもできます。
これは eager mode が有効であるときに限り動作するだけであることに注意してください。
eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())
# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)
以上