Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

TensorFlow 2.0 Alpha : 上級 Tutorials : 分散訓練 :- TensorFlow の分散訓練

Posted on 04/12/2019 by Sales Information

TensorFlow 2.0 Alpha : 上級 Tutorials : 分散訓練 :- TensorFlow の分散訓練 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/12/2019

* 本ページは、TensorFlow の本家サイトの TF 2.0 Alpha – Advanced Tutorials – Distributed training の以下のページを翻訳した上で適宜、補足説明したものです:

  • Distributed training in TensorFlow

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

分散訓練 :- TensorFlow の分散訓練

概要

tf.distribute.Strategy API は複数の処理ユニットに渡り貴方の訓練を分散するための抽象を提供します。目標はユーザに既存のモデルと訓練コードを (最小限の変更で) 使用して分散訓練を可能にすることです。

このチュートリアルは tf.distribute.MirroredStrategy を使用します、これは一つのマシン上の多くの GPU 上で同期訓練を伴う in-graph リプリケーションを行ないます。本質的には、それはモデルの変数の総てを各プロセッサにコピーします。それから、それは総てのプロセッサからの勾配を結合するために all-reduce を使用して結合された値をモデルの総てのコピーに適用します。

MirroredStategy は TensorFlow コアで利用可能な幾つかの分散ストラテジーの一つです。より多くのストラテジーについて 分散ストラテジー・ガイド で読むことができます。

 

Keras API

このサンプルはモデルと訓練ループを構築するために tf.kera API を使用します。カスタム訓練ループについては、このチュートリアル を見てください。

 

Import 依存性

from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow
!pip install -q tensorflow==2.0.0-alpha0 
import tensorflow_datasets as tfds
import tensorflow as tf

import os

 

データセットをダウンロードする

MNIST データセットをダウンロードしてそれを TensorFlow Datasets からロードします。これは tf.data フォーマットの dataset を返します。

with_info を True に設定するとデータセット全体に対するメタデータを含みます、これはここでは ds_info にセーブされます。他のものの中で、このメタデータは訓練とテストサンプルの数を含みます。

datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]
Downloading / extracting dataset mnist (11.06 MiB) to /root/tensorflow_datasets/mnist/1.0.0...

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/1 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:   0%|          | 0/1 [00:00<?, ? file/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  5.02 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:  50%|█████     | 1/2 [00:00<00:00,  4.26 file/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  4.76 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:  10%|█         | 1/10 [00:00<00:05,  1.59 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  5.29 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.59 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.59 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.59 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.76 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:03,  2.16 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.76 file/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:03,  2.16 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:00<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  40%|████      | 4/10 [00:00<00:02,  2.75 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  4.29 url/s]
Dl Size...:  50%|█████     | 5/10 [00:00<00:01,  2.75 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:00<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  60%|██████    | 6/10 [00:01<00:01,  3.50 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  70%|███████   | 7/10 [00:01<00:00,  3.50 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:01<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  80%|████████  | 8/10 [00:01<00:00,  4.47 MiB/s]

Extraction completed...: 100%|██████████| 3/3 [00:01<00:00,  3.61 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...:  90%|█████████ | 9/10 [00:01<00:00,  5.31 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  4.29 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.51 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.51 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Extraction completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.61 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.51 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.31 MiB/s]

Extraction completed...: 100%|██████████| 4/4 [00:01<00:00,  1.94 file/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.24 MiB/s]
1 examples [00:00,  6.16 examples/s]




60000 examples [00:13, 4349.69 examples/s]
Shuffling...:   0%|          | 0/10 [00:00<?, ? shard/s]WARNING: Logging before flag parsing goes to stderr.
W0405 15:23:16.461484 140384515561216 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow_datasets/core/file_format_adapter.py:249: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 260273.29 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 156708.54 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 271998.27 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  20%|██        | 2/10 [00:00<00:00, 13.66 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 294326.80 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 144187.84 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 252869.48 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  40%|████      | 4/10 [00:00<00:00, 13.64 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 280211.82 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 150449.41 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 284910.10 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  60%|██████    | 6/10 [00:00<00:00, 13.86 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 238819.31 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 137965.23 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 244280.96 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  80%|████████  | 8/10 [00:00<00:00, 13.58 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 292693.93 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 148655.99 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 232564.68 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 10/10 [00:00<00:00, 13.61 shard/s]
10000 examples [00:02, 4444.35 examples/s]
Shuffling...:   0%|          | 0/1 [00:00<?, ? shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 10000 examples [00:00, 313138.63 examples/s]
Writing...:   0%|          | 0/10000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 1/1 [00:00<00:00,  9.25 shard/s]

 

分散ストラテジーを定義する

MirroredStrategy オブジェクトを作成します。これは分散を処理し、内側でモデルを構築するためのコンテキストマネージャ (tf.distribute.MirroredStrategy.scope) を提供します。

strategy = tf.distribute.MirroredStrategy()
W0405 15:23:20.099184 140384515561216 cross_device_ops.py:1111] Not all devices in `tf.distribute.Strategy` are visible to TensorFlow.
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

 

入力パイプラインをセットアップする

モデルがマルチ GPU 上で訓練されるのであれば、特別な計算パワーを効果的に利用するためにバッチサイズはそれに従って増やされるべきです。更に、学習率もそれに従って調整されるべきです。

# You can also do ds_info.splits.total_num_examples to get the total 
# number of examples in the dataset.

num_train_examples = ds_info.splits['train'].num_examples
num_test_examples = ds_info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

0-255 のピクセル値は 0-1 範囲に正規化されなければなりません。このスケールを関数で定義します。

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  
  return image, label

この関数を訓練とテストデータに適用し、訓練データをシャッフルし、そして 訓練のためにそれをバッチ化します。

train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

 

モデルを作成する

strategy.scope のコンテキストで Keras モデルを作成してコンパイルします。

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])
  
  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

 

コールバックを定義する

ここで使用されるコールバックは :

  • Tensorboard: このコールバックはグラフを可視化することを可能にする TensorBoard のためのログを書きます。
  • モデル・チェックポイント: このコールバックは総てのエポック後にモデルをセーブします。
  • 学習率スケジューラ: このコールバックを使用すると、総てのエポック/バッチ後に変更する学習率をスケジューリングできます。

説明目的で、このノートブックでは学習率を表示するための print コールバックを追加します。

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print ('\nLearning rate for epoch {} is {}'.format(epoch + 1, 
                                                       model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, 
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

 

訓練そして評価する

さて、通常の方法でモデルを訓練します、モデル上で fit を呼び出してチュートリアルの最初に作成されたデータセットを渡します。このステップは貴方が訓練を分散していてもそうでなくても同じです。

model.fit(train_dataset, epochs=10, callbacks=callbacks)
W0405 15:23:21.675539 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0405 15:23:21.676841 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0405 15:23:21.677886 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0405 15:23:21.678794 140384515561216 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.

Epoch 1/10
    938/Unknown - 9s 9ms/step - loss: 0.1977 - accuracy: 0.9434
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 9s 9ms/step - loss: 0.1977 - accuracy: 0.9434
Epoch 2/10
930/938 [============================>.] - ETA: 0s - loss: 0.0680 - accuracy: 0.9791
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 7s 7ms/step - loss: 0.0678 - accuracy: 0.9791
Epoch 3/10
933/938 [============================>.] - ETA: 0s - loss: 0.0463 - accuracy: 0.9862
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 7s 8ms/step - loss: 0.0464 - accuracy: 0.9861
Epoch 4/10
935/938 [============================>.] - ETA: 0s - loss: 0.0256 - accuracy: 0.9927
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 7s 8ms/step - loss: 0.0255 - accuracy: 0.9927
Epoch 5/10
934/938 [============================>.] - ETA: 0s - loss: 0.0221 - accuracy: 0.9941
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 7s 8ms/step - loss: 0.0220 - accuracy: 0.9941
Epoch 6/10
936/938 [============================>.] - ETA: 0s - loss: 0.0202 - accuracy: 0.9947
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 7s 7ms/step - loss: 0.0201 - accuracy: 0.9947
Epoch 7/10
932/938 [============================>.] - ETA: 0s - loss: 0.0187 - accuracy: 0.9952
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 7s 7ms/step - loss: 0.0186 - accuracy: 0.9952
Epoch 8/10
935/938 [============================>.] - ETA: 0s - loss: 0.0161 - accuracy: 0.9963
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 7s 8ms/step - loss: 0.0161 - accuracy: 0.9963
Epoch 9/10
932/938 [============================>.] - ETA: 0s - loss: 0.0158 - accuracy: 0.9964
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 7s 8ms/step - loss: 0.0158 - accuracy: 0.9964
Epoch 10/10
934/938 [============================>.] - ETA: 0s - loss: 0.0156 - accuracy: 0.9965
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 7s 7ms/step - loss: 0.0156 - accuracy: 0.9965


下で見れるように、チェックポイントはセーブされています。

# check the checkpoint directory
!ls {checkpoint_dir}
checkpoint           ckpt_5.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_5.index
ckpt_1.index             ckpt_6.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_6.index
ckpt_10.index            ckpt_7.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_7.index
ckpt_2.index             ckpt_8.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_8.index
ckpt_3.index             ckpt_9.data-00000-of-00001
ckpt_4.data-00000-of-00001   ckpt_9.index
ckpt_4.index

モデルがどのように遂行するかを見るために、最新のチェックポイントをロードしてテストデータ上で evaluate を呼び出します。

適切なデータセットを使用して前のように evaluate を呼び出します。

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 2s 10ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744

出力を見るために、TensorBoard ログをダウンロードして端末で見ることができます。

$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 12K
4.0K plugins  4.0K train  4.0K validation

 

SavedModel にエクスポートする

グラフと変数をエクスポートすることを望む場合、SavedModel はこれを行なうために最善の方法です。モデルはスコープとともに、あるいはスコープなしでロードし戻すことができます。更に、SavedModel はプラットフォーム不可知論者 (= agnostic) です。

path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
W0405 15:25:10.681121 140384515561216 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0405 15:25:10.683329 140384515561216 tf_logging.py:161] Export includes no default signature!
W0405 15:25:11.132856 140384515561216 tf_logging.py:161] Export includes no default signature!

strategy.scope なしでモデルをロードします。

unreplicated_model = tf.keras.experimental.load_from_saved_model(path)

unreplicated_model.compile(
    loss='sparse_categorical_crossentropy', 
    optimizer=tf.keras.optimizers.Adam(), 
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 2s 10ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744

strategy.scope とともにロードします。

with strategy.scope():
  replicated_model = tf.keras.experimental.load_from_saved_model(path)
  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 1s 9ms/step - loss: 0.0388 - accuracy: 0.9872Eval loss: 0.03881146577201713, Eval Accuracy: 0.9872000217437744
 

以上






クラスキャット

最近の投稿

  • LangGraph 0.5 : Get started : ローカルサーバの実行
  • LangGraph 0.5 on Colab : Get started : human-in-the-loop 制御の追加
  • LangGraph 0.5 on Colab : Get started : Tavily Web 検索ツールの追加
  • LangGraph 0.5 on Colab : Get started : カスタム・ワークフローの構築
  • LangGraph 0.5 on Colab : Get started : クイックスタート

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (24) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2019年4月
月 火 水 木 金 土 日
1234567
891011121314
15161718192021
22232425262728
2930  
« 3月   5月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme