Keras 2 : examples : 訓練性能向上のための勾配集中化 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/19/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Gradient Centralization for Better Training Performance (Author: Rishit Dagli)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
Keras 2 : examples : 訓練性能向上のための勾配集中化
Description: DNN の訓練性能を向上させるための勾配集中化の実装。
イントロダクション
このサンプルは Yong et al. による深層ニューラルネットワークのための新しい最適化テクニック、勾配集中化 を実装してそれを Laurence Moroney の Horses or Humans データセット上で実演します。勾配集中化は訓練プロセスを高速化して DNN の最終的な汎用性能を向上させます。それは勾配ベクトルがゼロ平均を持つように集中化することにより勾配上で直接的に作用します。勾配集中化は更に、訓練プロセスがより効率的に安定的になるように、損失関数の Lipschitzness (リプシッツ性) とその勾配を改良します。
このサンプルは TensorFlow 2.2 またはそれ以上、そして tensorflow_datasets を必要とします、これはこのコマンドでインストールできます :
pip install tensorflow-datasets
このサンプルでは勾配集中化を実装していきますが、私が構築したパッケージ, gradient-centralization-tf でこれを非常に簡単に使用することもできます。
セットアップ
from time import time
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop
データの準備
このサンプルのためには、Horses or Humans データセットを使用します。
num_classes = 2
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf.data.AUTOTUNE
(train_ds, test_ds), metadata = tfds.load(
name=dataset_name,
split=[tfds.Split.TRAIN, tfds.Split.TEST],
with_info=True,
as_supervised=True,
)
print(f"Image shape: {metadata.features['image'].shape}")
print(f"Training images: {metadata.splits['train'].num_examples}")
print(f"Test images: {metadata.splits['test'].num_examples}")
Image shape: (300, 300, 3) Training images: 1027 Test images: 256
データ増強の使用
データを [0, 1] に再スケールしてデータに単純な増強を実行します。
rescale = layers.Rescaling(1.0 / 255)
data_augmentation = tf.keras.Sequential(
[
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.3),
layers.RandomZoom(0.2),
]
)
def prepare(ds, shuffle=False, augment=False):
# Rescale dataset
ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)
if shuffle:
ds = ds.shuffle(1024)
# Batch dataset
ds = ds.batch(batch_size)
# Use data augmentation only on the training set
if augment:
ds = ds.map(
lambda x, y: (data_augmentation(x, training=True), y),
num_parallel_calls=AUTOTUNE,
)
# Use buffered prefecting
return ds.prefetch(buffer_size=AUTOTUNE)
データを再スケールして増強します。
train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)
モデルの定義
このセクションでは畳み込みニューラルネットワークを定義します。
model = tf.keras.Sequential(
[
layers.Conv2D(16, (3, 3), activation="relu", input_shape=(300, 300, 3)),
layers.MaxPooling2D(2, 2),
layers.Conv2D(32, (3, 3), activation="relu"),
layers.Dropout(0.5),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.Dropout(0.5),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(512, activation="relu"),
layers.Dense(1, activation="sigmoid"),
]
)
勾配集中化の実装
今度は RMSProp optimizer クラスをサブクラス化して tf.keras.optimizers.Optimizer.get_gradients() メソッドを変更します、今そこで勾配集中化を実装します。高位ではそのアイデアは、Dense や Convolution 層のための逆伝播を通して勾配を得るとすると、それから重み行列の列ベクトルの平均を計算してから各列ベクトルから平均を取り除くというものです。
様々な応用での この論文 の実験は、一般的な画像分類、極め細かい画像分類、検出とセグメンテーションを含み、そして Person ReID は GC が DNN 学習のパフォーマンスを一貫して改良できることを実演します。
また、単純化のために当座は勾配クリッピング機能を実装していませんが、これは非常に簡単に実装できます。
現時点では RPSProp optimizer のサブクラスだけを作成していますが、同じ方法で任意の他の optimizer のためやカスタム optimizer 上でこれを簡単に再現できるでしょう。勾配集中化でモデルを訓練する後のセクションでこのクラスを使用していきます。
class GCRMSprop(RMSprop):
def get_gradients(self, loss, params):
# We here just provide a modified get_gradients() function since we are
# trying to just compute the centralized gradients.
grads = []
gradients = super().get_gradients()
for grad in gradients:
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)
grads.append(grad)
return grads
optimizer = GCRMSprop(learning_rate=1e-4)
訓練ユティリティ
合計訓練時間と各エポックのためにかかる時間を簡単に測定することを可能にするコールバックも作成します、何故ならば上記で構築したモデル上で勾配集中化の効果を比較することに関心があるからです。
class TimeHistory(tf.keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, batch, logs={}):
self.epoch_time_start = time()
def on_epoch_end(self, batch, logs={}):
self.times.append(time() - self.epoch_time_start)
GC なしでモデルを訓練する
次に勾配集中化なしで先に構築したモデルを訓練します、これを勾配集中化で訓練されたモデルの訓練性能と比較できます。
time_callback_no_gc = TimeHistory()
model.compile(
loss="binary_crossentropy",
optimizer=RMSprop(learning_rate=1e-4),
metrics=["accuracy"],
)
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 298, 298, 16) 448 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 149, 149, 16) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 147, 147, 32) 4640 _________________________________________________________________ dropout (Dropout) (None, 147, 147, 32) 0 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 73, 73, 32) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 71, 71, 64) 18496 _________________________________________________________________ dropout_1 (Dropout) (None, 71, 71, 64) 0 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 35, 35, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 33, 33, 64) 36928 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 16, 16, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 14, 14, 64) 36928 _________________________________________________________________ max_pooling2d_4 (MaxPooling2 (None, 7, 7, 64) 0 _________________________________________________________________ flatten (Flatten) (None, 3136) 0 _________________________________________________________________ dropout_2 (Dropout) (None, 3136) 0 _________________________________________________________________ dense (Dense) (None, 512) 1606144 _________________________________________________________________ dense_1 (Dense) (None, 1) 513 ================================================================= Total params: 1,704,097 Trainable params: 1,704,097 Non-trainable params: 0 _________________________________________________________________
後で勾配集中化で訓練されたモデルと訓練されていないモデルを比較することを望むために、履歴も保存します。
history_no_gc = model.fit(
train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
)
Epoch 1/10 9/9 [==============================] - 5s 571ms/step - loss: 0.7427 - accuracy: 0.5073 Epoch 2/10 9/9 [==============================] - 6s 667ms/step - loss: 0.6757 - accuracy: 0.5433 Epoch 3/10 9/9 [==============================] - 6s 660ms/step - loss: 0.6616 - accuracy: 0.6144 Epoch 4/10 9/9 [==============================] - 6s 642ms/step - loss: 0.6598 - accuracy: 0.6203 Epoch 5/10 9/9 [==============================] - 6s 666ms/step - loss: 0.6782 - accuracy: 0.6329 Epoch 6/10 9/9 [==============================] - 6s 655ms/step - loss: 0.6550 - accuracy: 0.6524 Epoch 7/10 9/9 [==============================] - 6s 645ms/step - loss: 0.6157 - accuracy: 0.7186 Epoch 8/10 9/9 [==============================] - 6s 654ms/step - loss: 0.6095 - accuracy: 0.6913 Epoch 9/10 9/9 [==============================] - 6s 677ms/step - loss: 0.5880 - accuracy: 0.7147 Epoch 10/10 9/9 [==============================] - 6s 663ms/step - loss: 0.5814 - accuracy: 0.6933
(訳注: 実験結果)
Epoch 1/10 9/9 [==============================] - 32s 845ms/step - loss: 0.8171 - accuracy: 0.5122 Epoch 2/10 9/9 [==============================] - 16s 1s/step - loss: 0.6904 - accuracy: 0.5355 Epoch 3/10 9/9 [==============================] - 16s 1s/step - loss: 0.6769 - accuracy: 0.5706 Epoch 4/10 9/9 [==============================] - 16s 1s/step - loss: 0.6558 - accuracy: 0.6173 Epoch 5/10 9/9 [==============================] - 16s 1s/step - loss: 0.6951 - accuracy: 0.6203 Epoch 6/10 9/9 [==============================] - 16s 1s/step - loss: 0.6367 - accuracy: 0.6485 Epoch 7/10 9/9 [==============================] - 16s 1s/step - loss: 0.6208 - accuracy: 0.6680 Epoch 8/10 9/9 [==============================] - 16s 1s/step - loss: 0.6172 - accuracy: 0.6504 Epoch 9/10 9/9 [==============================] - 16s 1s/step - loss: 0.5874 - accuracy: 0.6952 Epoch 10/10 9/9 [==============================] - 16s 1s/step - loss: 0.5466 - accuracy: 0.7556 CPU times: user 5min 16s, sys: 5.42 s, total: 5min 21s Wall time: 3min 23s
GC でモデルを訓練する
次に、同じモデルを訓練しますが、今回は勾配集中化を使用しています、私達の optimizer が今回は勾配集中化を使用しているものであることに注意してください。
time_callback_gc = TimeHistory()
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
model.summary()
history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 298, 298, 16) 448 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 149, 149, 16) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 147, 147, 32) 4640 _________________________________________________________________ dropout (Dropout) (None, 147, 147, 32) 0 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 73, 73, 32) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 71, 71, 64) 18496 _________________________________________________________________ dropout_1 (Dropout) (None, 71, 71, 64) 0 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 35, 35, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 33, 33, 64) 36928 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 16, 16, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 14, 14, 64) 36928 _________________________________________________________________ max_pooling2d_4 (MaxPooling2 (None, 7, 7, 64) 0 _________________________________________________________________ flatten (Flatten) (None, 3136) 0 _________________________________________________________________ dropout_2 (Dropout) (None, 3136) 0 _________________________________________________________________ dense (Dense) (None, 512) 1606144 _________________________________________________________________ dense_1 (Dense) (None, 1) 513 ================================================================= Total params: 1,704,097 Trainable params: 1,704,097 Non-trainable params: 0 _________________________________________________________________ Epoch 1/10 9/9 [==============================] - 6s 673ms/step - loss: 0.6022 - accuracy: 0.7147 Epoch 2/10 9/9 [==============================] - 6s 662ms/step - loss: 0.5385 - accuracy: 0.7371 Epoch 3/10 9/9 [==============================] - 6s 673ms/step - loss: 0.4832 - accuracy: 0.7945 Epoch 4/10 9/9 [==============================] - 6s 645ms/step - loss: 0.4692 - accuracy: 0.7799 Epoch 5/10 9/9 [==============================] - 6s 720ms/step - loss: 0.4792 - accuracy: 0.7799 Epoch 6/10 9/9 [==============================] - 6s 658ms/step - loss: 0.4623 - accuracy: 0.7838 Epoch 7/10 9/9 [==============================] - 6s 651ms/step - loss: 0.4413 - accuracy: 0.8072 Epoch 8/10 9/9 [==============================] - 6s 682ms/step - loss: 0.4542 - accuracy: 0.8014 Epoch 9/10 9/9 [==============================] - 6s 649ms/step - loss: 0.4235 - accuracy: 0.8053 Epoch 10/10 9/9 [==============================] - 6s 686ms/step - loss: 0.4445 - accuracy: 0.7936
_________________________________________________________________ Epoch 1/10 9/9 [==============================] - 17s 1s/step - loss: 0.6150 - accuracy: 0.7186 Epoch 2/10 9/9 [==============================] - 16s 1s/step - loss: 0.5451 - accuracy: 0.7790 Epoch 3/10 9/9 [==============================] - 16s 1s/step - loss: 0.4838 - accuracy: 0.7848 Epoch 4/10 9/9 [==============================] - 16s 1s/step - loss: 0.4931 - accuracy: 0.7877 Epoch 5/10 9/9 [==============================] - 16s 1s/step - loss: 0.4724 - accuracy: 0.7868 Epoch 6/10 9/9 [==============================] - 16s 1s/step - loss: 0.4608 - accuracy: 0.7907 Epoch 7/10 9/9 [==============================] - 16s 1s/step - loss: 0.4318 - accuracy: 0.8072 Epoch 8/10 9/9 [==============================] - 16s 1s/step - loss: 0.4191 - accuracy: 0.8218 Epoch 9/10 9/9 [==============================] - 16s 1s/step - loss: 0.4374 - accuracy: 0.8140 Epoch 10/10 9/9 [==============================] - 16s 1s/step - loss: 0.4185 - accuracy: 0.8286 CPU times: user 5min 13s, sys: 3.65 s, total: 5min 17s Wall time: 3min 17s
パフォーマンスの比較
print("Not using Gradient Centralization")
print(f"Loss: {history_no_gc.history['loss'][-1]}")
print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_no_gc.times)}")
print("Using Gradient Centralization")
print(f"Loss: {history_gc.history['loss'][-1]}")
print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_gc.times)}")
Not using Gradient Centralization Loss: 0.5814347863197327 Accuracy: 0.6932814121246338 Training Time: 136.35903406143188 Using Gradient Centralization Loss: 0.4444807469844818 Accuracy: 0.7935734987258911 Training Time: 131.61780261993408
Not using Gradient Centralization Loss: 0.5465568900108337 Accuracy: 0.7555988430976868 Training Time: 177.4813995361328 Using Gradient Centralization Loss: 0.4184599816799164 Accuracy: 0.8286270499229431 Training Time: 162.773184299469
読者は異なるドメインからの異なるデータセットで勾配集中化を試してその効果で実験することを勧めます。元の論文 を確認することも強く勧めます – 著者は、勾配集中化がそれがどのように一般的な性能、汎化性能、訓練時間を改良できてそしてより効率的であるかを示す幾つかの研究を提示しています。
Many thanks to Ali Mustufa Shaikh for reviewing this implementation.
以上