ホーム » TensorFlow 2.0 » TensorFlow 2.4 : ガイド : モデルのセーブ :- 訓練チェックポイント

TensorFlow 2.4 : ガイド : モデルのセーブ :- 訓練チェックポイント

TensorFlow 2.4 : ガイド : モデルのセーブ :- 訓練チェックポイント (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 01/26/2021

* 本ページは、TensorFlow org サイトの Guide – Save a model の以下のページを翻訳した上で
適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

 

ガイド : モデルのセーブ :- 訓練チェックポイント

フレーズ「TensorFlow モデルをセーブする」は典型的には 2 つのことの一つを意味します :

  • チェックポイント, OR
  • SavedModel。

チェックポイントはモデルにより使用される総てのパラメータ (tf.Variable オブジェクト) の正確な値を捕捉します。チェックポイントはモデルにより定義された計算のどのような記述も含みません、そのため典型的には (セーブされたパラメータ値を利用する) ソースコードが利用可能であるときにだけ有用です。

他方、SavedModel 形式はパラメータ値 (チェックポイント) に加えてモデルにより定義された計算のシリアライズされた記述を含みます。この形式のモデルはモデルを作成したソースコードから独立です。そしてそれらは TensorFlow Serving, TensorFlow Lite, TensorFlow.js あるいは他のプログラミング言語 (C, C++, Java, Go, Rust, C# 等。TensorFlow API) のプログラムを通して配備のために適します。

このガイドはチェックポイントを書いて読むための API をカバーします。

 

セットアップ

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

 

tf.keras 訓練 API からセーブする

セーブとリストアについては tf.keras ガイドを見てください。

tf.keras.Model.save_weights は TensorFlow チェックポイントをセーブします。

net.save_weights('easy_checkpoint')

 

チェックポイントを書く

TensorFlow モデルの永続的な状態は tf.Variable オブジェクトにストアされます。これらは直接構築できますが、しばしば tf.keras.layers or tf.keras.Model のような高位 API を通して作成されます。

変数を管理する最も容易な方法はそれらを Python オブジェクトに装着してから、それらのオブジェクトを参照することです。

tf.train.Checkpoint, tf.keras.layers.Layer, and tf.keras.Model のサブクラスはそれらの属性に割当てられた変数を自動的に追跡します。以下のサンプルは単純な線形モデルを構築してから、チェックポイントを書きます、これはモデルの変数の総てのための値を含みます。

Model.save_weights でモデル-チェックポイントを容易にセーブできます。

 

手動チェックポインティング

セットアップ

tf.train.Checkpoint の総ての特徴を実演する手助けをするために、toy データセットと最適化ステップを定義します :

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

 

チェックポイント・オブジェクトを作成する

チェックポイントを手動で作成するために tf.train.Checkpoint オブジェクトを使用します、そこではチェックポイントすることを望むオブジェクトはオブジェクト上の属性として設定されます。

tf.train.CheckpointManager はまた複数のチェックポイントを管理するために役立つことができます。

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

 

モデルを訓練してチェックポイントする

次の訓練ループはモデルと optimizer のインスタンスを作成してから、それらを tf.train.Checkpoint オブジェクトに集めます。

それはデータの各バッチ上でループの訓練ステップを呼び出し、そして定期的にチェックポイントをディスクに書きます。

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 26.85
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 20.27
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 13.72
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 7.30
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 1.72

 

リストアして訓練を続ける

最初の訓練サイクルの後、新しいモデルとマネージャを渡して、しかし正確に貴方がやめたところで訓練を選択できます :

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.91
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.90
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.57
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.49
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.39

tf.train.CheckpointManager オブジェクトは古いチェックポイントを削除します。上ではそれは 3 つの最も最近のチェックポイントだけを保持するように configure されています。

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

これらのパス, e.g. ‘./tf_ckpts/ckpt-10’, はディスク上のファイルではありません。代わりにそれらはインデックスファイルと、変数値を含む一つまたはそれ以上のデータファイルのためのプレフィクスです。これらのプレフィクスは単一のチェックポイント・ファイル (‘./tf_ckpts/checkpoint’) 内で一緒にグループ分けされます、そこで CheckpointManager はその状態をセーブします。

ls ./tf_ckpts
checkpoint                   ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index                ckpt-9.data-00000-of-00001

 

ローディング機構

TensorFlow はロードされているオブジェクトから始めて、名前付けられたエッジを持つ有向グラフを辿ることにより変数をチェックポイントされた値に合わせます。エッジ名は典型的にはオブジェクトの属性名に由来します、例えば self.l1 = tf.keras.layers.Dense(5) 内の “l1” です。tf.train.Checkpoint は tf.train.Checkpoint(step=…) 内の “step” のように、そのキーワード引数名を使用します。

上のサンプルからの依存性グラフはこのようなものです :


optimizer は赤色、通常の変数は青色、そして optimizer スロット変数はオレンジ色にあります。他のノード — 例えば、tf.train.Checkpoint を表す — は黒色です。

スロット変数は optimizer の状態の一部ですが、特定の変数のために作成されます。例えば上の ‘m’ エッジはモメンタムに対応します、これは各変数のために Adam optimizer が追跡します。変数と optimizer の両者がセーブされる場合にスロット変数はチェックポイントにセーブされるだけですので、破線のエッジです。

tf.train.Checkpoint オブジェクト上で restore を呼び出すと要求された復元 (= restorations) をキューに入れて、チェックポイント・オブジェクトから一致するパスがあれば変数値をリストアします。例えば、ネットワークと層を通してそれへの一つのパスを再構築することにより上で定義したモデルから単にバイアスをロードできます。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[-0.00426486  0.98928887  1.9369034   2.996623    3.9505417 ]

これらの新しいオブジェクトのための依存性グラフは貴方が上で書いた大きいチェックポイントの遥かに小さい部分グラフです。それはバイアスとチェックポイントに番号付けるために tf.train.Checkpoint が使用する save カウンターだけを含みます。

restore はオプションの assertion を持つ、status オブジェクトを返します。新しいチェックポイントで作成されたオブジェクトの総てがリストアされて、従って status.assert_existing_objects_matched はパスします。

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f525dea5b38>

チェックポイントには層のカーネルと optimizer の変数を含む、一致しない多くのオブジェクトがありますstatus.assert_consumed はチェックポイントとプログラムが正確に一致する場合にだけパスし、そしてここでは例外を上げます。

 

遅延復元 (= Delayed restorations)

TensorFlow の Layer オブジェクトは変数の作成をそれらの最初の呼び出し (入力 shape が利用可能なとき) まで遅延させるかもしれません。例えば、Dense 層のカーネルの shape は層の入力と出力 shape の両者に依拠し、そのためコンストラクタ引数として必要な出力 shape はそれ自身の上の変数を作成するために十分な情報ではありません。Layer の呼び出しはまた変数値を読みますので、restore は変数の作成とその最初の使用の間に発生しなければなりません。

この作法 (= idiom) をサポートするために、tf.train.Checkpoint は一致する変数をまだ持たない restore をキューイングします。

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.7385683 4.7148175 4.7513504 4.7783995 5.0485835]]

 

