Keras 2 : examples : 画像分類のための MixUp 増強 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/01/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : MixUp augmentation for image classification (Author: Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : 画像分類のための MixUp 増強
Description: 画像分類のために mixup テクニックを使用したデータ増強。
イントロダクション
mixup は mixup: Beyond Empirical Risk Minimization by Zhang et al. で提案された、ドメイン不可知なデータ増強テクニックです。それは次の式で実装されています :
(lambda 値は [0, 1] 範囲の値で Beta 分布 からサンプリングされていることに注意してください。)
このテクニックは非常にシステマティックに名前付けられています – 文字通り特徴と対応するラベルを良く混ぜあわせています (ミックスアップしています)。実装のやり方は単純です。ニューラルネットワークは 破損したラベルを記憶する 傾向にあります。mixup は異なる特徴を互いに組合せて (ラベルについても同じようにします) これを緩和します、その結果ネットワークは特徴とラベルの間の関係性について過信しません。
mixup は、例えば医用画像データセットのように、与えられたデータセットに対して増強変換のセットの選択について確信がないときに特に有用です。mixup はコンピュータビジョン, 自然言語処理, 音声認識, 等のような様々なデータ様式に拡張できます。
このサンプルは TensorFlow 2.4 またはそれ以上が必要です。
セットアップ
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers
データセットの準備
このサンプルでは、FashionMNIST データセットを使用していきます。しかしこの同じレシピは他の分類データセットに対しても使用できます。
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
y_train = tf.one_hot(y_train, 10)
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_test = tf.one_hot(y_test, 10)
ハイパーパラメータの定義
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 10
データを TensorFlow Dataset オブジェクトに変換する
# Put aside a few samples to create our validation set
val_samples = 2000
x_val, y_val = x_train[:val_samples], y_train[:val_samples]
new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]
train_ds_one = (
tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
)
train_ds_two = (
tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
)
# Because we will be mixing up the images and their corresponding labels, we will be
# combining two shuffled datasets from the same training data.
train_ds = tf.data.Dataset.zip((train_ds_one, train_ds_two))
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
mixup テクニック関数を定義する
mixup ルーチンを実行するため、同じデータセットからの訓練データを使用して新しい仮想データセットを作成して、Beta 分布からサンプリングされた [0, 1] 範囲内の lambda 値を適用します — その結果、例えば、 new_x = lambda * x1 + (1 – lambda) * x2 (ここで x1 と x2 は画像) でそして同じ式がラベルにも適用されます。
def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1)
gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0)
return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
def mix_up(ds_one, ds_two, alpha=0.2):
# Unpack two datasets
images_one, labels_one = ds_one
images_two, labels_two = ds_two
batch_size = tf.shape(images_one)[0]
# Sample lambda and reshape it to do the mixup
l = sample_beta_distribution(batch_size, alpha, alpha)
x_l = tf.reshape(l, (batch_size, 1, 1, 1))
y_l = tf.reshape(l, (batch_size, 1))
# Perform mixup on both images and labels by combining a pair of images/labels
# (one from each dataset) into one image/label
images = images_one * x_l + images_two * (1 - x_l)
labels = labels_one * y_l + labels_two * (1 - y_l)
return (images, labels)
ここでは、単一画像を作成するために 2 つの画像を組合せていることに注意してください。理論的には、望むだけの数を組み合わせることができますが、それは増加した計算コストを伴います。ある場合には、パフォーマンスを改善する役に立たない可能性もあります。
# First create the new dataset using our `mix_up` utility
train_ds_mu = train_ds.map(
lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2), num_parallel_calls=AUTO
)
# Let's preview 9 samples from the dataset
sample_images, sample_labels = next(iter(train_ds_mu))
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().squeeze())
print(label.numpy().tolist())
plt.axis("off")
[0.01706075668334961, 0.0, 0.0, 0.9829392433166504, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] [0.0, 0.5761554837226868, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.42384451627731323, 0.0] [0.0, 0.0, 0.9999957084655762, 0.0, 4.291534423828125e-06, 0.0, 0.0, 0.0, 0.0, 0.0] [0.0, 0.0, 0.03438800573348999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.96561199426651, 0.0] [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] [0.0, 0.0, 0.9808260202407837, 0.0, 0.0, 0.0, 0.01917397230863571, 0.0, 0.0, 0.0] [0.0, 0.9999748468399048, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5153160095214844e-05] [0.0, 0.0, 0.0, 0.0002035107754636556, 0.0, 0.9997965097427368, 0.0, 0.0, 0.0, 0.0] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2410212755203247, 0.0, 0.0, 0.7589787244796753]
モデル構築
def get_training_model():
model = tf.keras.Sequential(
[
layers.Conv2D(16, (5, 5), activation="relu", input_shape=(28, 28, 1)),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(32, (5, 5), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Dropout(0.2),
layers.GlobalAvgPool2D(),
layers.Dense(128, activation="relu"),
layers.Dense(10, activation="softmax"),
]
)
return model
再現性のために、浅いネットワークの初期ランダム重みをシリアライズします。
initial_model = get_training_model()
initial_model.save_weights("initial_weights.h5")
1. モデルを mix up されたデータセットで訓練する
model = get_training_model()
model.load_weights("initial_weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS)
_, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
Epoch 1/10 907/907 [==============================] - 38s 41ms/step - loss: 1.4440 - accuracy: 0.5173 - val_loss: 0.7120 - val_accuracy: 0.7405 Epoch 2/10 907/907 [==============================] - 38s 42ms/step - loss: 0.9869 - accuracy: 0.7074 - val_loss: 0.5996 - val_accuracy: 0.7780 Epoch 3/10 907/907 [==============================] - 38s 42ms/step - loss: 0.9096 - accuracy: 0.7451 - val_loss: 0.5197 - val_accuracy: 0.8285 Epoch 4/10 907/907 [==============================] - 38s 42ms/step - loss: 0.8485 - accuracy: 0.7741 - val_loss: 0.4830 - val_accuracy: 0.8380 Epoch 5/10 907/907 [==============================] - 38s 42ms/step - loss: 0.8032 - accuracy: 0.7916 - val_loss: 0.4543 - val_accuracy: 0.8445 Epoch 6/10 907/907 [==============================] - 38s 42ms/step - loss: 0.7675 - accuracy: 0.8032 - val_loss: 0.4398 - val_accuracy: 0.8470 Epoch 7/10 907/907 [==============================] - 38s 42ms/step - loss: 0.7474 - accuracy: 0.8098 - val_loss: 0.4262 - val_accuracy: 0.8495 Epoch 8/10 907/907 [==============================] - 38s 42ms/step - loss: 0.7337 - accuracy: 0.8145 - val_loss: 0.3950 - val_accuracy: 0.8650 Epoch 9/10 907/907 [==============================] - 38s 42ms/step - loss: 0.7154 - accuracy: 0.8218 - val_loss: 0.3822 - val_accuracy: 0.8725 Epoch 10/10 907/907 [==============================] - 38s 42ms/step - loss: 0.7095 - accuracy: 0.8224 - val_loss: 0.3563 - val_accuracy: 0.8720 157/157 [==============================] - 2s 14ms/step - loss: 0.3821 - accuracy: 0.8726 Test accuracy: 87.26%
(訳者注: 実験結果)
Epoch 1/10 907/907 [==============================] - 13s 5ms/step - loss: 1.1761 - accuracy: 0.6346 - val_loss: 0.6664 - val_accuracy: 0.7590 Epoch 2/10 907/907 [==============================] - 4s 5ms/step - loss: 0.9485 - accuracy: 0.7263 - val_loss: 0.5948 - val_accuracy: 0.7865 Epoch 3/10 907/907 [==============================] - 4s 5ms/step - loss: 0.8681 - accuracy: 0.7632 - val_loss: 0.4931 - val_accuracy: 0.8305 Epoch 4/10 907/907 [==============================] - 4s 5ms/step - loss: 0.8156 - accuracy: 0.7825 - val_loss: 0.4533 - val_accuracy: 0.8500 Epoch 5/10 907/907 [==============================] - 4s 5ms/step - loss: 0.7826 - accuracy: 0.7976 - val_loss: 0.4368 - val_accuracy: 0.8600 Epoch 6/10 907/907 [==============================] - 4s 5ms/step - loss: 0.7609 - accuracy: 0.8062 - val_loss: 0.4035 - val_accuracy: 0.8655 Epoch 7/10 907/907 [==============================] - 4s 5ms/step - loss: 0.7401 - accuracy: 0.8142 - val_loss: 0.3919 - val_accuracy: 0.8660 Epoch 8/10 907/907 [==============================] - 4s 5ms/step - loss: 0.7261 - accuracy: 0.8180 - val_loss: 0.3863 - val_accuracy: 0.8740 Epoch 9/10 907/907 [==============================] - 4s 5ms/step - loss: 0.7036 - accuracy: 0.8253 - val_loss: 0.3852 - val_accuracy: 0.8730 Epoch 10/10 907/907 [==============================] - 4s 5ms/step - loss: 0.7007 - accuracy: 0.8243 - val_loss: 0.3573 - val_accuracy: 0.8760 157/157 [==============================] - 0s 3ms/step - loss: 0.3892 - accuracy: 0.8679 Test accuracy: 86.79% CPU times: user 54.6 s, sys: 6.26 s, total: 1min Wall time: 1min 5s
1. モデルを mix up されたデータセット なし で訓練する
model = get_training_model()
model.load_weights("initial_weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
# Notice that we are NOT using the mixed up dataset here
model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS)
_, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
Epoch 1/10 907/907 [==============================] - 37s 40ms/step - loss: 1.2037 - accuracy: 0.5553 - val_loss: 0.6732 - val_accuracy: 0.7565 Epoch 2/10 907/907 [==============================] - 37s 40ms/step - loss: 0.6724 - accuracy: 0.7462 - val_loss: 0.5715 - val_accuracy: 0.7940 Epoch 3/10 907/907 [==============================] - 37s 40ms/step - loss: 0.5828 - accuracy: 0.7897 - val_loss: 0.5042 - val_accuracy: 0.8210 Epoch 4/10 907/907 [==============================] - 37s 40ms/step - loss: 0.5203 - accuracy: 0.8115 - val_loss: 0.4587 - val_accuracy: 0.8405 Epoch 5/10 907/907 [==============================] - 36s 40ms/step - loss: 0.4802 - accuracy: 0.8255 - val_loss: 0.4602 - val_accuracy: 0.8340 Epoch 6/10 907/907 [==============================] - 36s 40ms/step - loss: 0.4566 - accuracy: 0.8351 - val_loss: 0.3985 - val_accuracy: 0.8700 Epoch 7/10 907/907 [==============================] - 37s 40ms/step - loss: 0.4273 - accuracy: 0.8457 - val_loss: 0.3764 - val_accuracy: 0.8685 Epoch 8/10 907/907 [==============================] - 36s 40ms/step - loss: 0.4133 - accuracy: 0.8481 - val_loss: 0.3704 - val_accuracy: 0.8735 Epoch 9/10 907/907 [==============================] - 36s 40ms/step - loss: 0.3951 - accuracy: 0.8543 - val_loss: 0.3715 - val_accuracy: 0.8680 Epoch 10/10 907/907 [==============================] - 36s 40ms/step - loss: 0.3850 - accuracy: 0.8586 - val_loss: 0.3458 - val_accuracy: 0.8735 157/157 [==============================] - 2s 13ms/step - loss: 0.3817 - accuracy: 0.8636 Test accuracy: 86.36%
Epoch 1/10 907/907 [==============================] - 37s 40ms/step - loss: 1.2037 - accuracy: 0.5553 - val_loss: 0.6732 - val_accuracy: 0.7565 Epoch 2/10 907/907 [==============================] - 37s 40ms/step - loss: 0.6724 - accuracy: 0.7462 - val_loss: 0.5715 - val_accuracy: 0.7940 Epoch 3/10 907/907 [==============================] - 37s 40ms/step - loss: 0.5828 - accuracy: 0.7897 - val_loss: 0.5042 - val_accuracy: 0.8210 Epoch 4/10 907/907 [==============================] - 37s 40ms/step - loss: 0.5203 - accuracy: 0.8115 - val_loss: 0.4587 - val_accuracy: 0.8405 Epoch 5/10 907/907 [==============================] - 36s 40ms/step - loss: 0.4802 - accuracy: 0.8255 - val_loss: 0.4602 - val_accuracy: 0.8340 Epoch 6/10 907/907 [==============================] - 36s 40ms/step - loss: 0.4566 - accuracy: 0.8351 - val_loss: 0.3985 - val_accuracy: 0.8700 Epoch 7/10 907/907 [==============================] - 37s 40ms/step - loss: 0.4273 - accuracy: 0.8457 - val_loss: 0.3764 - val_accuracy: 0.8685 Epoch 8/10 907/907 [==============================] - 36s 40ms/step - loss: 0.4133 - accuracy: 0.8481 - val_loss: 0.3704 - val_accuracy: 0.8735 Epoch 9/10 907/907 [==============================] - 36s 40ms/step - loss: 0.3951 - accuracy: 0.8543 - val_loss: 0.3715 - val_accuracy: 0.8680 Epoch 10/10 907/907 [==============================] - 36s 40ms/step - loss: 0.3850 - accuracy: 0.8586 - val_loss: 0.3458 - val_accuracy: 0.8735 157/157 [==============================] - 2s 13ms/step - loss: 0.3817 - accuracy: 0.8636 Test accuracy: 86.36%
読者は、異なるドメインからの異なるデータセットで mixup を試して lambda パラメータで実験することを勧めます。原論文 を確認することも強く勧められます – 著者らは、mixup が汎用化をどのように改善できるかを示す幾つかのアブレーション (= ablation) 研究を提示し、そして単一画像を作成するために 2 つの画像より多くを組み合わせる結果も示しています。
ノート
- mixup で、合成サンプルを作成できます – 特に大規模なデータセットを欠いているとき – 高い計算コストをかけることなく。
- ラベル smoothing と mixup は通常は一緒には上手く機能しません、何故ならばラベル smoothing はある因子で既にハードラベルを変更しているからです。
- 教師あり対照学習 (SCL, Supervised Contrastive Learning) を使用しているとき mixup は上手く機能しません、SCL は事前訓練段階で真のラベルを想定しているからです。
- mixup の幾つかの他の利点は (この 論文 で説明されているように) 敵対的サンプルへの堅牢性と安定した GAN 訓練を含みます。
- CutMix と AugMix のような mixup を拡張したデータ増強テクニックも多くあります。
以上