TensorFlow 2.4 : ガイド : Keras :- 貴方自身のコールバックを書く (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 01/18/2021
* 本ページは、TensorFlow org サイトの Guide – Keras の以下のページを翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。 |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
ガイド : Keras :- 貴方自身のコールバックを書く
イントロダクション
コールバックは 訓練、評価や推論の間の Keras モデルの動作をカスタマイズするためのパワフルなツールです。サンプルは訓練進捗と結果が TensorBoard でエクスポートされて可視化できる tf.keras.callbacks.TensorBoard や、訓練の間にモデルを定期的にセーブする tf.keras.callbacks.ModelCheckpoint を含みます。
このガイドでは、Keras コールバックが何か、それは何ができるか、そして貴方自身のものをどのように構築できるかを学習します。貴方を始めさせるために単純なコールバック・アプリケーションの幾つかのデモを提供します。
セットアップ
import tensorflow as tf from tensorflow import keras
Keras コールバック概要
総てのコールバックは keras.callbacks.Callback クラスをサブクラス化して、そして訓練、テストと予測の様々なステージで呼び出されるメソッドのセットを override します。コールバックは訓練の間のモデルの内部状態と統計上のビューを得るために有用です。
コールバックのリストを (キーワード引数 callbacks として) 以下のモデル・メソッドに渡すことができます :
コールバック・メソッドの概要
グローバル・メソッド
on_(train|test|predict)_begin(self, logs=None)
fit/evaluate/predict の最初に呼び出されます。
on_(train|test|predict)_end(self, logs=None)
fit/evaluate/predict の最後に呼び出されます。
訓練/テスト/予測のためのバッチレベル・メソッド
on_(train|test|predict)_batch_begin(self, batch, logs=None)
訓練/テスト/予測の間のバッチを処理するすぐ前に呼び出されます。
on_(train|test|predict)_batch_end(self, batch, logs=None)
バッチを訓練/テスト/予測する最後に呼び出されます。このメソッド内では、logs はメトリクス結果を含む辞書です。
エポックレベル・メソッド (訓練 only)
on_epoch_begin(self, epoch, logs=None)
訓練の間にエポックの最初に呼び出されます。
on_epoch_end(self, epoch, logs=None)
訓練の間にエポックの最後に呼び出されます。
基本的なサンプル
具体的なサンプルを見ましょう。始めるために、tensorflow をインポートして単純な Sequential Keras モデルを定義しましょう :
# Define the Keras model to add callbacks to def get_model(): model = keras.Sequential() model.add(keras.layers.Dense(1, input_dim=784)) model.compile( optimizer=keras.optimizers.RMSprop(learning_rate=0.1), loss="mean_squared_error", metrics=["mean_absolute_error"], ) return model
それから訓練とテストのために Keras datasets API から MNIST データをロードします :
# Load example MNIST data and pre-process it (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train.reshape(-1, 784).astype("float32") / 255.0 x_test = x_test.reshape(-1, 784).astype("float32") / 255.0 # Limit the data to 1000 samples x_train = x_train[:1000] y_train = y_train[:1000] x_test = x_test[:1000] y_test = y_test[:1000]
今は、以下をログ出力する単純なカスタム・コールバックを定義します :
- fit/evaluate/predict の最初と最後
- 各エポックの最初 & 最後
- 各訓練バッチの最初 & 最後
- 各評価 (テスト) バッチの最初 & 最後
- 各推論 (予測) バッチの最初 & 最後
class CustomCallback(keras.callbacks.Callback): def on_train_begin(self, logs=None): keys = list(logs.keys()) print("Starting training; got log keys: {}".format(keys)) def on_train_end(self, logs=None): keys = list(logs.keys()) print("Stop training; got log keys: {}".format(keys)) def on_epoch_begin(self, epoch, logs=None): keys = list(logs.keys()) print("Start epoch {} of training; got log keys: {}".format(epoch, keys)) def on_epoch_end(self, epoch, logs=None): keys = list(logs.keys()) print("End epoch {} of training; got log keys: {}".format(epoch, keys)) def on_test_begin(self, logs=None): keys = list(logs.keys()) print("Start testing; got log keys: {}".format(keys)) def on_test_end(self, logs=None): keys = list(logs.keys()) print("Stop testing; got log keys: {}".format(keys)) def on_predict_begin(self, logs=None): keys = list(logs.keys()) print("Start predicting; got log keys: {}".format(keys)) def on_predict_end(self, logs=None): keys = list(logs.keys()) print("Stop predicting; got log keys: {}".format(keys)) def on_train_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Training: start of batch {}; got log keys: {}".format(batch, keys)) def on_train_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Training: end of batch {}; got log keys: {}".format(batch, keys)) def on_test_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys)) def on_test_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys)) def on_predict_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys)) def on_predict_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
Let’s try it out:
model = get_model() model.fit( x_train, y_train, batch_size=128, epochs=1, verbose=0, validation_split=0.5, callbacks=[CustomCallback()], ) res = model.evaluate( x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()] ) res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Starting training; got log keys: [] Start epoch 0 of training; got log keys: [] ...Training: start of batch 0; got log keys: [] ...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 1; got log keys: [] ...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 2; got log keys: [] ...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 3; got log keys: [] ...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 4; got log keys: [] ...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 5; got log keys: [] ...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 6; got log keys: [] ...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 7; got log keys: [] ...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] Start predicting; got log keys: [] ...Predicting: start of batch 0; got log keys: [] ...Predicting: end of batch 0; got log keys: ['outputs'] ...Predicting: start of batch 1; got log keys: [] ...Predicting: end of batch 1; got log keys: ['outputs'] ...Predicting: start of batch 2; got log keys: [] ...Predicting: end of batch 2; got log keys: ['outputs'] ...Predicting: start of batch 3; got log keys: [] ...Predicting: end of batch 3; got log keys: ['outputs'] ...Predicting: start of batch 4; got log keys: [] ...Predicting: end of batch 4; got log keys: ['outputs'] ...Predicting: start of batch 5; got log keys: [] ...Predicting: end of batch 5; got log keys: ['outputs'] ...Predicting: start of batch 6; got log keys: [] ...Predicting: end of batch 6; got log keys: ['outputs'] ...Predicting: start of batch 7; got log keys: [] ...Predicting: end of batch 7; got log keys: ['outputs'] Stop predicting; got log keys: []
logs 辞書の使用方法
logs 辞書はバッチかエポックの最後の損失値と総てのメトリックを含みます。サンプルは損失と平均絶対誤差 (MAE) を含みます。
class LossAndErrorPrintingCallback(keras.callbacks.Callback): def on_train_batch_end(self, batch, logs=None): print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"])) def on_test_batch_end(self, batch, logs=None): print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"])) def on_epoch_end(self, epoch, logs=None): print( "The average loss for epoch {} is {:7.2f} " "and mean absolute error is {:7.2f}.".format( epoch, logs["loss"], logs["mean_absolute_error"] ) ) model = get_model() model.fit( x_train, y_train, batch_size=128, epochs=2, verbose=0, callbacks=[LossAndErrorPrintingCallback()], ) res = model.evaluate( x_test, y_test, batch_size=128, verbose=0, callbacks=[LossAndErrorPrintingCallback()], )
For batch 0, loss is 28.35. For batch 1, loss is 495.31. For batch 2, loss is 338.58. For batch 3, loss is 256.09. For batch 4, loss is 206.61. For batch 5, loss is 173.36. For batch 6, loss is 149.41. For batch 7, loss is 134.53. The average loss for epoch 0 is 134.53 and mean absolute error is 6.17. For batch 0, loss is 5.50. For batch 1, loss is 5.29. For batch 2, loss is 5.42. For batch 3, loss is 5.37. For batch 4, loss is 5.01. For batch 5, loss is 4.77. For batch 6, loss is 4.67. For batch 7, loss is 4.74. The average loss for epoch 1 is 4.74 and mean absolute error is 1.78. For batch 0, loss is 5.57. For batch 1, loss is 5.10. For batch 2, loss is 5.00. For batch 3, loss is 4.99. For batch 4, loss is 5.14. For batch 5, loss is 5.18. For batch 6, loss is 5.13. For batch 7, loss is 5.07.
self.model の使用方法
これらのメソッドの一つが呼び出されるとき、ログ情報を受け取ることに加えて、コールバックは訓練/評価/推論の現在のラウンドに関連するモデルへのアクセスを持ちます : self.model です。
コールバック内で self.model で行なうことができることの幾つかがあります :
- 訓練を直ちに中断するために self.model.stop_training = True を設定します。
- self.model.optimizer.learning_rate のような、(self.model.optimizer として利用可能な) optimizer のハイパーパラメータを変化させます。
- 定期的な間隔でモデルをセーブします。
- 訓練の間の正当性チェックとして使用するため、各エポックの最後で幾つかのテストサンプル上 model.predict() の出力を記録します。
- 時間につれてモデルが何を学習しているか監視するため、各エポックの最後に中間特徴の可視化を抽出します。
- 等々。
これを 2, 3 のサンプルで実際に見ましょう。
Keras コールバック・アプリケーションのサンプル
最小損失における Early Stopping
最初の例はコールバックの作成を示します、これは損失の最小に達したとき属性 self.model.stop_training (boolean) を設定することにより訓練を停止します。オプションで、ローカル最小に達した後で停止する前に幾つのエポックを待つべきかを指定するための引数 patience を貴方は提供できます。
tf.keras.callbacks.EarlyStopping はより完全で一般的な実装を提供します。
import numpy as np class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """Stop training when the loss is at its min, i.e. the loss stops decreasing. Arguments: patience: Number of epochs to wait after min has been hit. After this number of no improvement, training stops. """ def __init__(self, patience=0): super(EarlyStoppingAtMinLoss, self).__init__() self.patience = patience # best_weights to store the weights at which the minimum loss occurs. self.best_weights = None def on_train_begin(self, logs=None): # The number of epoch it has waited when loss is no longer minimum. self.wait = 0 # The epoch the training stops at. self.stopped_epoch = 0 # Initialize the best as infinity. self.best = np.Inf def on_epoch_end(self, epoch, logs=None): current = logs.get("loss") if np.less(current, self.best): self.best = current self.wait = 0 # Record the best weights if current results is better (less). self.best_weights = self.model.get_weights() else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print("Restoring model weights from the end of the best epoch.") self.model.set_weights(self.best_weights) def on_train_end(self, logs=None): if self.stopped_epoch > 0: print("Epoch %05d: early stopping" % (self.stopped_epoch + 1)) model = get_model() model.fit( x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=30, verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()], )
For batch 0, loss is 31.46. For batch 1, loss is 431.85. For batch 2, loss is 299.28. For batch 3, loss is 227.94. For batch 4, loss is 184.11. The average loss for epoch 0 is 184.11 and mean absolute error is 8.47. For batch 0, loss is 7.22. For batch 1, loss is 6.92. For batch 2, loss is 6.74. For batch 3, loss is 6.27. For batch 4, loss is 6.40. The average loss for epoch 1 is 6.40 and mean absolute error is 2.04. For batch 0, loss is 5.56. For batch 1, loss is 6.04. For batch 2, loss is 5.59. For batch 3, loss is 5.10. For batch 4, loss is 4.92. The average loss for epoch 2 is 4.92 and mean absolute error is 1.79. For batch 0, loss is 5.98. For batch 1, loss is 4.58. For batch 2, loss is 6.00. For batch 3, loss is 8.78. For batch 4, loss is 13.73. The average loss for epoch 3 is 13.73 and mean absolute error is 2.96. Restoring model weights from the end of the best epoch. Epoch 00004: early stopping <tensorflow.python.keras.callbacks.History at 0x7f77f173fe48>
学習率スケジューリング
この例では、カスタム・コールバックが訓練の過程において optimizer の学習率を動的に変更するためにどのように使用できるかを示しています。
より一般的な実装については callbacks.LearningRateScheduler を見てください。
class CustomLearningRateScheduler(keras.callbacks.Callback): """Learning rate scheduler which sets the learning rate according to schedule. Arguments: schedule: a function that takes an epoch index (integer, indexed from 0) and current learning rate as inputs and returns a new learning rate as output (float). """ def __init__(self, schedule): super(CustomLearningRateScheduler, self).__init__() self.schedule = schedule def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, "lr"): raise ValueError('Optimizer must have a "lr" attribute.') # Get the current learning rate from model's optimizer. lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # Call schedule function to get the scheduled learning rate. scheduled_lr = self.schedule(epoch, lr) # Set the value back to the optimizer before this epoch starts tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr) print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr)) LR_SCHEDULE = [ # (epoch to start, learning rate) tuples (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001), ] def lr_schedule(epoch, lr): """Helper function to retrieve the scheduled learning rate based on epoch.""" if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]: return lr for i in range(len(LR_SCHEDULE)): if epoch == LR_SCHEDULE[i][0]: return LR_SCHEDULE[i][1] return lr model = get_model() model.fit( x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=15, verbose=0, callbacks=[ LossAndErrorPrintingCallback(), CustomLearningRateScheduler(lr_schedule), ], )
Epoch 00000: Learning rate is 0.1000. For batch 0, loss is 32.28. For batch 1, loss is 477.32. For batch 2, loss is 327.87. For batch 3, loss is 248.14. For batch 4, loss is 199.64. The average loss for epoch 0 is 199.64 and mean absolute error is 8.51. Epoch 00001: Learning rate is 0.1000. For batch 0, loss is 7.94. For batch 1, loss is 6.56. For batch 2, loss is 6.11. For batch 3, loss is 5.98. For batch 4, loss is 5.82. The average loss for epoch 1 is 5.82 and mean absolute error is 1.97. Epoch 00002: Learning rate is 0.1000. For batch 0, loss is 5.04. For batch 1, loss is 4.87. For batch 2, loss is 5.04. For batch 3, loss is 4.71. For batch 4, loss is 4.68. The average loss for epoch 2 is 4.68 and mean absolute error is 1.76. Epoch 00003: Learning rate is 0.0500. For batch 0, loss is 5.27. For batch 1, loss is 4.22. For batch 2, loss is 4.10. For batch 3, loss is 4.00. For batch 4, loss is 3.93. The average loss for epoch 3 is 3.93 and mean absolute error is 1.59. Epoch 00004: Learning rate is 0.0500. For batch 0, loss is 3.29. For batch 1, loss is 3.94. For batch 2, loss is 3.96. For batch 3, loss is 4.24. For batch 4, loss is 4.32. The average loss for epoch 4 is 4.32 and mean absolute error is 1.63. Epoch 00005: Learning rate is 0.0500. For batch 0, loss is 4.12. For batch 1, loss is 4.32. For batch 2, loss is 4.80. For batch 3, loss is 5.20. For batch 4, loss is 5.33. The average loss for epoch 5 is 5.33 and mean absolute error is 1.84. Epoch 00006: Learning rate is 0.0100. For batch 0, loss is 5.05. For batch 1, loss is 4.55. For batch 2, loss is 3.81. For batch 3, loss is 3.58. For batch 4, loss is 3.53. The average loss for epoch 6 is 3.53 and mean absolute error is 1.47. Epoch 00007: Learning rate is 0.0100. For batch 0, loss is 3.30. For batch 1, loss is 4.03. For batch 2, loss is 3.57. For batch 3, loss is 3.48. For batch 4, loss is 3.39. The average loss for epoch 7 is 3.39 and mean absolute error is 1.45. Epoch 00008: Learning rate is 0.0100. For batch 0, loss is 2.76. For batch 1, loss is 3.08. For batch 2, loss is 2.95. For batch 3, loss is 3.00. For batch 4, loss is 3.16. The average loss for epoch 8 is 3.16 and mean absolute error is 1.39. Epoch 00009: Learning rate is 0.0050. For batch 0, loss is 3.62. For batch 1, loss is 4.01. For batch 2, loss is 4.27. For batch 3, loss is 3.95. For batch 4, loss is 3.71. The average loss for epoch 9 is 3.71 and mean absolute error is 1.48. Epoch 00010: Learning rate is 0.0050. For batch 0, loss is 3.41. For batch 1, loss is 2.97. For batch 2, loss is 2.78. For batch 3, loss is 2.98. For batch 4, loss is 3.01. The average loss for epoch 10 is 3.01 and mean absolute error is 1.34. Epoch 00011: Learning rate is 0.0050. For batch 0, loss is 3.48. For batch 1, loss is 3.39. For batch 2, loss is 3.24. For batch 3, loss is 3.20. For batch 4, loss is 3.43. The average loss for epoch 11 is 3.43 and mean absolute error is 1.44. Epoch 00012: Learning rate is 0.0010. For batch 0, loss is 3.49. For batch 1, loss is 3.63. For batch 2, loss is 3.23. For batch 3, loss is 3.22. For batch 4, loss is 3.19. The average loss for epoch 12 is 3.19 and mean absolute error is 1.37. Epoch 00013: Learning rate is 0.0010. For batch 0, loss is 2.18. For batch 1, loss is 2.91. For batch 2, loss is 2.87. For batch 3, loss is 3.00. For batch 4, loss is 3.16. The average loss for epoch 13 is 3.16 and mean absolute error is 1.34. Epoch 00014: Learning rate is 0.0010. For batch 0, loss is 2.55. For batch 1, loss is 3.37. For batch 2, loss is 3.27. For batch 3, loss is 3.41. For batch 4, loss is 3.29. The average loss for epoch 14 is 3.29 and mean absolute error is 1.39. <tensorflow.python.keras.callbacks.History at 0x7f77f1683b38>
組込み Keras コールバック
API doc を読むことにより既存の Keras コールバックを確実に調べてください。アプリケーションは CSV へのロギング、モデルのセーブ、TensorBoard 内の可視化、そしてそれ以上を含みます!
以上