Keras 2 : examples : 教師あり対照学習 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/19/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Supervised Contrastive Learning (Author: Khalid Salama)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 教師あり対照学習
Description : 画像分類のために教師あり対照学習を使用する。
イントロダクション
教師あり対照学習 (Prannay Khosla et al.) は、分類タスクにおいて交差エントロピーによる教師あり学習のパフォーマンスを超える訓練手法です。
原則的には、教師あり対照学習による画像分類モデルの訓練は 2 段階で実行されます :
- 同じクラスの画像の表現が異なるクラスの画像の表現に比べてより類似しているように、入力画像のベクトル表現を生成することを学習するエンコーダを訓練します。
- 凍結されたエンコーダの上で分類器を訓練します。
このサンプルは TensorFlow Addons を必要とすることに注意してください、これは次のコマンドを使用してインストールできます :
pip install tensorflow-addons
セットアップ
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
データの準備
num_classes = 10
input_shape = (32, 32, 3)
# Load the train and test data splits
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# Display shapes of train and test datasets
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
画像データ増強の使用
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.02),
layers.RandomWidth(0.2),
layers.RandomHeight(0.2),
]
)
# Setting the state of the normalization layer.
data_augmentation.layers[0].adapt(x_train)
エンコーダモデルの構築
エンコーダモデルは入力として画像を取りそれを 2048-次元特徴ベクトルに変えます。
def create_encoder():
resnet = keras.applications.ResNet50V2(
include_top=False, weights=None, input_shape=input_shape, pooling="avg"
)
inputs = keras.Input(shape=input_shape)
augmented = data_augmentation(inputs)
outputs = resnet(augmented)
model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-encoder")
return model
encoder = create_encoder()
encoder.summary()
learning_rate = 0.001
batch_size = 265
hidden_units = 512
projection_units = 128
num_epochs = 50
dropout_rate = 0.5
temperature = 0.05
Model: "cifar10-encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 32, 32, 3)] 0 _________________________________________________________________ sequential (Sequential) (None, None, None, 3) 7 _________________________________________________________________ resnet50v2 (Functional) (None, 2048) 23564800 ================================================================= Total params: 23,564,807 Trainable params: 23,519,360 Non-trainable params: 45,447 _________________________________________________________________
分類モデルの構築
分類モデルはエンコーダの上に完全結合層、そしてターゲットクラスを持つ softmax 層を追加します。
def create_classifier(encoder, trainable=True):
for layer in encoder.layers:
layer.trainable = trainable
inputs = keras.Input(shape=input_shape)
features = encoder(inputs)
features = layers.Dropout(dropout_rate)(features)
features = layers.Dense(hidden_units, activation="relu")(features)
features = layers.Dropout(dropout_rate)(features)
outputs = layers.Dense(num_classes, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-classifier")
model.compile(
optimizer=keras.optimizers.Adam(learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
実験 1 : ベースライン分類モデルの訓練
この実験では、ベースライン分類器が通常のように訓練されます、つまり、エンコーダと分類器パートは交差エントロピー損失を最小化するために単一モデルとして一緒に訓練されます。
encoder = create_encoder()
classifier = create_classifier(encoder)
classifier.summary()
history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)
accuracy = classifier.evaluate(x_test, y_test)[1]
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
Model: "cifar10-classifier" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 32, 32, 3)] 0 _________________________________________________________________ cifar10-encoder (Functional) (None, 2048) 23564807 _________________________________________________________________ dropout (Dropout) (None, 2048) 0 _________________________________________________________________ dense (Dense) (None, 512) 1049088 _________________________________________________________________ dropout_1 (Dropout) (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 24,619,025 Trainable params: 24,573,578 Non-trainable params: 45,447 _________________________________________________________________ Epoch 1/50 189/189 [==============================] - 15s 77ms/step - loss: 1.9369 - sparse_categorical_accuracy: 0.2874 Epoch 2/50 189/189 [==============================] - 11s 57ms/step - loss: 1.5133 - sparse_categorical_accuracy: 0.4505 Epoch 3/50 189/189 [==============================] - 11s 57ms/step - loss: 1.3468 - sparse_categorical_accuracy: 0.5204 Epoch 4/50 189/189 [==============================] - 11s 60ms/step - loss: 1.2159 - sparse_categorical_accuracy: 0.5733 Epoch 5/50 189/189 [==============================] - 11s 56ms/step - loss: 1.1516 - sparse_categorical_accuracy: 0.6032 Epoch 6/50 189/189 [==============================] - 11s 58ms/step - loss: 1.0769 - sparse_categorical_accuracy: 0.6254 Epoch 7/50 189/189 [==============================] - 11s 58ms/step - loss: 0.9964 - sparse_categorical_accuracy: 0.6547 Epoch 8/50 189/189 [==============================] - 10s 55ms/step - loss: 0.9563 - sparse_categorical_accuracy: 0.6703 Epoch 9/50 189/189 [==============================] - 10s 55ms/step - loss: 0.8952 - sparse_categorical_accuracy: 0.6925 Epoch 10/50 189/189 [==============================] - 11s 56ms/step - loss: 0.8986 - sparse_categorical_accuracy: 0.6922 Epoch 11/50 189/189 [==============================] - 10s 55ms/step - loss: 0.8381 - sparse_categorical_accuracy: 0.7145 Epoch 12/50 189/189 [==============================] - 10s 55ms/step - loss: 0.8513 - sparse_categorical_accuracy: 0.7086 Epoch 13/50 189/189 [==============================] - 11s 56ms/step - loss: 0.7557 - sparse_categorical_accuracy: 0.7448 Epoch 14/50 189/189 [==============================] - 11s 56ms/step - loss: 0.7168 - sparse_categorical_accuracy: 0.7548 Epoch 15/50 189/189 [==============================] - 10s 55ms/step - loss: 0.6772 - sparse_categorical_accuracy: 0.7690 Epoch 16/50 189/189 [==============================] - 11s 56ms/step - loss: 0.7587 - sparse_categorical_accuracy: 0.7416 Epoch 17/50 189/189 [==============================] - 10s 55ms/step - loss: 0.6873 - sparse_categorical_accuracy: 0.7665 Epoch 18/50 189/189 [==============================] - 11s 56ms/step - loss: 0.6418 - sparse_categorical_accuracy: 0.7804 Epoch 19/50 189/189 [==============================] - 11s 56ms/step - loss: 0.6086 - sparse_categorical_accuracy: 0.7927 Epoch 20/50 189/189 [==============================] - 10s 55ms/step - loss: 0.5903 - sparse_categorical_accuracy: 0.7978 Epoch 21/50 189/189 [==============================] - 11s 56ms/step - loss: 0.5636 - sparse_categorical_accuracy: 0.8083 Epoch 22/50 189/189 [==============================] - 11s 56ms/step - loss: 0.5527 - sparse_categorical_accuracy: 0.8123 Epoch 23/50 189/189 [==============================] - 11s 56ms/step - loss: 0.5308 - sparse_categorical_accuracy: 0.8191 Epoch 24/50 189/189 [==============================] - 10s 55ms/step - loss: 0.5282 - sparse_categorical_accuracy: 0.8223 Epoch 25/50 189/189 [==============================] - 10s 55ms/step - loss: 0.5090 - sparse_categorical_accuracy: 0.8263 Epoch 26/50 189/189 [==============================] - 10s 55ms/step - loss: 0.5497 - sparse_categorical_accuracy: 0.8181 Epoch 27/50 189/189 [==============================] - 10s 55ms/step - loss: 0.4950 - sparse_categorical_accuracy: 0.8332 Epoch 28/50 189/189 [==============================] - 11s 56ms/step - loss: 0.4727 - sparse_categorical_accuracy: 0.8391 Epoch 29/50 167/189 [=========================>....] - ETA: 1s - loss: 0.4594 - sparse_categorical_accuracy: 0.8444
(訳者注: 実験結果)
Epoch 1/50 189/189 [==============================] - 34s 100ms/step - loss: 1.9656 - sparse_categorical_accuracy: 0.2784 Epoch 2/50 189/189 [==============================] - 13s 66ms/step - loss: 1.5228 - sparse_categorical_accuracy: 0.4455 Epoch 3/50 189/189 [==============================] - 13s 69ms/step - loss: 1.3597 - sparse_categorical_accuracy: 0.5152 Epoch 4/50 189/189 [==============================] - 11s 58ms/step - loss: 1.2510 - sparse_categorical_accuracy: 0.5567 Epoch 5/50 189/189 [==============================] - 13s 67ms/step - loss: 1.1573 - sparse_categorical_accuracy: 0.5929 Epoch 6/50 189/189 [==============================] - 11s 58ms/step - loss: 1.1039 - sparse_categorical_accuracy: 0.6172 Epoch 7/50 189/189 [==============================] - 11s 57ms/step - loss: 1.1553 - sparse_categorical_accuracy: 0.5952 Epoch 8/50 189/189 [==============================] - 11s 58ms/step - loss: 1.0258 - sparse_categorical_accuracy: 0.6460 Epoch 9/50 189/189 [==============================] - 11s 57ms/step - loss: 0.9233 - sparse_categorical_accuracy: 0.6823 Epoch 10/50 189/189 [==============================] - 11s 56ms/step - loss: 0.8598 - sparse_categorical_accuracy: 0.7045 Epoch 11/50 189/189 [==============================] - 11s 57ms/step - loss: 0.9328 - sparse_categorical_accuracy: 0.6835 Epoch 12/50 189/189 [==============================] - 11s 57ms/step - loss: 0.8088 - sparse_categorical_accuracy: 0.7227 Epoch 13/50 189/189 [==============================] - 11s 60ms/step - loss: 0.7679 - sparse_categorical_accuracy: 0.7372 Epoch 14/50 189/189 [==============================] - 11s 57ms/step - loss: 0.7226 - sparse_categorical_accuracy: 0.7547 Epoch 15/50 189/189 [==============================] - 11s 57ms/step - loss: 0.6902 - sparse_categorical_accuracy: 0.7644 Epoch 16/50 189/189 [==============================] - 11s 57ms/step - loss: 0.6651 - sparse_categorical_accuracy: 0.7739 Epoch 17/50 189/189 [==============================] - 11s 56ms/step - loss: 0.6392 - sparse_categorical_accuracy: 0.7804 Epoch 18/50 189/189 [==============================] - 11s 56ms/step - loss: 0.6291 - sparse_categorical_accuracy: 0.7871 Epoch 19/50 189/189 [==============================] - 11s 56ms/step - loss: 0.6826 - sparse_categorical_accuracy: 0.7669 Epoch 20/50 189/189 [==============================] - 11s 56ms/step - loss: 0.7607 - sparse_categorical_accuracy: 0.7418 Epoch 21/50 189/189 [==============================] - 11s 56ms/step - loss: 0.6255 - sparse_categorical_accuracy: 0.7881 Epoch 22/50 189/189 [==============================] - 11s 57ms/step - loss: 0.5754 - sparse_categorical_accuracy: 0.8045 Epoch 23/50 189/189 [==============================] - 11s 58ms/step - loss: 0.5553 - sparse_categorical_accuracy: 0.8114 Epoch 24/50 189/189 [==============================] - 11s 57ms/step - loss: 0.5327 - sparse_categorical_accuracy: 0.8206 Epoch 25/50 189/189 [==============================] - 11s 57ms/step - loss: 0.5079 - sparse_categorical_accuracy: 0.8285 Epoch 26/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4984 - sparse_categorical_accuracy: 0.8283 Epoch 27/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4740 - sparse_categorical_accuracy: 0.8399 Epoch 28/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4581 - sparse_categorical_accuracy: 0.8451 Epoch 29/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4574 - sparse_categorical_accuracy: 0.8454 Epoch 30/50 189/189 [==============================] - 11s 56ms/step - loss: 0.4508 - sparse_categorical_accuracy: 0.8468 Epoch 31/50 189/189 [==============================] - 11s 58ms/step - loss: 0.4382 - sparse_categorical_accuracy: 0.8493 Epoch 32/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4211 - sparse_categorical_accuracy: 0.8550 Epoch 33/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4129 - sparse_categorical_accuracy: 0.8600 Epoch 34/50 189/189 [==============================] - 11s 57ms/step - loss: 0.7405 - sparse_categorical_accuracy: 0.7566 Epoch 35/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4615 - sparse_categorical_accuracy: 0.8427 Epoch 36/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3987 - sparse_categorical_accuracy: 0.8639 Epoch 37/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4043 - sparse_categorical_accuracy: 0.8616 Epoch 38/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3718 - sparse_categorical_accuracy: 0.8723 Epoch 39/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3651 - sparse_categorical_accuracy: 0.8752 Epoch 40/50 189/189 [==============================] - 11s 58ms/step - loss: 0.3447 - sparse_categorical_accuracy: 0.8814 Epoch 41/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3524 - sparse_categorical_accuracy: 0.8790 Epoch 42/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3621 - sparse_categorical_accuracy: 0.8761 Epoch 43/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3254 - sparse_categorical_accuracy: 0.8879 Epoch 44/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3172 - sparse_categorical_accuracy: 0.8915 Epoch 45/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3042 - sparse_categorical_accuracy: 0.8956 Epoch 46/50 189/189 [==============================] - 11s 57ms/step - loss: 0.3020 - sparse_categorical_accuracy: 0.8964 Epoch 47/50 189/189 [==============================] - 11s 57ms/step - loss: 0.2953 - sparse_categorical_accuracy: 0.8993 Epoch 48/50 189/189 [==============================] - 11s 57ms/step - loss: 0.2981 - sparse_categorical_accuracy: 0.8976 Epoch 49/50 189/189 [==============================] - 11s 57ms/step - loss: 0.4010 - sparse_categorical_accuracy: 0.8634 Epoch 50/50 189/189 [==============================] - 11s 57ms/step - loss: 0.2894 - sparse_categorical_accuracy: 0.9007 313/313 [==============================] - 5s 12ms/step - loss: 0.8723 - sparse_categorical_accuracy: 0.7954 Test accuracy: 79.54% CPU times: user 8min 53s, sys: 13.7 s, total: 9min 7s Wall time: 10min 36s
実験 2 : 教師あり対照学習の利用
この実験では、モデルは 2 段階で訓練されます。最初の段階では、Prannay Khosla et al. で説明されているように、エンコーダが教師あり対照損失を最適化するために事前訓練されます。
第 2 段階では、訓練されたエンコーダを (重みは凍結した上で) 使用して分類器が訓練されます ; softmax を持つ完全結合層の重みだけが最適化されます。
1. 教師あり対照学習損失関数
class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=1, name=None):
super(SupervisedContrastiveLoss, self).__init__(name=name)
self.temperature = temperature
def __call__(self, labels, feature_vectors, sample_weight=None):
# Normalize feature vectors
feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
# Compute logits
logits = tf.divide(
tf.matmul(
feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
),
self.temperature,
)
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
def add_projection_head(encoder):
inputs = keras.Input(shape=input_shape)
features = encoder(inputs)
outputs = layers.Dense(projection_units, activation="relu")(features)
model = keras.Model(
inputs=inputs, outputs=outputs, name="cifar-encoder_with_projection-head"
)
return model
2. エンコーダの事前訓練
encoder = create_encoder()
encoder_with_projection_head = add_projection_head(encoder)
encoder_with_projection_head.compile(
optimizer=keras.optimizers.Adam(learning_rate),
loss=SupervisedContrastiveLoss(temperature),
)
encoder_with_projection_head.summary()
history = encoder_with_projection_head.fit(
x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs
)
Model: "cifar-encoder_with_projection-head" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_8 (InputLayer) [(None, 32, 32, 3)] 0 _________________________________________________________________ cifar10-encoder (Functional) (None, 2048) 23564807 _________________________________________________________________ dense_2 (Dense) (None, 128) 262272 ================================================================= Total params: 23,827,079 Trainable params: 23,781,632 Non-trainable params: 45,447 _________________________________________________________________ Epoch 1/50 189/189 [==============================] - 11s 56ms/step - loss: 5.3730 Epoch 2/50 189/189 [==============================] - 11s 56ms/step - loss: 5.1583 Epoch 3/50 189/189 [==============================] - 10s 55ms/step - loss: 5.0368 Epoch 4/50 189/189 [==============================] - 11s 56ms/step - loss: 4.9349 Epoch 5/50 189/189 [==============================] - 10s 55ms/step - loss: 4.8262 Epoch 6/50 189/189 [==============================] - 11s 56ms/step - loss: 4.7470 Epoch 7/50 189/189 [==============================] - 11s 56ms/step - loss: 4.6835 Epoch 8/50 189/189 [==============================] - 11s 56ms/step - loss: 4.6120 Epoch 9/50 189/189 [==============================] - 11s 56ms/step - loss: 4.5608 Epoch 10/50 189/189 [==============================] - 10s 55ms/step - loss: 4.5075 Epoch 11/50 189/189 [==============================] - 11s 56ms/step - loss: 4.4674 Epoch 12/50 189/189 [==============================] - 10s 56ms/step - loss: 4.4362 Epoch 13/50 189/189 [==============================] - 11s 56ms/step - loss: 4.3899 Epoch 14/50 189/189 [==============================] - 10s 55ms/step - loss: 4.3664 Epoch 15/50 189/189 [==============================] - 11s 56ms/step - loss: 4.3188 Epoch 16/50 189/189 [==============================] - 10s 56ms/step - loss: 4.3030 Epoch 17/50 189/189 [==============================] - 11s 57ms/step - loss: 4.2725 Epoch 18/50 189/189 [==============================] - 10s 55ms/step - loss: 4.2523 Epoch 19/50 189/189 [==============================] - 11s 56ms/step - loss: 4.2100 Epoch 20/50 189/189 [==============================] - 10s 55ms/step - loss: 4.2033 Epoch 21/50 189/189 [==============================] - 11s 56ms/step - loss: 4.1741 Epoch 22/50 189/189 [==============================] - 11s 56ms/step - loss: 4.1443 Epoch 23/50 189/189 [==============================] - 11s 56ms/step - loss: 4.1350 Epoch 24/50 189/189 [==============================] - 11s 57ms/step - loss: 4.1192 Epoch 25/50 189/189 [==============================] - 11s 56ms/step - loss: 4.1002 Epoch 26/50 189/189 [==============================] - 11s 57ms/step - loss: 4.0797 Epoch 27/50 189/189 [==============================] - 11s 56ms/step - loss: 4.0547 Epoch 28/50 189/189 [==============================] - 11s 56ms/step - loss: 4.0336 Epoch 29/50 189/189 [==============================] - 11s 56ms/step - loss: 4.0299 Epoch 30/50 189/189 [==============================] - 11s 56ms/step - loss: 4.0031 Epoch 31/50 189/189 [==============================] - 11s 56ms/step - loss: 3.9979 Epoch 32/50 189/189 [==============================] - 11s 56ms/step - loss: 3.9777 Epoch 33/50 189/189 [==============================] - 10s 55ms/step - loss: 3.9800 Epoch 34/50 189/189 [==============================] - 11s 56ms/step - loss: 3.9538 Epoch 35/50 189/189 [==============================] - 11s 56ms/step - loss: 3.9298 Epoch 36/50 189/189 [==============================] - 11s 57ms/step - loss: 3.9241 Epoch 37/50 189/189 [==============================] - 11s 56ms/step - loss: 3.9102 Epoch 38/50 189/189 [==============================] - 11s 56ms/step - loss: 3.9075 Epoch 39/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8897 Epoch 40/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8871 Epoch 41/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8596 Epoch 42/50 189/189 [==============================] - 10s 56ms/step - loss: 3.8526 Epoch 43/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8417 Epoch 44/50 189/189 [==============================] - 10s 55ms/step - loss: 3.8239 Epoch 45/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8178 Epoch 46/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8065 Epoch 47/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8185 Epoch 48/50 189/189 [==============================] - 11s 56ms/step - loss: 3.8022 Epoch 49/50 189/189 [==============================] - 11s 56ms/step - loss: 3.7815 Epoch 50/50 189/189 [==============================] - 11s 56ms/step - loss: 3.7601
Model: "cifar-encoder_with_projection-head" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_8 (InputLayer) [(None, 32, 32, 3)] 0 cifar10-encoder (Functional (None, 2048) 23564807 ) dense_2 (Dense) (None, 128) 262272 ================================================================= Total params: 23,827,079 Trainable params: 23,781,632 Non-trainable params: 45,447 _________________________________________________________________ Epoch 1/50 189/189 [==============================] - 16s 57ms/step - loss: 5.3859 Epoch 2/50 189/189 [==============================] - 11s 58ms/step - loss: 5.1560 Epoch 3/50 189/189 [==============================] - 11s 57ms/step - loss: 5.0352 Epoch 4/50 189/189 [==============================] - 11s 57ms/step - loss: 4.9236 Epoch 5/50 189/189 [==============================] - 11s 58ms/step - loss: 4.8284 Epoch 6/50 189/189 [==============================] - 11s 57ms/step - loss: 4.7470 Epoch 7/50 189/189 [==============================] - 11s 58ms/step - loss: 4.6875 Epoch 8/50 189/189 [==============================] - 11s 57ms/step - loss: 4.6289 Epoch 9/50 189/189 [==============================] - 11s 57ms/step - loss: 4.5714 Epoch 10/50 189/189 [==============================] - 11s 58ms/step - loss: 4.5163 Epoch 11/50 189/189 [==============================] - 11s 58ms/step - loss: 4.4755 Epoch 12/50 189/189 [==============================] - 11s 58ms/step - loss: 4.4294 Epoch 13/50 189/189 [==============================] - 11s 58ms/step - loss: 4.3943 Epoch 14/50 189/189 [==============================] - 11s 58ms/step - loss: 4.3669 Epoch 15/50 189/189 [==============================] - 11s 58ms/step - loss: 4.3246 Epoch 16/50 189/189 [==============================] - 11s 57ms/step - loss: 4.2921 Epoch 17/50 189/189 [==============================] - 11s 57ms/step - loss: 4.2576 Epoch 18/50 189/189 [==============================] - 11s 58ms/step - loss: 4.2367 Epoch 19/50 189/189 [==============================] - 11s 57ms/step - loss: 4.2170 Epoch 20/50 189/189 [==============================] - 11s 57ms/step - loss: 4.1998 Epoch 21/50 189/189 [==============================] - 11s 57ms/step - loss: 4.1656 Epoch 22/50 189/189 [==============================] - 11s 57ms/step - loss: 4.1398 Epoch 23/50 189/189 [==============================] - 11s 57ms/step - loss: 4.1203 Epoch 24/50 189/189 [==============================] - 11s 57ms/step - loss: 4.1022 Epoch 25/50 189/189 [==============================] - 11s 57ms/step - loss: 4.0824 Epoch 26/50 189/189 [==============================] - 11s 57ms/step - loss: 4.0664 Epoch 27/50 189/189 [==============================] - 11s 57ms/step - loss: 4.0583 Epoch 28/50 189/189 [==============================] - 11s 57ms/step - loss: 4.0313 Epoch 29/50 189/189 [==============================] - 11s 58ms/step - loss: 4.0114 Epoch 30/50 189/189 [==============================] - 11s 58ms/step - loss: 3.9954 Epoch 31/50 189/189 [==============================] - 11s 57ms/step - loss: 3.9716 Epoch 32/50 189/189 [==============================] - 11s 57ms/step - loss: 3.9646 Epoch 33/50 189/189 [==============================] - 11s 57ms/step - loss: 3.9440 Epoch 34/50 189/189 [==============================] - 11s 57ms/step - loss: 3.9356 Epoch 35/50 189/189 [==============================] - 11s 57ms/step - loss: 3.9167 Epoch 36/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8941 Epoch 37/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8983 Epoch 38/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8905 Epoch 39/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8671 Epoch 40/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8540 Epoch 41/50 189/189 [==============================] - 11s 58ms/step - loss: 3.8406 Epoch 42/50 189/189 [==============================] - 11s 58ms/step - loss: 3.8311 Epoch 43/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8190 Epoch 44/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8140 Epoch 45/50 189/189 [==============================] - 11s 57ms/step - loss: 3.8091 Epoch 46/50 189/189 [==============================] - 11s 57ms/step - loss: 3.7872 Epoch 47/50 189/189 [==============================] - 11s 57ms/step - loss: 3.7743 Epoch 48/50 189/189 [==============================] - 11s 56ms/step - loss: 3.7913 Epoch 49/50 189/189 [==============================] - 11s 57ms/step - loss: 3.7578 Epoch 50/50 189/189 [==============================] - 11s 57ms/step - loss: 3.7589
3. 凍結されたエンコーダを用いた分類器の訓練
classifier = create_classifier(encoder, trainable=False)
history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)
accuracy = classifier.evaluate(x_test, y_test)[1]
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
Epoch 1/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3979 - sparse_categorical_accuracy: 0.8869 Epoch 2/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3422 - sparse_categorical_accuracy: 0.8959 Epoch 3/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3251 - sparse_categorical_accuracy: 0.9004 Epoch 4/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3313 - sparse_categorical_accuracy: 0.8963 Epoch 5/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3213 - sparse_categorical_accuracy: 0.9006 Epoch 6/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3221 - sparse_categorical_accuracy: 0.9001 Epoch 7/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3134 - sparse_categorical_accuracy: 0.9001 Epoch 8/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3245 - sparse_categorical_accuracy: 0.8978 Epoch 9/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3144 - sparse_categorical_accuracy: 0.9001 Epoch 10/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3191 - sparse_categorical_accuracy: 0.8984 Epoch 11/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3104 - sparse_categorical_accuracy: 0.9025 Epoch 12/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3261 - sparse_categorical_accuracy: 0.8958 Epoch 13/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3130 - sparse_categorical_accuracy: 0.9001 Epoch 14/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3147 - sparse_categorical_accuracy: 0.9003 Epoch 15/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3113 - sparse_categorical_accuracy: 0.9016 Epoch 16/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3114 - sparse_categorical_accuracy: 0.9008 Epoch 17/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3044 - sparse_categorical_accuracy: 0.9026 Epoch 18/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3142 - sparse_categorical_accuracy: 0.8987 Epoch 19/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3139 - sparse_categorical_accuracy: 0.9018 Epoch 20/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3199 - sparse_categorical_accuracy: 0.8987 Epoch 21/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3125 - sparse_categorical_accuracy: 0.8994 Epoch 22/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3291 - sparse_categorical_accuracy: 0.8967 Epoch 23/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3208 - sparse_categorical_accuracy: 0.8963 Epoch 24/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3065 - sparse_categorical_accuracy: 0.9041 Epoch 25/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3099 - sparse_categorical_accuracy: 0.9006 Epoch 26/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3181 - sparse_categorical_accuracy: 0.8986 Epoch 27/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3112 - sparse_categorical_accuracy: 0.9013 Epoch 28/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3136 - sparse_categorical_accuracy: 0.8996 Epoch 29/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3217 - sparse_categorical_accuracy: 0.8969 Epoch 30/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3161 - sparse_categorical_accuracy: 0.8998 Epoch 31/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3151 - sparse_categorical_accuracy: 0.8999 Epoch 32/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3092 - sparse_categorical_accuracy: 0.9009 Epoch 33/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3246 - sparse_categorical_accuracy: 0.8961 Epoch 34/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3143 - sparse_categorical_accuracy: 0.8995 Epoch 35/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3106 - sparse_categorical_accuracy: 0.9002 Epoch 36/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3210 - sparse_categorical_accuracy: 0.8980 Epoch 37/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3178 - sparse_categorical_accuracy: 0.9009 Epoch 38/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3064 - sparse_categorical_accuracy: 0.9032 Epoch 39/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3196 - sparse_categorical_accuracy: 0.8981 Epoch 40/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3177 - sparse_categorical_accuracy: 0.8988 Epoch 41/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3167 - sparse_categorical_accuracy: 0.8987 Epoch 42/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3110 - sparse_categorical_accuracy: 0.9014 Epoch 43/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3124 - sparse_categorical_accuracy: 0.9002 Epoch 44/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3128 - sparse_categorical_accuracy: 0.8999 Epoch 45/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3131 - sparse_categorical_accuracy: 0.8991 Epoch 46/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3149 - sparse_categorical_accuracy: 0.8992 Epoch 47/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3082 - sparse_categorical_accuracy: 0.9021 Epoch 48/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3223 - sparse_categorical_accuracy: 0.8959 Epoch 49/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3195 - sparse_categorical_accuracy: 0.8981 Epoch 50/50 189/189 [==============================] - 3s 16ms/step - loss: 0.3240 - sparse_categorical_accuracy: 0.8962 313/313 [==============================] - 2s 7ms/step - loss: 0.7332 - sparse_categorical_accuracy: 0.8162 Test accuracy: 81.62%
Epoch 1/50 189/189 [==============================] - 7s 20ms/step - loss: 0.3696 - sparse_categorical_accuracy: 0.8963 Epoch 2/50 189/189 [==============================] - 4s 20ms/step - loss: 0.3153 - sparse_categorical_accuracy: 0.9045 Epoch 3/50 189/189 [==============================] - 6s 30ms/step - loss: 0.3005 - sparse_categorical_accuracy: 0.9078 Epoch 4/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2962 - sparse_categorical_accuracy: 0.9091 Epoch 5/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2982 - sparse_categorical_accuracy: 0.9075 Epoch 6/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2970 - sparse_categorical_accuracy: 0.9072 Epoch 7/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2793 - sparse_categorical_accuracy: 0.9113 Epoch 8/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2881 - sparse_categorical_accuracy: 0.9103 Epoch 9/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2926 - sparse_categorical_accuracy: 0.9084 Epoch 10/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2884 - sparse_categorical_accuracy: 0.9108 Epoch 11/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2976 - sparse_categorical_accuracy: 0.9068 Epoch 12/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2926 - sparse_categorical_accuracy: 0.9081 Epoch 13/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2952 - sparse_categorical_accuracy: 0.9064 Epoch 14/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2896 - sparse_categorical_accuracy: 0.9096 Epoch 15/50 189/189 [==============================] - 4s 20ms/step - loss: 0.3003 - sparse_categorical_accuracy: 0.9069 Epoch 16/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2896 - sparse_categorical_accuracy: 0.9078 Epoch 17/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2921 - sparse_categorical_accuracy: 0.9066 Epoch 18/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2946 - sparse_categorical_accuracy: 0.9086 Epoch 19/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2993 - sparse_categorical_accuracy: 0.9063 Epoch 20/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2934 - sparse_categorical_accuracy: 0.9070 Epoch 21/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2991 - sparse_categorical_accuracy: 0.9056 Epoch 22/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9103 Epoch 23/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2924 - sparse_categorical_accuracy: 0.9094 Epoch 24/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2838 - sparse_categorical_accuracy: 0.9094 Epoch 25/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2899 - sparse_categorical_accuracy: 0.9088 Epoch 26/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2924 - sparse_categorical_accuracy: 0.9077 Epoch 27/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2954 - sparse_categorical_accuracy: 0.9080 Epoch 28/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2884 - sparse_categorical_accuracy: 0.9096 Epoch 29/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2872 - sparse_categorical_accuracy: 0.9082 Epoch 30/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2892 - sparse_categorical_accuracy: 0.9086 Epoch 31/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2857 - sparse_categorical_accuracy: 0.9085 Epoch 32/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2941 - sparse_categorical_accuracy: 0.9066 Epoch 33/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2835 - sparse_categorical_accuracy: 0.9103 Epoch 34/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2812 - sparse_categorical_accuracy: 0.9119 Epoch 35/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2895 - sparse_categorical_accuracy: 0.9084 Epoch 36/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2823 - sparse_categorical_accuracy: 0.9100 Epoch 37/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2910 - sparse_categorical_accuracy: 0.9084 Epoch 38/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2987 - sparse_categorical_accuracy: 0.9047 Epoch 39/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2960 - sparse_categorical_accuracy: 0.9067 Epoch 40/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2858 - sparse_categorical_accuracy: 0.9101 Epoch 41/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2921 - sparse_categorical_accuracy: 0.9077 Epoch 42/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2847 - sparse_categorical_accuracy: 0.9099 Epoch 43/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2842 - sparse_categorical_accuracy: 0.9104 Epoch 44/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2891 - sparse_categorical_accuracy: 0.9085 Epoch 45/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2897 - sparse_categorical_accuracy: 0.9072 Epoch 46/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2851 - sparse_categorical_accuracy: 0.9079 Epoch 47/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2884 - sparse_categorical_accuracy: 0.9085 Epoch 48/50 189/189 [==============================] - 4s 21ms/step - loss: 0.2868 - sparse_categorical_accuracy: 0.9100 Epoch 49/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2832 - sparse_categorical_accuracy: 0.9085 Epoch 50/50 189/189 [==============================] - 4s 20ms/step - loss: 0.2873 - sparse_categorical_accuracy: 0.9087 313/313 [==============================] - 5s 11ms/step - loss: 0.6967 - sparse_categorical_accuracy: 0.8076 Test accuracy: 80.76%
We get to an improved test accuracy.
結論
実験で示されたように、教師あり対照学習テクニックはテスト精度の視点から従来のテクニックのパフォーマンスを超えました。同じ訓練バジェット (i.e. エポック数) が各テクニックに与えられたことに注意してください。教師あり対照学習はエンコーダが、ResNet のような、複雑なアーキテクチャと多くのラベルを持つ多クラス問題を伴うときに報われます (= pay off)。更に、大きなバッチサイズと多層投影ヘッドがその有効性を高めます。See the Supervised Contrastive Learning paper for more details.
以上