ホーム » TensorFlow 2.0 » TensorFlow 2.0 : 上級 Tutorials : 分散訓練 :- 分散ストラテジーを使用してモデルをセーブとロードする

TensorFlow 2.0 : 上級 Tutorials : 分散訓練 :- 分散ストラテジーを使用してモデルをセーブとロードする

TensorFlow 2.0 : 上級 Tutorials : 分散訓練 :- 分散ストラテジーを使用してモデルをセーブとロードする (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/02/2019

* 本ページは、TensorFlow org サイトの TF 2.0 – Advanced Tutorials – Distributed training の以下のページを翻訳した上で
適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

分散訓練 :- 分散ストラテジーを使用してモデルをセーブとロードする

概要

訓練の間にモデルをセーブしてロードすることは一般的です。keras モデルをセーブしてロードするために API の 2 つのセットがあります: 高位 API、そして低位 API です。このチュートリアルは tf.distribute.Strategy を使用するとき SavedModel API をどのように使用できるかを実演します。一般に SavedModel とシリアライゼーションについて学習するためには、saved model ガイド、そして Keras モデル・シリアライゼーション・ガイド を読んでください。単純なサンプルで始めましょう。

依存性をインポートします :

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()

tf.distribute.Strategy を使用してデータとモデルを準備します :

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model

モデルを訓練します :

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...

/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,
/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,
/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,
/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 14s 14ms/step - loss: 0.1997 - accuracy: 0.9430
Epoch 2/2
938/938 [==============================] - 2s 3ms/step - loss: 0.0674 - accuracy: 0.9801

<tensorflow.python.keras.callbacks.History at 0x7f1450dc75c0>

 

モデルをセーブしてロードする

作業するための単純なモデルを持った今、セーブ/ロード API を見てみまよう。利用可能な API の 2 つのセットがあります :

 

Keras API

ここに Keras API でモデルをセーブしてロードするサンプルがあります :

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)  # save() should be called out of strategy scope
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy なしでモデルを復旧します :

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 10s 11ms/step - loss: 0.0491 - accuracy: 0.9851
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0339 - accuracy: 0.9900

<tensorflow.python.keras.callbacks.History at 0x7f14502d00f0>

モデルを復旧した後、その上で訓練を継続できます、compile() を再度呼び出す必要さえなく、何故ならばそれはセーブ前に既にコンパイルされているからです。モデルは TensorFlow の標準 SavedModel proto 形式でセーブされます。より多くの情報については、guide to saved_model format を参照してください。

tf.distribute.strategy のスコープの外から model.save() メソッドを呼び出すことだけは重要です。スコープ内でそれを呼び出すことはサポートされません。

今はモデルをロードしてそれを訓練するために tf.distribute.Strategy を使用します :

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 15s 16ms/step - loss: 0.0486 - accuracy: 0.9852
Epoch 2/2
938/938 [==============================] - 11s 12ms/step - loss: 0.0345 - accuracy: 0.9897

見れるように、ロードは tf.distribute.Strategy とともに期待されたように動作します。 ここで使用されるストラテジーはセービングの前に使用されたのと同じストラテジーでなくてもかまいません。

 

tf.saved_model API

今は低位 API を見てみましょう。モデルのセービングは keras API に類似しています :

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

ロードは tf.saved_model.load() で行なうことができます。けれども、それは低位上の API です (そしてそれ故により広範囲のユースケースを持ちます) から、それは Keras モデルを返しません。代わりに、それは推論を行なうために使用できる関数を含むオブジェクトを返します。例えば :

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

