Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

TensorFlow 2.0 Beta : ガイド : Keras : TensorFlow Keras でモデルをセーブしてシリアライズする

Posted on 06/19/2019 by Sales Information

TensorFlow 2.0 Beta : ガイド : Keras : TensorFlow Keras でモデルをセーブしてシリアライズする (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/19/2019

* 本ページは、TensorFlow の本家サイトの TF 2.0 Beta の以下のページを翻訳した上で適宜、補足説明したものです:

  • Keras: Saving and Serializing Models with TensorFlow Keras

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ガイド : Keras : TensorFlow Keras でモデルをセーブしてシリアライズする

このガイドの最初のパートは Sequential モデルと Functional API を使用して構築されたモデルのためのセーブとシリアライゼーションをカバーします。セーブとシリアライゼーション API はモデルのこれらのタイプの両者について正確に同じです。

モデルのカスタム・サブクラスのためのセーブはセクション「サブクラス化されたモデルをセーブする」でカバーされます。この場合の API は Sequential や Functional モデルのためのものとは僅かに異なります。

 

セットアップ

from __future__ import absolute_import, division, print_function, unicode_literals

!pip install -q tensorflow-gpu==2.0.0-beta1
import tensorflow as tf

tf.keras.backend.clear_session()  # For easy reset of notebook state.

 

Part I: Sequential モデルまたは Functional モデルをセーブする

次のモデルを考えましょう :

from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)

model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')
model.summary()
Model: "3_layer_mlp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
predictions (Dense)          (None, 10)                650       
=================================================================
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________

オプションとして、このモデルを訓練してみましょう、そしてそれはセーブするための重み値と optimizer 状態を持ちます。もちろん、まだ訓練していないモデルもセーブできますが、明らかにそれは面白くないでしょう。

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=1)
WARNING: Logging before flag parsing goes to stderr.
W0614 15:22:16.427572 140456761870080 deprecation.py:323] From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Train on 60000 samples
60000/60000 [==============================] - 3s 48us/sample - loss: 0.3014
# Save predictions for future checks
predictions = model.predict(x_test)

 

モデル全体のセーブ

Functional API で構築されたモデルを単一のファイルにセーブできます。このファイルから同じモデルを後で再作成できます、モデルを作成したコードへのアクセスをもはや持たない場合でさえも。

このファイルは以下を含みます :

  • モデルのアーキテクチャ
  • モデルの重み値 (それは訓練の間に学習されました)
  • モデルの訓練 config (それは compile に渡したものです)、もしあれば
  • optimizer とその状態、もしあれば (これは貴方がやめたところで訓練を再開することを可能にします)
# Save the model
model.save('path_to_my_model.h5')

# Recreate the exact same model purely from the file
new_model = keras.models.load_model('path_to_my_model.h5')
import numpy as np

# Check that the state is preserved
new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)

# Note that the optimizer state is preserved as well:
# you can resume training where you left off.

 

SavedModel にエクスポートする

TensorFlow SavedModel 形式にモデル全体をエクスポートすることも可能です。SavedModel は TensorFlow オブジェクトのためのスタンドアロン・シリアライゼーション形式で、TensorFlow serving と Python 以外の TensorFlow 実装によりサポートされます。

# Export the model to a SavedModel
keras.experimental.export_saved_model(model, 'path_to_saved_model')

# Recreate the exact same model
new_model = keras.experimental.load_from_saved_model('path_to_saved_model')

# Check that the state is preserved
new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)

# Note that the optimizer state is preserved as well:
# you can resume training where you left off.
W0614 15:22:21.755359 140456761870080 deprecation.py:323] From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0614 15:22:21.757266 140456761870080 export_utils.py:182] Export includes no default signature!
W0614 15:22:22.025224 140456761870080 export_utils.py:182] Export includes no default signature!

作成された SavedModel は以下を含みます :

  • モデル重みを含む TensorFlow チェックポイント。
  • 基礎となる TensorFlow グラフを含む SavedModel proto。予測 (サービング)、訓練と評価のために個別のグラフがセーブされます。モデルが前にコンパイルされていない場合は、推論グラフだけがエクスポートされます。
  • モデルのアーキテクチャ config、もし利用可能であれば。

 

アーキテクチャ-only セービング

時に、貴方はモデルのアーキテクチャだけに興味があり、そして重み値や optimizer をセーブする必要がありません。この場合、get_config() メソッドを通してモデルの “config” を取得できます。config は Python 辞書で同じモデルを再作成することを可能にします — スクラッチから初期化され、訓練の間に以前に学習されたどのような情報も持ちません。