チェックポイントを手動で調べる

tf.train.load_checkpoint はチェックポイントの内容への低位アクセスを与える CheckpointReader を返します。それは各変数のキーからチェックポイントの各変数のための shape と dtype へのマッピングを含みます。変数のキーは上で表示されたグラフ内のような、そのオブジェクトパスです。

Note: チェックポイントへの高位構造はありません。それは変数のためのパスと値を知るだけで、モデル あるいはそれらがどのように接続されているかの概念は持ちません。

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

そして net.l1.kernel の値に関心があれば次のコードで値を得ることができます :

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

それはまた get_tensor メソッドも提供し、変数の値を調査することを可能にします :

reader.get_tensor(key)
array([[4.7385683, 4.7148175, 4.7513504, 4.7783995, 5.0485835]],
      dtype=float32)

 

リストと辞書追跡

self.l1 = tf.keras.layers.Dense(5) のような直接的な属性割当てと同様に、リストと辞書を属性に割り当てるとそれらの内容を追跡します。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

リストと辞書のためのラッパー・オブジェクトに気付くかもしれません。これらのラッパーは基礎的なデータ構造のためのチェックポイント可能なバージョンです。ちょうど属性ベースのローディングのように、これらのラッパーはそれらがコンテナに追加されるとすぐに変数の値をリストアします。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

同じ追跡が tf.keras.Model のサブクラスに自動的に適用され、そして層のリストを追跡するサンプルのために使用されるかもしれません。

 

まとめ

TensorFlow オブジェクトはそれらが使用する変数の値をセーブしてリストアするための容易な自動機構を提供しています。

 

以上



AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com