ロードされたオブジェクトは複数の関数を含むかもしれません、各々はキーと関連付けられています。”serving_default” はセーブされた Keras モデルによる推論関数のためのデフォルトキーです。この関数で推論を行なうためには :

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: id=163719, shape=(64, 10), dtype=float32, numpy=
array([[0.10728385, 0.10047228, 0.10261484, 0.11404606, 0.10439669,
        0.09272724, 0.09213927, 0.10215139, 0.11169203, 0.07247637],
       [0.10643671, 0.09905838, 0.09364076, 0.11701185, 0.10924707,
        0.09127147, 0.09204151, 0.10397962, 0.10052412, 0.08678856],
       [0.10438395, 0.10305102, 0.09159523, 0.12274157, 0.1077928 ,
        0.09475194, 0.0791743 , 0.11037502, 0.09112864, 0.09500553],
       [0.10716499, 0.09490513, 0.10168986, 0.11599538, 0.10473887,
        0.08412197, 0.09422952, 0.10134438, 0.10640062, 0.08940925],
       [0.10903514, 0.10466785, 0.09625442, 0.09959863, 0.10513363,
        0.08701214, 0.09148837, 0.10629265, 0.10480279, 0.09571434],
       [0.10791405, 0.0994103 , 0.08700535, 0.11203778, 0.10917093,
        0.09215206, 0.09492003, 0.10987404, 0.10913727, 0.07837823],
       [0.10487697, 0.1086623 , 0.09380413, 0.11928336, 0.11513475,
        0.08585147, 0.0855743 , 0.10447548, 0.09436136, 0.08797587],
       [0.11953841, 0.10592035, 0.09351776, 0.10883537, 0.11183759,
        0.08616813, 0.08055402, 0.1016137 , 0.10553744, 0.08647724],
       [0.1001956 , 0.09496056, 0.10120489, 0.11922734, 0.09466279,
        0.08841368, 0.09372198, 0.10862154, 0.10443942, 0.0945522 ],
       [0.10173271, 0.10577305, 0.10043491, 0.10811324, 0.09887545,
        0.10345043, 0.09138881, 0.10311095, 0.10145498, 0.08566544],
       [0.10915802, 0.09801247, 0.09486966, 0.11884604, 0.10890727,
        0.09173197, 0.08757571, 0.1097801 , 0.09345251, 0.08766618],
       [0.12305965, 0.10477802, 0.10122419, 0.10920276, 0.10642693,
        0.0803097 , 0.089819  , 0.10291311, 0.09739247, 0.08487416],
       [0.11381779, 0.10931753, 0.09873011, 0.11637557, 0.11114289,
        0.08807714, 0.07799057, 0.09821263, 0.10721567, 0.07912013],
       [0.11134263, 0.09600631, 0.09837915, 0.12407107, 0.10555959,
        0.09594846, 0.08290008, 0.10883261, 0.10144335, 0.07551681],
       [0.11251914, 0.10554384, 0.09913144, 0.10999293, 0.10561096,
        0.07938422, 0.08732757, 0.10422582, 0.1096524 , 0.08661176],
       [0.1071265 , 0.10107235, 0.10104455, 0.12165232, 0.09664022,
        0.09286968, 0.0891385 , 0.10249172, 0.10644917, 0.08151497],
       [0.10819855, 0.10904767, 0.10180869, 0.11292984, 0.09951544,
        0.08959039, 0.09180361, 0.10376687, 0.0935899 , 0.089749  ],
       [0.10038829, 0.09874914, 0.09890129, 0.11915599, 0.10857335,
        0.09699985, 0.08979851, 0.10805573, 0.09952758, 0.07985032],
       [0.10805429, 0.10442685, 0.099279  , 0.11294735, 0.1079711 ,
        0.08618448, 0.08894631, 0.10854797, 0.11077981, 0.07286283],
       [0.11217733, 0.10181618, 0.09626949, 0.1101595 , 0.09953736,
        0.08308611, 0.1039778 , 0.11402612, 0.09844285, 0.08050729],
       [0.1175424 , 0.10847426, 0.09683861, 0.10149902, 0.10666198,
        0.08579018, 0.09005664, 0.10190959, 0.10221969, 0.08900762],
       [0.11500932, 0.10343985, 0.09347771, 0.12577702, 0.10971518,
        0.08523733, 0.08214136, 0.10528794, 0.09671546, 0.08319884],
       [0.11175612, 0.10435012, 0.0984369 , 0.1191108 , 0.09581577,
        0.08925212, 0.09045488, 0.10521886, 0.09986348, 0.08574096],
       [0.11444057, 0.10850222, 0.09963862, 0.10142051, 0.10838775,
        0.08676049, 0.08990512, 0.10423999, 0.09729694, 0.0894078 ],
       [0.10881206, 0.11097988, 0.09718787, 0.11925952, 0.09623734,
        0.08606921, 0.09177926, 0.09912794, 0.09836774, 0.09217919],
       [0.10792447, 0.11186771, 0.10025145, 0.1140421 , 0.10892291,
        0.08941509, 0.09012905, 0.09400386, 0.10047787, 0.08296548],
       [0.10844032, 0.09806747, 0.0961417 , 0.11372101, 0.09871072,
        0.09228362, 0.09338211, 0.10678563, 0.10351523, 0.08895217],
       [0.10387842, 0.09575935, 0.10545119, 0.10064226, 0.1034956 ,
        0.09042715, 0.09294212, 0.11890968, 0.11319113, 0.07530304],
       [0.1227584 , 0.10466544, 0.0841779 , 0.09771729, 0.11111001,
        0.08448302, 0.0934371 , 0.11048996, 0.10305537, 0.08810547],
       [0.12329397, 0.09496878, 0.103743  , 0.11139977, 0.10066491,
        0.08561404, 0.09533388, 0.11039741, 0.09291557, 0.08166874],
       [0.12045516, 0.10910697, 0.10111618, 0.09754261, 0.10887689,
        0.08471666, 0.09063336, 0.10426611, 0.09986193, 0.08342412],
       [0.11602842, 0.11101867, 0.10563238, 0.10994542, 0.10256457,
        0.0871765 , 0.0918168 , 0.10208733, 0.09566812, 0.0780618 ],
       [0.10475811, 0.10498275, 0.09951681, 0.10977667, 0.10362052,
        0.09762244, 0.0898955 , 0.10377637, 0.09997776, 0.08607301],
       [0.10318095, 0.10635176, 0.10179733, 0.11117157, 0.105158  ,
        0.09139618, 0.09156478, 0.09806094, 0.10203766, 0.08928081],
       [0.09828327, 0.10664824, 0.09477657, 0.12164301, 0.10451195,
        0.0932444 , 0.08256079, 0.10863042, 0.09929656, 0.09040477],
       [0.10853811, 0.11033206, 0.0940258 , 0.1113859 , 0.10713692,
        0.08983784, 0.0831388 , 0.10531119, 0.1047819 , 0.08551157],
       [0.11990581, 0.12092119, 0.1031604 , 0.11084578, 0.10152177,
        0.08251272, 0.08635441, 0.10541882, 0.09153688, 0.07782226],
       [0.10128036, 0.10670947, 0.10337903, 0.11031315, 0.1019455 ,
        0.09146291, 0.09584799, 0.09672082, 0.10148021, 0.09086055],
       [0.10947275, 0.10200524, 0.09993119, 0.10882212, 0.10303847,
        0.08627772, 0.09159538, 0.09889076, 0.10446595, 0.0955004 ],
       [0.10939393, 0.09627976, 0.10602371, 0.1142818 , 0.10102987,
        0.09605702, 0.08749614, 0.10808735, 0.10765684, 0.07369358],
       [0.10504688, 0.10167556, 0.10281529, 0.1192755 , 0.10568851,
        0.08670694, 0.08155579, 0.11723237, 0.09987677, 0.08012642],
       [0.10824965, 0.09176606, 0.09443861, 0.11798903, 0.11201838,
        0.09821461, 0.08609   , 0.11264105, 0.09485514, 0.0837375 ],
       [0.09964753, 0.10516388, 0.09635878, 0.1263353 , 0.10689379,
        0.09252935, 0.08307378, 0.10495035, 0.10201847, 0.08302879],
       [0.10292235, 0.10445141, 0.10405432, 0.11601374, 0.09692912,
        0.09204514, 0.08684668, 0.10418826, 0.1005675 , 0.09198151],
       [0.10444795, 0.09730533, 0.10376438, 0.12015738, 0.09772504,
        0.09089442, 0.08984255, 0.10278826, 0.11849745, 0.07457713],
       [0.10438851, 0.10686149, 0.10246357, 0.10493152, 0.10348819,
        0.08775381, 0.09083097, 0.09938617, 0.11093659, 0.08895922],
       [0.10544944, 0.09818077, 0.09329084, 0.11733873, 0.11078161,
        0.09766933, 0.09083114, 0.10782209, 0.0947513 , 0.08388472],
       [0.10703585, 0.10091925, 0.09355933, 0.11673613, 0.10551108,
        0.0878339 , 0.08623412, 0.10723548, 0.09952866, 0.09540622],
       [0.10752242, 0.10790183, 0.10093197, 0.11127086, 0.09987544,
        0.08419652, 0.09178086, 0.09874151, 0.10740087, 0.09037771],
       [0.10816085, 0.09434848, 0.09469503, 0.11608687, 0.10236199,
        0.0936012 , 0.09178096, 0.11500423, 0.10858335, 0.0753771 ],
       [0.10754462, 0.09610054, 0.10465137, 0.11755015, 0.09690968,
        0.0879105 , 0.09588192, 0.11404539, 0.10285615, 0.07654973],
       [0.11241148, 0.08882284, 0.09405247, 0.10950889, 0.10673883,
        0.08781228, 0.09985832, 0.1141803 , 0.10457499, 0.08203958],
       [0.1091413 , 0.10412721, 0.10478941, 0.10395776, 0.09535929,
        0.08805153, 0.09860662, 0.10201051, 0.10118678, 0.09276962],
       [0.10605051, 0.09701546, 0.0949271 , 0.11232355, 0.1087898 ,
        0.1004519 , 0.08760285, 0.11227378, 0.10104988, 0.07951519],
       [0.10381093, 0.09920877, 0.08820166, 0.12324253, 0.11016691,
        0.09974337, 0.08212494, 0.11035147, 0.106262  , 0.07688744],
       [0.11215791, 0.10010476, 0.10120094, 0.11671812, 0.10982783,
        0.08855668, 0.08435183, 0.11731754, 0.08742444, 0.08233997],
       [0.10245907, 0.10029367, 0.10221576, 0.10540797, 0.09739272,
        0.09511559, 0.09776582, 0.10032315, 0.101548  , 0.09747829],
       [0.10604411, 0.1033968 , 0.10131188, 0.10618415, 0.09859092,
        0.08605155, 0.09868422, 0.10696511, 0.10429069, 0.08848058],
       [0.10748281, 0.11348563, 0.09736608, 0.1156299 , 0.10006159,
        0.09052482, 0.08917169, 0.10554566, 0.0933612 , 0.08737059],
       [0.1029908 , 0.09412522, 0.09733023, 0.12130862, 0.10755724,
        0.0986068 , 0.08982269, 0.1031662 , 0.10753588, 0.0775563 ],
       [0.1047086 , 0.09500396, 0.09726457, 0.11584704, 0.10728658,
        0.0957251 , 0.09239414, 0.11417031, 0.09799711, 0.07960259],
       [0.11915238, 0.09845953, 0.09652468, 0.1032242 , 0.11258449,
        0.08078464, 0.09128478, 0.10983384, 0.10177512, 0.08637629],
       [0.11451644, 0.09615005, 0.09332057, 0.12316085, 0.11431079,
        0.08255224, 0.08722318, 0.11545283, 0.09139941, 0.08191368],
       [0.11015461, 0.10917749, 0.0990641 , 0.1019919 , 0.1069501 ,
        0.0880907 , 0.08974147, 0.1048572 , 0.10007656, 0.08989589]],
      dtype=float32)>}