config = model.get_config()
reinitialized_model = keras.Model.from_config(config)

# Note that the model state is not preserved! We only saved the architecture.
new_predictions = reinitialized_model.predict(x_test)
assert abs(np.sum(predictions - new_predictions)) > 0.

代わりに from_json() から to_json() を使用することができます、これは config をストアするために Python 辞書の代わりに JSON 文字列を使用します。これは config をディスクにセーブするために有用です。

json_config = model.to_json()
reinitialized_model = keras.models.model_from_json(json_config)

 

重み-only セービング

時に、貴方はアーキテクチャではなくモデルの状態 — その重み値 — にだけ興味があります。この場合、get_weights() を通して重み値を Numpy 配列のリストとして取得できて、set_weights を通してモデルの状態を設定できます :

weights = model.get_weights()  # Retrieves the state of the model.
model.set_weights(weights)  # Sets the state of the model.

貴方のモデルを同じ状態で再作成するために get_config()/from_config() と get_weights()/set_weights() を組み合わせることができます。けれども、model.save() とは違い、これは訓練 config と optimizer を含みません。モデルを訓練のために使用する前に compile() を再度呼び出さなければならないでしょう。

config = model.get_config()
weights = model.get_weights()

new_model = keras.Model.from_config(config)
new_model.set_weights(weights)

# Check that the state is preserved
new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)

# Note that the optimizer was not preserved,
# so the model should be compiled anew before training
# (and the optimizer will start from a blank state).

get_weights() と set_weights(weights) に対する save-to-disk 選択肢は save_weights(fpath) と load_weights(fpath) です。

ここにディスクにセーブするサンプルがあります :

# Save JSON config to disk
json_config = model.to_json()
with open('model_config.json', 'w') as json_file:
    json_file.write(json_config)
# Save weights to disk
model.save_weights('path_to_my_weights.h5')

# Reload the model from the 2 files we saved
with open('model_config.json') as json_file:
    json_config = json_file.read()
new_model = keras.models.model_from_json(json_config)
new_model.load_weights('path_to_my_weights.h5')

# Check that the state is preserved
new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)

# Note that the optimizer was not preserved.
W0614 15:22:23.880016 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.decay
W0614 15:22:23.881534 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.learning_rate
W0614 15:22:23.882155 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.momentum
W0614 15:22:23.882841 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.rho
W0614 15:22:23.883504 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel
W0614 15:22:23.884812 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias
W0614 15:22:23.885461 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel
W0614 15:22:23.886249 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.bias
W0614 15:22:23.887130 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.kernel
W0614 15:22:23.888468 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.bias
W0614 15:22:23.889342 140456761870080 util.py:252] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.
W0614 15:22:23.893026 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer
W0614 15:22:23.893844 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.iter
W0614 15:22:23.894460 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.decay
W0614 15:22:23.895506 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.learning_rate
W0614 15:22:23.896186 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.momentum
W0614 15:22:23.897153 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.rho
W0614 15:22:23.897989 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel
W0614 15:22:23.898835 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias
W0614 15:22:23.899448 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel
W0614 15:22:23.900100 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.bias
W0614 15:22:23.900975 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.kernel
W0614 15:22:23.902399 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.bias
W0614 15:22:23.902988 140456761870080 util.py:252] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.
W0614 15:22:23.904124 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer
W0614 15:22:23.904798 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.iter
W0614 15:22:23.905418 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.decay
W0614 15:22:23.906876 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.learning_rate
W0614 15:22:23.907450 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.momentum
W0614 15:22:23.908407 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer.rho
W0614 15:22:23.909109 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel
W0614 15:22:23.909620 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias
W0614 15:22:23.910654 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel
W0614 15:22:23.911258 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.bias
W0614 15:22:23.911958 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.kernel
W0614 15:22:23.912729 140456761870080 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-2.bias
W0614 15:22:23.914015 140456761870080 util.py:252] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.

しかし覚えておいてください、最も単純な、推奨される方法は単にこれです :

model.save('path_to_my_model.h5')
del model
model = keras.models.load_model('path_to_my_model.h5')

 

SavedModel 形式で重み-only セービング

save_weights は Keras HDF5 形式か、TensorFlow SavedModel 形式でファイルを作成できることに注意してください。このフォーマットは貴方が提供するファイル拡張子から推論されます : それが”.h5″ か “.keras” であれば、フレームワークは Keras HDF5 形式を使用します。他の任意のものは SavedModel をデフォルトとします。

