ホーム » TensorFlow 2.0 » TensorFlow 2.0 Alpha : 上級 Tutorials : 分散訓練 :- TensorFlow の分散訓練

TensorFlow 2.0 Alpha : 上級 Tutorials : 分散訓練 :- TensorFlow の分散訓練

TensorFlow 2.0 Alpha : 上級 Tutorials : 分散訓練 :- TensorFlow の分散訓練 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/12/2019

* 本ページは、TensorFlow の本家サイトの TF 2.0 Alpha – Advanced Tutorials – Distributed training の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

分散訓練 :- TensorFlow の分散訓練

概要

tf.distribute.Strategy API は複数の処理ユニットに渡り貴方の訓練を分散するための抽象を提供します。目標はユーザに既存のモデルと訓練コードを (最小限の変更で) 使用して分散訓練を可能にすることです。

このチュートリアルは tf.distribute.MirroredStrategy を使用します、これは一つのマシン上の多くの GPU 上で同期訓練を伴う in-graph リプリケーションを行ないます。本質的には、それはモデルの変数の総てを各プロセッサにコピーします。それから、それは総てのプロセッサからの勾配を結合するために all-reduce を使用して結合された値をモデルの総てのコピーに適用します。

MirroredStategy は TensorFlow コアで利用可能な幾つかの分散ストラテジーの一つです。より多くのストラテジーについて 分散ストラテジー・ガイド で読むことができます。

 

Keras API

このサンプルはモデルと訓練ループを構築するために tf.kera API を使用します。カスタム訓練ループについては、このチュートリアル を見てください。

 

Import 依存性

from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow
!pip install -q tensorflow==2.0.0-alpha0 
import tensorflow_datasets as tfds
import tensorflow as tf

import os

 

データセットをダウンロードする

MNIST データセットをダウンロードしてそれを TensorFlow Datasets からロードします。これは tf.data フォーマットの dataset を返します。

with_info を True に設定するとデータセット全体に対するメタデータを含みます、これはここでは ds_info にセーブされます。他のものの中で、このメタデータは訓練とテストサンプルの数を含みます。

datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]
Downloading / extracting dataset mnist (11.06 MiB) to /root/tensorflow_datasets/mnist/1.0.0...

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/1 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:   0%|          | 0/1 [00:00<?, ? file/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:  50%|█████     | 1/2 [00:00<00:00,  4.26 file/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  4.76 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:  10%|█         | 1/10 [00:00<00:05,  1.59 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.59 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.59 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.59 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.76 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:03,  2.16 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.76 file/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:03,  2.16 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:00<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  40%|████      | 4/10 [00:00<00:02,  2.75 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  50%|█████     | 5/10 [00:00<00:01,  2.75 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:00<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  60%|██████    | 6/10 [00:01<00:01,  3.50 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  70%|███████   | 7/10 [00:01<00:00,  3.50 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:01<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  80%|████████  | 8/10 [00:01<00:00,  4.47 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:01<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  90%|█████████ | 9/10 [00:01<00:00,  5.31 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.51 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.51 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Extraction completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.61 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.51 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Extraction completed...: 100%|██████████| 4/4 [00:01<00:00,  1.94 file/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.24 MiB/s]
1 examples [00:00,  6.16 examples/s]




60000 examples [00:13, 4349.69 examples/s]
Shuffling...:   0%|          | 0/10 [00:00<?, ? shard/s]WARNING: Logging before flag parsing goes to stderr.
W0405 15:23:16.461484 140384515561216 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow_datasets/core/file_format_adapter.py:249: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 260273.29 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 156708.54 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 271998.27 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  20%|██        | 2/10 [00:00<00:00, 13.66 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 294326.80 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 144187.84 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 252869.48 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  40%|████      | 4/10 [00:00<00:00, 13.64 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 280211.82 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 150449.41 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 284910.10 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  60%|██████    | 6/10 [00:00<00:00, 13.86 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 238819.31 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 137965.23 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 244280.96 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  80%|████████  | 8/10 [00:00<00:00, 13.58 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 292693.93 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 148655.99 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 232564.68 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 10/10 [00:00<00:00, 13.61 shard/s]
10000 examples [00:02, 4444.35 examples/s]
Shuffling...:   0%|          | 0/1 [00:00<?, ? shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 10000 examples [00:00, 313138.63 examples/s]
Writing...:   0%|          | 0/10000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 1/1 [00:00<00:00,  9.25 shard/s]

 

分散ストラテジーを定義する

MirroredStrategy オブジェクトを作成します。これは分散を処理し、内側でモデルを構築するためのコンテキストマネージャ (tf.distribute.MirroredStrategy.scope) を提供します。

strategy = tf.distribute.MirroredStrategy()
W0405 15:23:20.099184 140384515561216 cross_device_ops.py:1111] Not all devices in `tf.distribute.Strategy` are visible to TensorFlow.
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

 

入力パイプラインをセットアップする

モデルがマルチ GPU 上で訓練されるのであれば、特別な計算パワーを効果的に利用するためにバッチサイズはそれに従って増やされるべきです。更に、学習率もそれに従って調整されるべきです。

# You can also do ds_info.splits.total_num_examples to get the total 
# number of examples in the dataset.

num_train_examples = ds_info.splits['train'].num_examples
num_test_examples = ds_info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

0-255 のピクセル値は 0-1 範囲に正規化されなければなりません。このスケールを関数で定義します。

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  
  return image, label

この関数を訓練とテストデータに適用し、訓練データをシャッフルし、そして 訓練のためにそれをバッチ化します

train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

 

モデルを作成する

strategy.scope のコンテキストで Keras モデルを作成してコンパイルします。

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])
  
  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

 

コールバックを定義する

ここで使用されるコールバックは :

  • Tensorboard: このコールバックはグラフを可視化することを可能にする TensorBoard のためのログを書きます。
  • モデル・チェックポイント: このコールバックは総てのエポック後にモデルをセーブします。
  • 学習率スケジューラ: このコールバックを使用すると、総てのエポック/バッチ後に変更する学習率をスケジューリングできます。

説明目的で、このノートブックでは学習率を表示するための print コールバックを追加します。

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print ('\nLearning rate for epoch {} is {}'.format(epoch + 1, 
                                                       model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, 
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

 

訓練そして評価する

さて、通常の方法でモデルを訓練します、モデル上で fit を呼び出してチュートリアルの最初に作成されたデータセットを渡します。このステップは貴方が訓練を分散していてもそうでなくても同じです。

model.fit(train_dataset, epochs=10, callbacks=callbacks)
W0405 15:23:21.675539 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0405 15:23:21.676841 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0405 15:23:21.677886 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0405 15:23:21.678794 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.

Epoch 1/10
    938/Unknown - 9s 9ms/step - loss: 0.1977 - accuracy: 0.9434
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 9s 9ms/step - loss: 0.1977 - accuracy: 0.9434
Epoch 2/10
930/938 [============================>.] - ETA: 0s - loss: 0.0680 - accuracy: 0.9791
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 7s 7ms/step - loss: 0.0678 - accuracy: 0.9791
Epoch 3/10
933/938 [============================>.] - ETA: 0s - loss: 0.0463 - accuracy: 0.9862
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 7s 8ms/step - loss: 0.0464 - accuracy: 0.9861
Epoch 4/10
935/938 [============================>.] - ETA: 0s - loss: 0.0256 - accuracy: 0.9927
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 7s 8ms/step - loss: 0.0255 - accuracy: 0.9927
Epoch 5/10
934/938 [============================>.] - ETA: 0s - loss: 0.0221 - accuracy: 0.9941
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 7s 8ms/step - loss: 0.0220 - accuracy: 0.9941
Epoch 6/10
936/938 [============================>.] - ETA: 0s - loss: 0.0202 - accuracy: 0.9947
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 7s 7ms/step - loss: 0.0201 - accuracy: 0.9947
Epoch 7/10
932/938 [============================>.] - ETA: 0s - loss: 0.0187 - accuracy: 0.9952
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 7s 7ms/step - loss: 0.0186 - accuracy: 0.9952
Epoch 8/10
935/938 [============================>.] - ETA: 0s - loss: 0.0161 - accuracy: 0.9963
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 7s 8ms/step - loss: 0.0161 - accuracy: 0.9963
Epoch 9/10
932/938 [============================>.] - ETA: 0s - loss: 0.0158 - accuracy: 0.9964
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 7s 8ms/step - loss: 0.0158 - accuracy: 0.9964
Epoch 10/10
934/938 [============================>.] - ETA: 0s - loss: 0.0156 - accuracy: 0.9965
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 7s 7ms/step - loss: 0.0156 - accuracy: 0.9965


下で見れるように、チェックポイントはセーブされています。

# check the checkpoint directory
!ls {checkpoint_dir}
checkpoint           ckpt_5.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_5.index
ckpt_1.index             ckpt_6.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_6.index
ckpt_10.index            ckpt_7.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_7.index
ckpt_2.index             ckpt_8.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_8.index
ckpt_3.index             ckpt_9.data-00000-of-00001
ckpt_4.data-00000-of-00001   ckpt_9.index
ckpt_4.index

モデルがどのように遂行するかを見るために、最新のチェックポイントをロードしてテストデータ上で evaluate を呼び出します。

適切なデータセットを使用して前のように evaluate を呼び出します。

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 2s 10ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744

出力を見るために、TensorBoard ログをダウンロードして端末で見ることができます。

$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 12K
4.0K plugins  4.0K train  4.0K validation

 

SavedModel にエクスポートする

グラフと変数をエクスポートすることを望む場合、SavedModel はこれを行なうために最善の方法です。モデルはスコープとともに、あるいはスコープなしでロードし戻すことができます。更に、SavedModel はプラットフォーム不可知論者 (= agnostic) です。

path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
W0405 15:25:10.681121 140384515561216 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0405 15:25:10.683329 140384515561216 tf_logging.py:161] Export includes no default signature!
W0405 15:25:11.132856 140384515561216 tf_logging.py:161] Export includes no default signature!

strategy.scope なしでモデルをロードします。

unreplicated_model = tf.keras.experimental.load_from_saved_model(path)

unreplicated_model.compile(
    loss='sparse_categorical_crossentropy', 
    optimizer=tf.keras.optimizers.Adam(), 
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 2s 10ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744

strategy.scope とともにロードします。

with strategy.scope():
  replicated_model = tf.keras.experimental.load_from_saved_model(path)
  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 1s 9ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744
 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com