分散マナーでロードして推論を行なうこともできます :

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.experimental_run_v2(inference_func, 
                                         args=(batch,))
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

復旧された関数の呼び出しはセーブされたモデル上の forward パス (予測) です。ロードされた関数で訓練を続けることを望む場合はどうでしょう?あるいはロードされた関数をより大きなモデルに埋め込には?一般的な実践はこれを成すためにロードされたオブジェクトを Keras 層にラップすることです。幸い、ここで示されるように、TF Hub はこの目的のために hub.KerasLayer を持ちます :

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 10s 10ms/step - loss: 0.1881 - accuracy: 0.9451
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0638 - accuracy: 0.9812

見れるように、hub.KerasLayer は tf.saved_model.load() でロードし戻された結果を (もう一つのモデルを構築するために使用できる) Keras 層にラップします。これは転移学習のために非常に有用です。

 

どの API を使用するべきでしょう?

セーブについては、keras モデルで作業している場合、Keras の model.save() API を使用することが殆ど常に推奨されます。貴方がセーブしているものが Keras モデルでないのであれば、低位レベルが貴方の唯一の選択肢です。

ロードについては、貴方がどの API を使用するかはロードする API から何を得ることを望むかに依拠します。Keras モデルを得ることができない (あるいは望まない) 場合にはtf.saved_model.load() を使用します。そうでないなら、tf.keras.models.load_model() を使用します。Keras モデルでセーブした場合に限り Keras モデルを戻して得ることができることに注意してください。

API を混在させて適合させることは可能です。model.save で Keras モデルをセーブできて、低位 API, tf.saved_model.load で非 Keras モデルをロードできます。

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

 

Caveats (注意事項)

特別なケースは well-defined な入力を持たない Keras モデルを持つときです。例えば、Sequential モデルは任意の入力 shape なしに作成できます (Sequential([Dense(3), …])。Subclassed モデルもまた初期化後 well-defined な入力を持ちません。この場合、セーブとロードの両方で低位 API に固執するべきです。

貴方のモデルが well-defined な入力を持つかどうか確認するには、単に model.inputs が None であるか確認します。それが None でないならば、総て問題ありません。モデルが .fit, .evaluate, .predict やモデルを呼び出す (model(inputs)) とき入力 shape は自動的に定義されます。

ここのサンプルがあります :

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras model <__main__.SubclassedModel object at 0x7f15017bfa90>, because its inputs are not defined.

WARNING:tensorflow:Skipping full serialization of Keras model <__main__.SubclassedModel object at 0x7f15017bfa90>, because its inputs are not defined.

WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.

WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets
 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com