model.save_weights('path_to_my_tf_savedmodel')

総合的な明瞭さのために、フォーマットは save_format 引数を通して明示的に渡すことができます、これは値 “tf” か “h5” を取ることができます :

model.save_weights('path_to_my_tf_savedmodel', save_format='tf')

 

サブクラス化されたモデルをセーブする

Sequential モデルと Functional モデルは層の DAG を表わすデータ構造です。そのようなものとして、それらは安全にシリアライズとデシリアライズされます。

サブクラス化されたモデルはそれがデータ構造ではないという点で異なります、それはコードのピースです。モデルのアーキテクチャは call メソッド本体を通して定義されます。これはモデルのアーキテクチャは安全にシリアライズ化されないことを意味します。モデルをロードするためには、それを作成したコード (モデル・サブクラスのコード) へのアクセスを持つ必要があるでしょう。代わりに、このコードをバイトコードとしてシリアライズすることもできるでしょうが (e.g. pickling を通して)、それは安全ではなく一般に可搬ではありません。

これらの違いについてのより多くの情報は、記事 “What are Symbolic and Imperative APIs in TensorFlow 2.0?” を見てください。

次のサブクラス化されたモデルを考えましょう、これは最初のセクションからのモデルと同じ構造に従います :

class ThreeLayerMLP(keras.Model):

  def __init__(self, name=None):
    super(ThreeLayerMLP, self).__init__(name=name)
    self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')
    self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')
    self.pred_layer = layers.Dense(10, activation='softmax', name='predictions')

  def call(self, inputs):
    x = self.dense_1(inputs)
    x = self.dense_2(x)
    return self.pred_layer(x)

def get_model():
  return ThreeLayerMLP(name='3_layer_mlp')

model = get_model()

最初に、決して使用されていないサブクラス化されたモデルはセーブできません。

それはサブクラス化されたモデルはその重みを作成するために何某かのデータの上で呼び出される必要があるからです。

モデルが作成されるまで、それはそれが期待すべき入力データの shape と dtype を知りません、そしてそれ故にその重み変数を作成できません。最初のセクションから Functional モデルでは、入力の shape と dtype が (keras.Input(…)を通して) 前もって指定されたことを覚えているかもしれません — それが Functional モデルがインスタンス化されてすぐに状態を持つ理由です。

モデルを、それに状態を与えるために、訓練しましょう :

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=1)
Train on 60000 samples
60000/60000 [==============================] - 3s 46us/sample - loss: 0.3150

サブクラス化されたモデルをセーブする推奨される方法は TensorFlow SavedModel チェックポイントを作成するために save_weights を使用することです、これはモデルに関連する総ての変数の値を含みます :- 層の重み – optimizer の状態 – ステートフル・モデル・メトリクスに関連する任意の変数 (もしあれば)

model.save_weights('path_to_my_weights', save_format='tf')
# Save predictions for future checks
predictions = model.predict(x_test)
# Also save the loss on the first batch
# to later assert that the optimizer state was preserved
first_batch_loss = model.train_on_batch(x_train[:64], y_train[:64])

貴方のモデルをリストアするためには、モデル・オブジェクトを作成したコードへのアクセスが必要です。

optimizer 状態と任意のステートフル・メトリックの状態をリストアするためには、モデルを (前と正確に同じ引数で) compile して load_weights を呼び出す前にそれをあるデータ上で呼び出すべきです :

# Recreate the model
new_model = get_model()
new_model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=keras.optimizers.RMSprop())

# This initializes the variables used by the optimizers,
# as well as any stateful metric variables
new_model.train_on_batch(x_train[:1], y_train[:1])

# Load the state of the old model
new_model.load_weights('path_to_my_weights')

# Check that the model state has been preserved
new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)

# The optimizer state is preserved as well,
# so you can resume training where you left off
new_first_batch_loss = new_model.train_on_batch(x_train[:64], y_train[:64])
assert first_batch_loss == new_first_batch_loss
 

以上



クラスキャット

最近の投稿

  • LangGraph on Colab : エージェント型 RAG
  • LangGraph : 例題 : エージェント型 RAG
  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要
  • LangGraph : Prebuilt エージェント : ユーザインターフェイス

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (22) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2019年6月
月 火 水 木 金 土 日
 12
3456789
10111213141516
17181920212223
24252627282930
« 5月   7月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme