TensorFlow 2.0 : 上級 Tutorials : 分散訓練 :- 分散ストラテジーを使用してモデルをセーブとロードする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/02/2019
* 本ページは、TensorFlow org サイトの TF 2.0 – Advanced Tutorials – Distributed training の以下のページを翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
分散訓練 :- 分散ストラテジーを使用してモデルをセーブとロードする
概要
訓練の間にモデルをセーブしてロードすることは一般的です。keras モデルをセーブしてロードするために API の 2 つのセットがあります: 高位 API、そして低位 API です。このチュートリアルは tf.distribute.Strategy を使用するとき SavedModel API をどのように使用できるかを実演します。一般に SavedModel とシリアライゼーションについて学習するためには、saved model ガイド、そして Keras モデル・シリアライゼーション・ガイド を読んでください。単純なサンプルで始めましょう。
依存性をインポートします :
from __future__ import absolute_import, division, print_function, unicode_literals import tensorflow_datasets as tfds import tensorflow as tf tfds.disable_progress_bar()
tf.distribute.Strategy を使用してデータとモデルを準備します :
mirrored_strategy = tf.distribute.MirroredStrategy() def get_data(): datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True) mnist_train, mnist_test = datasets['train'], datasets['test'] BUFFER_SIZE = 10000 BATCH_SIZE_PER_REPLICA = 64 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync def scale(image, label): image = tf.cast(image, tf.float32) image /= 255 return image, label train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) return train_dataset, eval_dataset def get_model(): with mirrored_strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) return model
モデルを訓練します :
model = get_model() train_dataset, eval_dataset = get_data() model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0... /home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning, /home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning, /home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning, /home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning, WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)` WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)` Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data. Epoch 1/2 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 938/938 [==============================] - 14s 14ms/step - loss: 0.1997 - accuracy: 0.9430 Epoch 2/2 938/938 [==============================] - 2s 3ms/step - loss: 0.0674 - accuracy: 0.9801 <tensorflow.python.keras.callbacks.History at 0x7f1450dc75c0>
モデルをセーブしてロードする
作業するための単純なモデルを持った今、セーブ/ロード API を見てみまよう。利用可能な API の 2 つのセットがあります :
- 高位 keras model.save と tf.keras.models.load_model
- 低位 tf.saved_model.save と tf.saved_model.load
Keras API
ここに Keras API でモデルをセーブしてロードするサンプルがあります :
keras_model_path = "/tmp/keras_save" model.save(keras_model_path) # save() should be called out of strategy scope
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
tf.distribute.Strategy なしでモデルを復旧します :
restored_keras_model = tf.keras.models.load_model(keras_model_path) restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 10s 11ms/step - loss: 0.0491 - accuracy: 0.9851 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0339 - accuracy: 0.9900 <tensorflow.python.keras.callbacks.History at 0x7f14502d00f0>
モデルを復旧した後、その上で訓練を継続できます、compile() を再度呼び出す必要さえなく、何故ならばそれはセーブ前に既にコンパイルされているからです。モデルは TensorFlow の標準 SavedModel proto 形式でセーブされます。より多くの情報については、guide to saved_model format を参照してください。
tf.distribute.strategy のスコープの外から model.save() メソッドを呼び出すことだけは重要です。スコープ内でそれを呼び出すことはサポートされません。
今はモデルをロードしてそれを訓練するために tf.distribute.Strategy を使用します :
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0") with another_strategy.scope(): restored_keras_model_ds = tf.keras.models.load_model(keras_model_path) restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 15s 16ms/step - loss: 0.0486 - accuracy: 0.9852 Epoch 2/2 938/938 [==============================] - 11s 12ms/step - loss: 0.0345 - accuracy: 0.9897
見れるように、ロードは tf.distribute.Strategy とともに期待されたように動作します。 ここで使用されるストラテジーはセービングの前に使用されたのと同じストラテジーでなくてもかまいません。
tf.saved_model API
今は低位 API を見てみましょう。モデルのセービングは keras API に類似しています :
model = get_model() # get a fresh model saved_model_path = "/tmp/tf_save" tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
ロードは tf.saved_model.load() で行なうことができます。けれども、それは低位上の API です (そしてそれ故により広範囲のユースケースを持ちます) から、それは Keras モデルを返しません。代わりに、それは推論を行なうために使用できる関数を含むオブジェクトを返します。例えば :
DEFAULT_FUNCTION_KEY = "serving_default" loaded = tf.saved_model.load(saved_model_path) inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
ロードされたオブジェクトは複数の関数を含むかもしれません、各々はキーと関連付けられています。”serving_default” はセーブされた Keras モデルによる推論関数のためのデフォルトキーです。この関数で推論を行なうためには :
predict_dataset = eval_dataset.map(lambda image, label: image) for batch in predict_dataset.take(1): print(inference_func(batch))
{'dense_3': <tf.Tensor: id=163719, shape=(64, 10), dtype=float32, numpy= array([[0.10728385, 0.10047228, 0.10261484, 0.11404606, 0.10439669, 0.09272724, 0.09213927, 0.10215139, 0.11169203, 0.07247637], [0.10643671, 0.09905838, 0.09364076, 0.11701185, 0.10924707, 0.09127147, 0.09204151, 0.10397962, 0.10052412, 0.08678856], [0.10438395, 0.10305102, 0.09159523, 0.12274157, 0.1077928 , 0.09475194, 0.0791743 , 0.11037502, 0.09112864, 0.09500553], [0.10716499, 0.09490513, 0.10168986, 0.11599538, 0.10473887, 0.08412197, 0.09422952, 0.10134438, 0.10640062, 0.08940925], [0.10903514, 0.10466785, 0.09625442, 0.09959863, 0.10513363, 0.08701214, 0.09148837, 0.10629265, 0.10480279, 0.09571434], [0.10791405, 0.0994103 , 0.08700535, 0.11203778, 0.10917093, 0.09215206, 0.09492003, 0.10987404, 0.10913727, 0.07837823], [0.10487697, 0.1086623 , 0.09380413, 0.11928336, 0.11513475, 0.08585147, 0.0855743 , 0.10447548, 0.09436136, 0.08797587], [0.11953841, 0.10592035, 0.09351776, 0.10883537, 0.11183759, 0.08616813, 0.08055402, 0.1016137 , 0.10553744, 0.08647724], [0.1001956 , 0.09496056, 0.10120489, 0.11922734, 0.09466279, 0.08841368, 0.09372198, 0.10862154, 0.10443942, 0.0945522 ], [0.10173271, 0.10577305, 0.10043491, 0.10811324, 0.09887545, 0.10345043, 0.09138881, 0.10311095, 0.10145498, 0.08566544], [0.10915802, 0.09801247, 0.09486966, 0.11884604, 0.10890727, 0.09173197, 0.08757571, 0.1097801 , 0.09345251, 0.08766618], [0.12305965, 0.10477802, 0.10122419, 0.10920276, 0.10642693, 0.0803097 , 0.089819 , 0.10291311, 0.09739247, 0.08487416], [0.11381779, 0.10931753, 0.09873011, 0.11637557, 0.11114289, 0.08807714, 0.07799057, 0.09821263, 0.10721567, 0.07912013], [0.11134263, 0.09600631, 0.09837915, 0.12407107, 0.10555959, 0.09594846, 0.08290008, 0.10883261, 0.10144335, 0.07551681], [0.11251914, 0.10554384, 0.09913144, 0.10999293, 0.10561096, 0.07938422, 0.08732757, 0.10422582, 0.1096524 , 0.08661176], [0.1071265 , 0.10107235, 0.10104455, 0.12165232, 0.09664022, 0.09286968, 0.0891385 , 0.10249172, 0.10644917, 0.08151497], [0.10819855, 0.10904767, 0.10180869, 0.11292984, 0.09951544, 0.08959039, 0.09180361, 0.10376687, 0.0935899 , 0.089749 ], [0.10038829, 0.09874914, 0.09890129, 0.11915599, 0.10857335, 0.09699985, 0.08979851, 0.10805573, 0.09952758, 0.07985032], [0.10805429, 0.10442685, 0.099279 , 0.11294735, 0.1079711 , 0.08618448, 0.08894631, 0.10854797, 0.11077981, 0.07286283], [0.11217733, 0.10181618, 0.09626949, 0.1101595 , 0.09953736, 0.08308611, 0.1039778 , 0.11402612, 0.09844285, 0.08050729], [0.1175424 , 0.10847426, 0.09683861, 0.10149902, 0.10666198, 0.08579018, 0.09005664, 0.10190959, 0.10221969, 0.08900762], [0.11500932, 0.10343985, 0.09347771, 0.12577702, 0.10971518, 0.08523733, 0.08214136, 0.10528794, 0.09671546, 0.08319884], [0.11175612, 0.10435012, 0.0984369 , 0.1191108 , 0.09581577, 0.08925212, 0.09045488, 0.10521886, 0.09986348, 0.08574096], [0.11444057, 0.10850222, 0.09963862, 0.10142051, 0.10838775, 0.08676049, 0.08990512, 0.10423999, 0.09729694, 0.0894078 ], [0.10881206, 0.11097988, 0.09718787, 0.11925952, 0.09623734, 0.08606921, 0.09177926, 0.09912794, 0.09836774, 0.09217919], [0.10792447, 0.11186771, 0.10025145, 0.1140421 , 0.10892291, 0.08941509, 0.09012905, 0.09400386, 0.10047787, 0.08296548], [0.10844032, 0.09806747, 0.0961417 , 0.11372101, 0.09871072, 0.09228362, 0.09338211, 0.10678563, 0.10351523, 0.08895217], [0.10387842, 0.09575935, 0.10545119, 0.10064226, 0.1034956 , 0.09042715, 0.09294212, 0.11890968, 0.11319113, 0.07530304], [0.1227584 , 0.10466544, 0.0841779 , 0.09771729, 0.11111001, 0.08448302, 0.0934371 , 0.11048996, 0.10305537, 0.08810547], [0.12329397, 0.09496878, 0.103743 , 0.11139977, 0.10066491, 0.08561404, 0.09533388, 0.11039741, 0.09291557, 0.08166874], [0.12045516, 0.10910697, 0.10111618, 0.09754261, 0.10887689, 0.08471666, 0.09063336, 0.10426611, 0.09986193, 0.08342412], [0.11602842, 0.11101867, 0.10563238, 0.10994542, 0.10256457, 0.0871765 , 0.0918168 , 0.10208733, 0.09566812, 0.0780618 ], [0.10475811, 0.10498275, 0.09951681, 0.10977667, 0.10362052, 0.09762244, 0.0898955 , 0.10377637, 0.09997776, 0.08607301], [0.10318095, 0.10635176, 0.10179733, 0.11117157, 0.105158 , 0.09139618, 0.09156478, 0.09806094, 0.10203766, 0.08928081], [0.09828327, 0.10664824, 0.09477657, 0.12164301, 0.10451195, 0.0932444 , 0.08256079, 0.10863042, 0.09929656, 0.09040477], [0.10853811, 0.11033206, 0.0940258 , 0.1113859 , 0.10713692, 0.08983784, 0.0831388 , 0.10531119, 0.1047819 , 0.08551157], [0.11990581, 0.12092119, 0.1031604 , 0.11084578, 0.10152177, 0.08251272, 0.08635441, 0.10541882, 0.09153688, 0.07782226], [0.10128036, 0.10670947, 0.10337903, 0.11031315, 0.1019455 , 0.09146291, 0.09584799, 0.09672082, 0.10148021, 0.09086055], [0.10947275, 0.10200524, 0.09993119, 0.10882212, 0.10303847, 0.08627772, 0.09159538, 0.09889076, 0.10446595, 0.0955004 ], [0.10939393, 0.09627976, 0.10602371, 0.1142818 , 0.10102987, 0.09605702, 0.08749614, 0.10808735, 0.10765684, 0.07369358], [0.10504688, 0.10167556, 0.10281529, 0.1192755 , 0.10568851, 0.08670694, 0.08155579, 0.11723237, 0.09987677, 0.08012642], [0.10824965, 0.09176606, 0.09443861, 0.11798903, 0.11201838, 0.09821461, 0.08609 , 0.11264105, 0.09485514, 0.0837375 ], [0.09964753, 0.10516388, 0.09635878, 0.1263353 , 0.10689379, 0.09252935, 0.08307378, 0.10495035, 0.10201847, 0.08302879], [0.10292235, 0.10445141, 0.10405432, 0.11601374, 0.09692912, 0.09204514, 0.08684668, 0.10418826, 0.1005675 , 0.09198151], [0.10444795, 0.09730533, 0.10376438, 0.12015738, 0.09772504, 0.09089442, 0.08984255, 0.10278826, 0.11849745, 0.07457713], [0.10438851, 0.10686149, 0.10246357, 0.10493152, 0.10348819, 0.08775381, 0.09083097, 0.09938617, 0.11093659, 0.08895922], [0.10544944, 0.09818077, 0.09329084, 0.11733873, 0.11078161, 0.09766933, 0.09083114, 0.10782209, 0.0947513 , 0.08388472], [0.10703585, 0.10091925, 0.09355933, 0.11673613, 0.10551108, 0.0878339 , 0.08623412, 0.10723548, 0.09952866, 0.09540622], [0.10752242, 0.10790183, 0.10093197, 0.11127086, 0.09987544, 0.08419652, 0.09178086, 0.09874151, 0.10740087, 0.09037771], [0.10816085, 0.09434848, 0.09469503, 0.11608687, 0.10236199, 0.0936012 , 0.09178096, 0.11500423, 0.10858335, 0.0753771 ], [0.10754462, 0.09610054, 0.10465137, 0.11755015, 0.09690968, 0.0879105 , 0.09588192, 0.11404539, 0.10285615, 0.07654973], [0.11241148, 0.08882284, 0.09405247, 0.10950889, 0.10673883, 0.08781228, 0.09985832, 0.1141803 , 0.10457499, 0.08203958], [0.1091413 , 0.10412721, 0.10478941, 0.10395776, 0.09535929, 0.08805153, 0.09860662, 0.10201051, 0.10118678, 0.09276962], [0.10605051, 0.09701546, 0.0949271 , 0.11232355, 0.1087898 , 0.1004519 , 0.08760285, 0.11227378, 0.10104988, 0.07951519], [0.10381093, 0.09920877, 0.08820166, 0.12324253, 0.11016691, 0.09974337, 0.08212494, 0.11035147, 0.106262 , 0.07688744], [0.11215791, 0.10010476, 0.10120094, 0.11671812, 0.10982783, 0.08855668, 0.08435183, 0.11731754, 0.08742444, 0.08233997], [0.10245907, 0.10029367, 0.10221576, 0.10540797, 0.09739272, 0.09511559, 0.09776582, 0.10032315, 0.101548 , 0.09747829], [0.10604411, 0.1033968 , 0.10131188, 0.10618415, 0.09859092, 0.08605155, 0.09868422, 0.10696511, 0.10429069, 0.08848058], [0.10748281, 0.11348563, 0.09736608, 0.1156299 , 0.10006159, 0.09052482, 0.08917169, 0.10554566, 0.0933612 , 0.08737059], [0.1029908 , 0.09412522, 0.09733023, 0.12130862, 0.10755724, 0.0986068 , 0.08982269, 0.1031662 , 0.10753588, 0.0775563 ], [0.1047086 , 0.09500396, 0.09726457, 0.11584704, 0.10728658, 0.0957251 , 0.09239414, 0.11417031, 0.09799711, 0.07960259], [0.11915238, 0.09845953, 0.09652468, 0.1032242 , 0.11258449, 0.08078464, 0.09128478, 0.10983384, 0.10177512, 0.08637629], [0.11451644, 0.09615005, 0.09332057, 0.12316085, 0.11431079, 0.08255224, 0.08722318, 0.11545283, 0.09139941, 0.08191368], [0.11015461, 0.10917749, 0.0990641 , 0.1019919 , 0.1069501 , 0.0880907 , 0.08974147, 0.1048572 , 0.10007656, 0.08989589]], dtype=float32)>}
分散マナーでロードして推論を行なうこともできます :
another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path) inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY] dist_predict_dataset = another_strategy.experimental_distribute_dataset( predict_dataset) # Calling the function in a distributed manner for batch in dist_predict_dataset: another_strategy.experimental_run_v2(inference_func, args=(batch,))
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
復旧された関数の呼び出しはセーブされたモデル上の forward パス (予測) です。ロードされた関数で訓練を続けることを望む場合はどうでしょう?あるいはロードされた関数をより大きなモデルに埋め込には?一般的な実践はこれを成すためにロードされたオブジェクトを Keras 層にラップすることです。幸い、ここで示されるように、TF Hub はこの目的のために hub.KerasLayer を持ちます :
import tensorflow_hub as hub def build_model(loaded): x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x') # Wrap what's loaded to a KerasLayer keras_layer = hub.KerasLayer(loaded, trainable=True)(x) model = tf.keras.Model(x, keras_layer) return model another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path) model = build_model(loaded) model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 10s 10ms/step - loss: 0.1881 - accuracy: 0.9451 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0638 - accuracy: 0.9812
見れるように、hub.KerasLayer は tf.saved_model.load() でロードし戻された結果を (もう一つのモデルを構築するために使用できる) Keras 層にラップします。これは転移学習のために非常に有用です。
どの API を使用するべきでしょう?
セーブについては、keras モデルで作業している場合、Keras の model.save() API を使用することが殆ど常に推奨されます。貴方がセーブしているものが Keras モデルでないのであれば、低位レベルが貴方の唯一の選択肢です。
ロードについては、貴方がどの API を使用するかはロードする API から何を得ることを望むかに依拠します。Keras モデルを得ることができない (あるいは望まない) 場合にはtf.saved_model.load() を使用します。そうでないなら、tf.keras.models.load_model() を使用します。Keras モデルでセーブした場合に限り Keras モデルを戻して得ることができることに注意してください。
API を混在させて適合させることは可能です。model.save で Keras モデルをセーブできて、低位 API, tf.saved_model.load で非 Keras モデルをロードできます。
model = get_model() # Saving the model using Keras's save() API model.save(keras_model_path) another_strategy = tf.distribute.MirroredStrategy() # Loading the model using lower level API with another_strategy.scope(): loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
Caveats (注意事項)
特別なケースは well-defined な入力を持たない Keras モデルを持つときです。例えば、Sequential モデルは任意の入力 shape なしに作成できます (Sequential([Dense(3), …])。Subclassed モデルもまた初期化後 well-defined な入力を持ちません。この場合、セーブとロードの両方で低位 API に固執するべきです。
貴方のモデルが well-defined な入力を持つかどうか確認するには、単に model.inputs が None であるか確認します。それが None でないならば、総て問題ありません。モデルが .fit, .evaluate, .predict やモデルを呼び出す (model(inputs)) とき入力 shape は自動的に定義されます。
ここのサンプルがあります :
class SubclassedModel(tf.keras.Model): output_name = 'output_layer' def __init__(self): super(SubclassedModel, self).__init__() self._dense_layer = tf.keras.layers.Dense( 5, dtype=tf.dtypes.float32, name=self.output_name) def call(self, inputs): return self._dense_layer(inputs) my_model = SubclassedModel() # my_model.save(keras_model_path) # ERROR! tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras model <__main__.SubclassedModel object at 0x7f15017bfa90>, because its inputs are not defined. WARNING:tensorflow:Skipping full serialization of Keras model <__main__.SubclassedModel object at 0x7f15017bfa90>, because its inputs are not defined. WARNING:tensorflow:Skipping full serialization of Keras layer, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
以上