TensorFlow 2.0 Alpha : 上級 Tutorials : 分散訓練 :- TensorFlow の分散訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/12/2019
* 本ページは、TensorFlow の本家サイトの TF 2.0 Alpha – Advanced Tutorials – Distributed training の以下のページを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
分散訓練 :- TensorFlow の分散訓練
概要
tf.distribute.Strategy API は複数の処理ユニットに渡り貴方の訓練を分散するための抽象を提供します。目標はユーザに既存のモデルと訓練コードを (最小限の変更で) 使用して分散訓練を可能にすることです。
このチュートリアルは tf.distribute.MirroredStrategy を使用します、これは一つのマシン上の多くの GPU 上で同期訓練を伴う in-graph リプリケーションを行ないます。本質的には、それはモデルの変数の総てを各プロセッサにコピーします。それから、それは総てのプロセッサからの勾配を結合するために all-reduce を使用して結合された値をモデルの総てのコピーに適用します。
MirroredStategy は TensorFlow コアで利用可能な幾つかの分散ストラテジーの一つです。より多くのストラテジーについて 分散ストラテジー・ガイド で読むことができます。
Keras API
このサンプルはモデルと訓練ループを構築するために tf.kera API を使用します。カスタム訓練ループについては、このチュートリアル を見てください。
Import 依存性
from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow !pip install -q tensorflow==2.0.0-alpha0 import tensorflow_datasets as tfds import tensorflow as tf import os
データセットをダウンロードする
MNIST データセットをダウンロードしてそれを TensorFlow Datasets からロードします。これは tf.data フォーマットの dataset を返します。
with_info を True に設定するとデータセット全体に対するメタデータを含みます、これはここでは ds_info にセーブされます。他のものの中で、このメタデータは訓練とテストサンプルの数を含みます。
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True) mnist_train, mnist_test = datasets['train'], datasets['test']
Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Dl Completed...: 0%| | 0/1 [00:00<?, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Dl Completed...: 0%| | 0/2 [00:00<?, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Dl Completed...: 0%| | 0/3 [00:00<?, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Dl Completed...: 0%| | 0/4 [00:00<?, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Dl Completed...: 0%| | 0/4 [00:00<?, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] Downloading / extracting dataset mnist (11.06 MiB) to /root/tensorflow_datasets/mnist/1.0.0... Dl Completed...: 25%|██▌ | 1/4 [00:00<00:00, 5.02 url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Dl Completed...: 25%|██▌ | 1/4 [00:00<00:00, 5.02 url/s] Dl Size...: 0%| | 0/1 [00:00<?, ? MiB/s] Dl Completed...: 25%|██▌ | 1/4 [00:00<00:00, 5.02 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Dl Completed...: 25%|██▌ | 1/4 [00:00<00:00, 5.02 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Extraction completed...: 0%| | 0/1 [00:00<?, ? file/s] Dl Completed...: 25%|██▌ | 1/4 [00:00<00:00, 5.02 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Dl Completed...: 25%|██▌ | 1/4 [00:00<00:00, 5.02 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Dl Completed...: 50%|█████ | 2/4 [00:00<00:00, 5.29 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Dl Completed...: 50%|█████ | 2/4 [00:00<00:00, 5.29 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Extraction completed...: 50%|█████ | 1/2 [00:00<00:00, 4.26 file/s] Dl Completed...: 50%|█████ | 2/4 [00:00<00:00, 5.29 url/s] Dl Size...: 0%| | 0/10 [00:00<?, ? MiB/s] Extraction completed...: 100%|██████████| 2/2 [00:00<00:00, 4.76 file/s] Dl Completed...: 50%|█████ | 2/4 [00:00<00:00, 5.29 url/s] Dl Size...: 10%|█ | 1/10 [00:00<00:05, 1.59 MiB/s] Dl Completed...: 50%|█████ | 2/4 [00:00<00:00, 5.29 url/s] Dl Size...: 20%|██ | 2/10 [00:00<00:05, 1.59 MiB/s] Dl Completed...: 75%|███████▌ | 3/4 [00:00<00:00, 4.29 url/s] Dl Size...: 20%|██ | 2/10 [00:00<00:05, 1.59 MiB/s] Dl Completed...: 75%|███████▌ | 3/4 [00:00<00:00, 4.29 url/s] Dl Size...: 20%|██ | 2/10 [00:00<00:05, 1.59 MiB/s] Extraction completed...: 67%|██████▋ | 2/3 [00:00<00:00, 4.76 file/s] Dl Completed...: 75%|███████▌ | 3/4 [00:00<00:00, 4.29 url/s] Dl Size...: 30%|███ | 3/10 [00:00<00:03, 2.16 MiB/s] Extraction completed...: 67%|██████▋ | 2/3 [00:00<00:00, 4.76 file/s] Dl Completed...: 75%|███████▌ | 3/4 [00:00<00:00, 4.29 url/s] Dl Size...: 30%|███ | 3/10 [00:00<00:03, 2.16 MiB/s] Extraction completed...: 100%|██████████| 3/3 [00:00<00:00, 3.61 file/s] Dl Completed...: 75%|███████▌ | 3/4 [00:00<00:00, 4.29 url/s] Dl Size...: 40%|████ | 4/10 [00:00<00:02, 2.75 MiB/s] Dl Completed...: 75%|███████▌ | 3/4 [00:00<00:00, 4.29 url/s] Dl Size...: 50%|█████ | 5/10 [00:00<00:01, 2.75 MiB/s] Extraction completed...: 100%|██████████| 3/3 [00:00<00:00, 3.61 file/s] Dl Completed...: 75%|███████▌ | 3/4 [00:01<00:00, 4.29 url/s] Dl Size...: 60%|██████ | 6/10 [00:01<00:01, 3.50 MiB/s] Dl Completed...: 75%|███████▌ | 3/4 [00:01<00:00, 4.29 url/s] Dl Size...: 70%|███████ | 7/10 [00:01<00:00, 3.50 MiB/s] Extraction completed...: 100%|██████████| 3/3 [00:01<00:00, 3.61 file/s] Dl Completed...: 75%|███████▌ | 3/4 [00:01<00:00, 4.29 url/s] Dl Size...: 80%|████████ | 8/10 [00:01<00:00, 4.47 MiB/s] Extraction completed...: 100%|██████████| 3/3 [00:01<00:00, 3.61 file/s] Dl Completed...: 75%|███████▌ | 3/4 [00:01<00:00, 4.29 url/s] Dl Size...: 90%|█████████ | 9/10 [00:01<00:00, 5.31 MiB/s] Dl Completed...: 75%|███████▌ | 3/4 [00:01<00:00, 4.29 url/s] Dl Size...: 100%|██████████| 10/10 [00:01<00:00, 5.31 MiB/s] Dl Completed...: 100%|██████████| 4/4 [00:01<00:00, 2.51 url/s] Dl Size...: 100%|██████████| 10/10 [00:01<00:00, 5.31 MiB/s] Dl Completed...: 100%|██████████| 4/4 [00:01<00:00, 2.51 url/s] Dl Size...: 100%|██████████| 10/10 [00:01<00:00, 5.31 MiB/s] Extraction completed...: 75%|███████▌ | 3/4 [00:01<00:00, 3.61 file/s] Dl Completed...: 100%|██████████| 4/4 [00:01<00:00, 2.51 url/s] Dl Size...: 100%|██████████| 10/10 [00:01<00:00, 5.31 MiB/s] Extraction completed...: 100%|██████████| 4/4 [00:01<00:00, 1.94 file/s] Dl Size...: 100%|██████████| 10/10 [00:01<00:00, 5.24 MiB/s] 1 examples [00:00, 6.16 examples/s] 60000 examples [00:13, 4349.69 examples/s] Shuffling...: 0%| | 0/10 [00:00<?, ? shard/s]WARNING: Logging before flag parsing goes to stderr. W0405 15:23:16.461484 140384515561216 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow_datasets/core/file_format_adapter.py:249: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)` Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 260273.29 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 156708.54 examples/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 271998.27 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Shuffling...: 20%|██ | 2/10 [00:00<00:00, 13.66 shard/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 294326.80 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 144187.84 examples/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 252869.48 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Shuffling...: 40%|████ | 4/10 [00:00<00:00, 13.64 shard/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 280211.82 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 150449.41 examples/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 284910.10 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Shuffling...: 60%|██████ | 6/10 [00:00<00:00, 13.86 shard/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 238819.31 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 137965.23 examples/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 244280.96 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Shuffling...: 80%|████████ | 8/10 [00:00<00:00, 13.58 shard/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 292693.93 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 148655.99 examples/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 6000 examples [00:00, 232564.68 examples/s] Writing...: 0%| | 0/6000 [00:00<?, ? examples/s] Shuffling...: 100%|██████████| 10/10 [00:00<00:00, 13.61 shard/s] 10000 examples [00:02, 4444.35 examples/s] Shuffling...: 0%| | 0/1 [00:00<?, ? shard/s] Reading...: 0 examples [00:00, ? examples/s] Reading...: 10000 examples [00:00, 313138.63 examples/s] Writing...: 0%| | 0/10000 [00:00<?, ? examples/s] Shuffling...: 100%|██████████| 1/1 [00:00<00:00, 9.25 shard/s]
分散ストラテジーを定義する
MirroredStrategy オブジェクトを作成します。これは分散を処理し、内側でモデルを構築するためのコンテキストマネージャ (tf.distribute.MirroredStrategy.scope) を提供します。
strategy = tf.distribute.MirroredStrategy()
W0405 15:23:20.099184 140384515561216 cross_device_ops.py:1111] Not all devices in `tf.distribute.Strategy` are visible to TensorFlow.
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
入力パイプラインをセットアップする
モデルがマルチ GPU 上で訓練されるのであれば、特別な計算パワーを効果的に利用するためにバッチサイズはそれに従って増やされるべきです。更に、学習率もそれに従って調整されるべきです。
# You can also do ds_info.splits.total_num_examples to get the total # number of examples in the dataset. num_train_examples = ds_info.splits['train'].num_examples num_test_examples = ds_info.splits['test'].num_examples BUFFER_SIZE = 10000 BATCH_SIZE_PER_REPLICA = 64 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
0-255 のピクセル値は 0-1 範囲に正規化されなければなりません。このスケールを関数で定義します。
def scale(image, label): image = tf.cast(image, tf.float32) image /= 255 return image, label
この関数を訓練とテストデータに適用し、訓練データをシャッフルし、そして 訓練のためにそれをバッチ化します。
train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
モデルを作成する
strategy.scope のコンテキストで Keras モデルを作成してコンパイルします。
with strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
コールバックを定義する
ここで使用されるコールバックは :
- Tensorboard: このコールバックはグラフを可視化することを可能にする TensorBoard のためのログを書きます。
- モデル・チェックポイント: このコールバックは総てのエポック後にモデルをセーブします。
- 学習率スケジューラ: このコールバックを使用すると、総てのエポック/バッチ後に変更する学習率をスケジューリングできます。
説明目的で、このノートブックでは学習率を表示するための print コールバックを追加します。
# Define the checkpoint directory to store the checkpoints checkpoint_dir = './training_checkpoints' # Name of the checkpoint files checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate. # You can define any decay function you need. def decay(epoch): if epoch < 3: return 1e-3 elif epoch >= 3 and epoch < 7: return 1e-4 else: return 1e-5
# Callback for printing the LR at the end of each epoch. class PrintLR(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print ('\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.lr.numpy()))
callbacks = [ tf.keras.callbacks.TensorBoard(log_dir='./logs'), tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True), tf.keras.callbacks.LearningRateScheduler(decay), PrintLR() ]
訓練そして評価する
さて、通常の方法でモデルを訓練します、モデル上で fit を呼び出してチュートリアルの最初に作成されたデータセットを渡します。このステップは貴方が訓練を分散していてもそうでなくても同じです。
model.fit(train_dataset, epochs=10, callbacks=callbacks)
W0405 15:23:21.675539 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model. W0405 15:23:21.676841 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model. W0405 15:23:21.677886 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model. W0405 15:23:21.678794 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model. Epoch 1/10 938/Unknown - 9s 9ms/step - loss: 0.1977 - accuracy: 0.9434 Learning rate for epoch 1 is 0.0010000000474974513 938/938 [==============================] - 9s 9ms/step - loss: 0.1977 - accuracy: 0.9434 Epoch 2/10 930/938 [============================>.] - ETA: 0s - loss: 0.0680 - accuracy: 0.9791 Learning rate for epoch 2 is 0.0010000000474974513 938/938 [==============================] - 7s 7ms/step - loss: 0.0678 - accuracy: 0.9791 Epoch 3/10 933/938 [============================>.] - ETA: 0s - loss: 0.0463 - accuracy: 0.9862 Learning rate for epoch 3 is 0.0010000000474974513 938/938 [==============================] - 7s 8ms/step - loss: 0.0464 - accuracy: 0.9861 Epoch 4/10 935/938 [============================>.] - ETA: 0s - loss: 0.0256 - accuracy: 0.9927 Learning rate for epoch 4 is 9.999999747378752e-05 938/938 [==============================] - 7s 8ms/step - loss: 0.0255 - accuracy: 0.9927 Epoch 5/10 934/938 [============================>.] - ETA: 0s - loss: 0.0221 - accuracy: 0.9941 Learning rate for epoch 5 is 9.999999747378752e-05 938/938 [==============================] - 7s 8ms/step - loss: 0.0220 - accuracy: 0.9941 Epoch 6/10 936/938 [============================>.] - ETA: 0s - loss: 0.0202 - accuracy: 0.9947 Learning rate for epoch 6 is 9.999999747378752e-05 938/938 [==============================] - 7s 7ms/step - loss: 0.0201 - accuracy: 0.9947 Epoch 7/10 932/938 [============================>.] - ETA: 0s - loss: 0.0187 - accuracy: 0.9952 Learning rate for epoch 7 is 9.999999747378752e-05 938/938 [==============================] - 7s 7ms/step - loss: 0.0186 - accuracy: 0.9952 Epoch 8/10 935/938 [============================>.] - ETA: 0s - loss: 0.0161 - accuracy: 0.9963 Learning rate for epoch 8 is 9.999999747378752e-06 938/938 [==============================] - 7s 8ms/step - loss: 0.0161 - accuracy: 0.9963 Epoch 9/10 932/938 [============================>.] - ETA: 0s - loss: 0.0158 - accuracy: 0.9964 Learning rate for epoch 9 is 9.999999747378752e-06 938/938 [==============================] - 7s 8ms/step - loss: 0.0158 - accuracy: 0.9964 Epoch 10/10 934/938 [============================>.] - ETA: 0s - loss: 0.0156 - accuracy: 0.9965 Learning rate for epoch 10 is 9.999999747378752e-06 938/938 [==============================] - 7s 7ms/step - loss: 0.0156 - accuracy: 0.9965
下で見れるように、チェックポイントはセーブされています。
# check the checkpoint directory !ls {checkpoint_dir}
checkpoint ckpt_5.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_5.index ckpt_1.index ckpt_6.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_6.index ckpt_10.index ckpt_7.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_7.index ckpt_2.index ckpt_8.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_8.index ckpt_3.index ckpt_9.data-00000-of-00001 ckpt_4.data-00000-of-00001 ckpt_9.index ckpt_4.index
モデルがどのように遂行するかを見るために、最新のチェックポイントをロードしてテストデータ上で evaluate を呼び出します。
適切なデータセットを使用して前のように evaluate を呼び出します。
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) eval_loss, eval_acc = model.evaluate(eval_dataset) print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/Unknown - 2s 10ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744
出力を見るために、TensorBoard ログをダウンロードして端末で見ることができます。
$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 12K 4.0K plugins 4.0K train 4.0K validation
SavedModel にエクスポートする
グラフと変数をエクスポートすることを望む場合、SavedModel はこれを行なうために最善の方法です。モデルはスコープとともに、あるいはスコープなしでロードし戻すことができます。更に、SavedModel はプラットフォーム不可知論者 (= agnostic) です。
path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
W0405 15:25:10.681121 140384515561216 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. W0405 15:25:10.683329 140384515561216 tf_logging.py:161] Export includes no default signature! W0405 15:25:11.132856 140384515561216 tf_logging.py:161] Export includes no default signature!
strategy.scope なしでモデルをロードします。
unreplicated_model = tf.keras.experimental.load_from_saved_model(path) unreplicated_model.compile( loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset) print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/Unknown - 2s 10ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744
strategy.scope とともにロードします。
with strategy.scope(): replicated_model = tf.keras.experimental.load_from_saved_model(path) replicated_model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) eval_loss, eval_acc = replicated_model.evaluate(eval_dataset) print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
157/Unknown - 1s 9ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744
以上