ホーム » TensorFlow 2.0 » TensorFlow 2.0 Alpha : ガイド : 訓練チェックポイント

TensorFlow 2.0 Alpha : ガイド : 訓練チェックポイント

TensorFlow 2.0 Alpha : ガイド : 訓練チェックポイント (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/24/2019

* 本ページは、TensorFlow の本家サイトの TF 2.0 Alpha の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ガイド : 訓練チェックポイント

慣用句「TensorFlow モデルをセーブする」は典型的には 2 つのことの 1 つを意味しています : (1) チェックポイント, OR (2) SavedModel です。

チェックポイントはモデルにより使用される総てのパラメータの正確な値を捕捉します。チェックポイントはモデルで定義された計算のどのような記述も含みません、そのため典型的にはセーブされたパラメータ値を使用するソースコードが利用可能であるときに限り有用です。

他方、SavedModel フォーマットはパラメータ値 (チェックポイント) に加えてモデルで定義された計算のシリアライズされた記述を含みます。このフォーマットのモデルはこのモデルを作成したソースコードからは独立です。そのためそれらは TensorFlow Serving や他のプログラミング言語 (C, C++, Java, Go, Rust, C# etc. TensorFlow API) のプログラムを通した配備に適しています。

このガイドはチェックポイントを書いて読むための API をカバーします。

 

tf.keras 訓練 API からセーブする

tf.keras guide on saving and restoring を見てください。

tf.keras.Model.save_weights はオプションで TensorFlow チェックポイント・フォーマットでセーブします。このガイドはフォーマットを詳細に説明してカスタム訓練ループでチェックポイントを管理するための API を紹介します。

 

チェックポイントを手動で書く

TensorFlow モデルの永続的な状態は tf.Variable オブジェクトにストアされます。これらは直接に構築できますが、しばしば tf.keras.layers のような高位 API を通して作成されます。

変数を管理する最も容易な方法はそれらを Python オブジェクトにアタッチして、それらのオブジェクトを参照することです。tf.train.Checkpoint のサブクラス、tf.keras.layers.Layer と tf.keras.Model はそれらの属性に割り当てられた変数を自動的に追跡します。次のサンプルは単純な線形モデルを構築してから、モデルの総ての変数のための値を含むチェックポイントを書きます。

from __future__ import absolute_import, division, print_function
!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)

このガイドの焦点ではありませんが、実行可能であるためにはサンプルはデータと最適化ステップが必要です。モデルは in-メモリのデータセットのスライス上で訓練されます。

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat(10).batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

次の訓練ループはモデルと optimizer のインスタンスを作成して、それらを tf.train.Checkpoint オブジェクトに集めます。それはデータの各バッチ上のループで訓練ステップを呼び出し、定期的にチェックポイントをディスクに書きます。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
  print("Restored from {}".format(manager.latest_checkpoint))
else:
  print("Initializing from scratch.")

for example in toy_dataset():
  loss = train_step(net, example, opt)
  ckpt.step.assign_add(1)
  if int(ckpt.step) % 10 == 0:
    save_path = manager.save()
    print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
    print("loss {:1.2f}".format(loss.numpy()))
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 25.84
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 19.26
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 12.70
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 6.26
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 2.32

上述のスニペットはそれが最初に実行されるときモデル変数をランダムに初期化します。最初の実行後それは訓練をやめたところから訓練を再開します :

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
  print("Restored from {}".format(manager.latest_checkpoint))
else:
  print("Initializing from scratch.")

for example in toy_dataset():
  loss = train_step(net, example, opt)
  ckpt.step.assign_add(1)
  if int(ckpt.step) % 10 == 0:
    save_path = manager.save()
    print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
    print("loss {:1.2f}".format(loss.numpy()))
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.39
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.51
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.72
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.81
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.46

tf.train.CheckpointManager オブジェクトは古いチェックポイントを削除します。上で 3 つの最近のチェックポイントだけを保持するように configure されています。

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

これらのパス, e.g. ‘./tf_ckpts/ckpt-10’ はディスク上のファイルではありません。代わりにそれらはインデックスファイルのための prefix で変数値を含む一つまたはそれ以上のデータファイルがあります。これらの prerfix はまとめて単一のチェックポイントファイル (‘./tf_ckpts/checkpoint’) にグループ化されます、そこでは CheckpointManager がその状態をセーブします。

!ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

 

ローディング機構

TensorFlow は変数を、ロードされたオブジェクトから始めて、名前付けられたエッジを持つ有向グラフを辿ることによりチェックポイントされた値に合わせます。エッジ名は典型的にはオブジェクトの属性名に由来します、例えば self.l1 = tf.keras.layers.Dense(5) の “l1” です。tf.train.Checkpoint はそのキーワード引数名を例えば tf.train.Checkpoint(step=…) の “step” のように使用します。

上のサンプルからの依存グラフはこのようなものです :


サンプル訓練ループのための依存グラフの可視化

optimizer は赤色、通常変数は青色そして optimizer スロット変数はオレンジ色です。他のノード、例えば tf.train.Checkpoint を表わすのは黒色です。

スロット変数は optimizer の状態の一部ですが、特定の変数のために作成されます。例えば上の ‘m’ エッジは momentum に対応します、これは Adam optimizer が各変数のために追跡します。スロット変数は変数と optimizer の両者がセーブされる場合に限りチェックポイントにセーブされますので、破線のエッジです。

tf.train.Checkpoint オブジェクト上の restore() 呼び出しは要求された復元 (= restoration) をキューに入れて、Checkpoint オブジェクトから一致するパスがあればすぐに変数値を復元します。例えば上で定義したモデルから単にカーネルをネットワークと層を通してそれへのパスを再構築することによりロードできます。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[3.406768  2.0089386 2.3660657 2.9850907 4.004283 ]

これらの新しいオブジェクトのための依存グラフは上で書いたより大きなチェックポイントの遥かに小さいサブグラフです。それはバイアスと tf.train.Checkpoint がチェックポイントに番号付けするために使用した save_counter だけを含みます。


バイアス変数のためのサブグラフの可視化

restore() は状態オブジェクトを返します、これはオプションのアサーションを持ちます。新しい Checkpoint で作成した総てのオブジェクトは復元されますので、status.assert_existing_objects_matched() が通ります。

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fce04957c18>

チェックポイントには、層のカーネルと optimizer の変数を含む、マッチしていない多くのオブジェクトがあります。status.assert_consumed() はチェックポイントとプログラムが正確にマッチする場合に限りパスし、そしてここでは例外を投げます。

 

遅延復元 (= Delayed restorations)

TensorFlow の Layer オブジェクトは変数の作成を最初の呼び出しまで遅らせるかもしれません、入力 shape が利用可能であるときです。例えば Dense 層のカーネルの shape は層の入力と出力 shape の両者に依拠しますので、コンストラクタ引数として必要な出力 shape は単独で変数を作成するために十分な情報ではありません。層の呼び出しはまた変数の値を読みますので、restore は変数の作成とその最初の使用の間で起きる必要があります。

この用法をサポートするために、tf.train.Checkpoint はマッチした変数をまだ持たない restore をキューに入れます。

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.4939117 4.8438153 4.976974  4.979875  4.999077 ]]

 

チェックポイントを手動で調べる

tf.train.list_variables はチェックポイント・キーとチェックポイントの変数の shape をリストします。チェックポイント・キーは上で示したグラフのパスです。

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('net/l1/.ATTRIBUTES/OBJECT_CONFIG_JSON', []),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/.ATTRIBUTES/OBJECT_CONFIG_JSON', []),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/epsilon/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

 

リストと辞書追跡

self.l1 = tf.keras.layers.Dense(5) のような直接の属性割り当てと同様に、リストと辞書の属性への割り当てはそれらの内容を追跡します。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

貴方はリストと辞書のためのラッパー・オブジェクトに気がつくかもしれません。これらのラッパーは基礎的なデータ構造のチェックポイント可能なバージョンです。ちょうどローディングに基づく属性のように、これらのラッパーは変数の値をそれがコンテナに追加されるとすぐに復元します。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

同じ追跡が tf.keras.Model のサブクラスに自動的に適用されて例えば層のリストを追跡するために使用されるかもしれません。

 

Estimator でオブジェクトベースのチェックポイントをセーブする

Estimator のガイド を見てください。

Estimator はデフォルトでチェックポイントを前のセクションで説明されたオブジェクトグラフではなく変数名でセーブします。tf.train.Checkpoint は名前ベースのチェックポイントを受け取りますが、モデルの一部を Estimator の model_fn の外側に移すときには変数名は変わるかもしれません。オブジェクトベースのチェックポイントのセーブは Estimator の内側でのモデルを訓練して外側でそれを使用することを容易にします。

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
WARNING: Logging before flag parsing goes to stderr.
W0307 18:07:17.775649 140524399343360 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/training_util.py:238: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fcddc691be0>

それから tf.train.Checkpoint は Estimator のチェックポイントをその model_dir からロードできます。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

 

要約

TensorFlow オブジェクトはそれらが使用する変数の値をセーブして復元するための容易で自動的なメカニズムを提供します。

 

以上



AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com