Keras 2 : examples : EANet (外部注意 Transformer) で画像分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/18/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Image classification with EANet (External Attention Transformer) (Author: ZhiYong Chang)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : EANet (外部注意 Transformer) で画像分類
Description: 外部注意を活用する Transformer による画像分類。
イントロダクション
このサンプルは画像分類のための EANet モデルを実装し、それを CIFAR-100 データセット上で実演します。2 つの外部の (= external), 小さく, 学習可能でそして共有メモリに基づいた、外部注意 (= external attention) と呼ばれる、EANet は新規の注意 (= attention) メカニズムを導入します、これは 2 つのカスケード線形層と 2 つの正規化層を単純に使用して簡単に実装できます。それは既存のアーキテクチャで使用されていた自己注意 (= self-attention) を都合よく置き換えます。外部注意は、総てのサンプル間の相関関係を暗黙的に考えるだけなので、線形複雑度を持ちます。
このサンプルは TensorFlow 2.5 またはそれ以上、そして TensorFlow Addons パッケージを必要とします、これは次のコマンドを使用してインストールできます :
pip install -U tensorflow-addons
セットアップ
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
データの準備
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 100) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100)
ハイパーパラメータの設定
weight_decay = 0.0001
learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
patch_size = 2 # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
embedding_dim = 64 # Number of hidden units.
mlp_dim = 64
dim_coefficient = 4
num_heads = 4
attention_dropout = 0.2
projection_dropout = 0.2
num_transformer_blocks = 8 # Number of repetitions of the transformer layer
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
Patch size: 2 X 2 = 4 Patches per image: 256
データ増強の利用
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.1),
layers.RandomContrast(factor=0.1),
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
パッチ抽出とエンコーディング層の実装
class PatchExtract(layers.Layer):
def __init__(self, patch_size, **kwargs):
super(PatchExtract, self).__init__(**kwargs)
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=(1, self.patch_size, self.patch_size, 1),
strides=(1, self.patch_size, self.patch_size, 1),
rates=(1, 1, 1, 1),
padding="VALID",
)
patch_dim = patches.shape[-1]
patch_num = patches.shape[1]
return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
class PatchEmbedding(layers.Layer):
def __init__(self, num_patch, embed_dim, **kwargs):
super(PatchEmbedding, self).__init__(**kwargs)
self.num_patch = num_patch
self.proj = layers.Dense(embed_dim)
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
def call(self, patch):
pos = tf.range(start=0, limit=self.num_patch, delta=1)
return self.proj(patch) + self.pos_embed(pos)
外部注意ブロックの実装
def external_attention(
x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
):
_, num_patch, channel = x.shape
assert dim % num_heads == 0
num_heads = num_heads * dim_coefficient
x = layers.Dense(dim * dim_coefficient)(x)
# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
x = tf.reshape(
x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
)
x = tf.transpose(x, perm=[0, 2, 1, 3])
# a linear layer M_k
attn = layers.Dense(dim // dim_coefficient)(x)
# normalize attention map
attn = layers.Softmax(axis=2)(attn)
# dobule-normalization
attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))
attn = layers.Dropout(attention_dropout)(attn)
# a linear layer M_v
x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
x = tf.transpose(x, perm=[0, 2, 1, 3])
x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])
# a linear layer to project original dim
x = layers.Dense(dim)(x)
x = layers.Dropout(projection_dropout)(x)
return x
MLP ブロックの実装
def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
x = layers.Dropout(drop_rate)(x)
x = layers.Dense(embedding_dim)(x)
x = layers.Dropout(drop_rate)(x)
return x
Transformer ブロックの実装
def transformer_encoder(
x,
embedding_dim,
mlp_dim,
num_heads,
dim_coefficient,
attention_dropout,
projection_dropout,
attention_type="external_attention",
):
residual_1 = x
x = layers.LayerNormalization(epsilon=1e-5)(x)
if attention_type == "external_attention":
x = external_attention(
x,
embedding_dim,
num_heads,
dim_coefficient,
attention_dropout,
projection_dropout,
)
elif attention_type == "self_attention":
x = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
)(x, x)
x = layers.add([x, residual_1])
residual_2 = x
x = layers.LayerNormalization(epsilon=1e-5)(x)
x = mlp(x, embedding_dim, mlp_dim)
x = layers.add([x, residual_2])
return x
EANet モデルの実装
EANet モデルは外部アテンションを活用しています。従来の自己注意の計算複雑度は O(d * N ** 2) です、ここで d は埋め込みサイズで、N はパッチの数です。著者らは殆どのピクセルは幾つかの他のピクセルだけに密接に関係していて、N 対 N の注意行列は冗長であるかもしれないことを見出しました。そこで、彼らは代替として外部注意モジュールを提案しました、ここで外部注意の計算複雑度は O(d * S * N) です。d と S はハイパーパラメータですので、提案されたアルゴリズムはピクセル数内で線形です。実際には、これは drop パッチ演算に等値です、何故ならば画像内のパッチに含まれる多くの情報は冗長で重要ではないからです。
def get_model(attention_type="external_attention"):
inputs = layers.Input(shape=input_shape)
# Image augment
x = data_augmentation(inputs)
# Extract patches.
x = PatchExtract(patch_size)(x)
# Create patch embedding.
x = PatchEmbedding(num_patches, embedding_dim)(x)
# Create Transformer block.
for _ in range(num_transformer_blocks):
x = transformer_encoder(
x,
embedding_dim,
mlp_dim,
num_heads,
dim_coefficient,
attention_dropout,
projection_dropout,
attention_type,
)
x = layers.GlobalAvgPool1D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
CIFAR-100 上の訓練
model = get_model(attention_type="external_attention")
model.compile(
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
optimizer=tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
),
metrics=[
keras.metrics.CategoricalAccuracy(name="accuracy"),
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
history = model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=validation_split,
)
Epoch 1/50 313/313 [==============================] - 40s 95ms/step - loss: 4.2091 - accuracy: 0.0723 - top-5-accuracy: 0.2384 - val_loss: 3.9706 - val_accuracy: 0.1153 - val_top-5-accuracy: 0.3336 Epoch 2/50 313/313 [==============================] - 29s 91ms/step - loss: 3.8028 - accuracy: 0.1427 - top-5-accuracy: 0.3871 - val_loss: 3.6672 - val_accuracy: 0.1829 - val_top-5-accuracy: 0.4513 Epoch 3/50 313/313 [==============================] - 29s 93ms/step - loss: 3.5493 - accuracy: 0.1978 - top-5-accuracy: 0.4805 - val_loss: 3.5402 - val_accuracy: 0.2141 - val_top-5-accuracy: 0.5038 Epoch 4/50 313/313 [==============================] - 29s 93ms/step - loss: 3.4029 - accuracy: 0.2355 - top-5-accuracy: 0.5328 - val_loss: 3.4496 - val_accuracy: 0.2354 - val_top-5-accuracy: 0.5316 Epoch 5/50 313/313 [==============================] - 29s 92ms/step - loss: 3.2917 - accuracy: 0.2636 - top-5-accuracy: 0.5678 - val_loss: 3.3342 - val_accuracy: 0.2699 - val_top-5-accuracy: 0.5679 Epoch 6/50 313/313 [==============================] - 29s 92ms/step - loss: 3.2116 - accuracy: 0.2830 - top-5-accuracy: 0.5921 - val_loss: 3.2896 - val_accuracy: 0.2749 - val_top-5-accuracy: 0.5874 Epoch 7/50 313/313 [==============================] - 28s 90ms/step - loss: 3.1453 - accuracy: 0.2980 - top-5-accuracy: 0.6100 - val_loss: 3.3090 - val_accuracy: 0.2857 - val_top-5-accuracy: 0.5831 Epoch 8/50 313/313 [==============================] - 29s 94ms/step - loss: 3.0889 - accuracy: 0.3121 - top-5-accuracy: 0.6266 - val_loss: 3.1969 - val_accuracy: 0.2975 - val_top-5-accuracy: 0.6082 Epoch 9/50 313/313 [==============================] - 29s 92ms/step - loss: 3.0390 - accuracy: 0.3252 - top-5-accuracy: 0.6441 - val_loss: 3.1249 - val_accuracy: 0.3175 - val_top-5-accuracy: 0.6330 Epoch 10/50 313/313 [==============================] - 29s 92ms/step - loss: 2.9871 - accuracy: 0.3365 - top-5-accuracy: 0.6615 - val_loss: 3.1121 - val_accuracy: 0.3200 - val_top-5-accuracy: 0.6374 Epoch 11/50 313/313 [==============================] - 29s 92ms/step - loss: 2.9476 - accuracy: 0.3489 - top-5-accuracy: 0.6697 - val_loss: 3.1156 - val_accuracy: 0.3268 - val_top-5-accuracy: 0.6421 Epoch 12/50 313/313 [==============================] - 29s 91ms/step - loss: 2.9106 - accuracy: 0.3576 - top-5-accuracy: 0.6783 - val_loss: 3.1337 - val_accuracy: 0.3226 - val_top-5-accuracy: 0.6389 Epoch 13/50 313/313 [==============================] - 29s 92ms/step - loss: 2.8772 - accuracy: 0.3662 - top-5-accuracy: 0.6871 - val_loss: 3.0373 - val_accuracy: 0.3348 - val_top-5-accuracy: 0.6624 Epoch 14/50 313/313 [==============================] - 29s 92ms/step - loss: 2.8508 - accuracy: 0.3756 - top-5-accuracy: 0.6944 - val_loss: 3.0297 - val_accuracy: 0.3441 - val_top-5-accuracy: 0.6643 Epoch 15/50 313/313 [==============================] - 28s 90ms/step - loss: 2.8211 - accuracy: 0.3821 - top-5-accuracy: 0.7034 - val_loss: 2.9680 - val_accuracy: 0.3604 - val_top-5-accuracy: 0.6847 Epoch 16/50 313/313 [==============================] - 28s 90ms/step - loss: 2.8017 - accuracy: 0.3864 - top-5-accuracy: 0.7090 - val_loss: 2.9746 - val_accuracy: 0.3584 - val_top-5-accuracy: 0.6855 Epoch 17/50 313/313 [==============================] - 29s 91ms/step - loss: 2.7714 - accuracy: 0.3962 - top-5-accuracy: 0.7169 - val_loss: 2.9104 - val_accuracy: 0.3738 - val_top-5-accuracy: 0.6940 Epoch 18/50 313/313 [==============================] - 29s 92ms/step - loss: 2.7523 - accuracy: 0.4008 - top-5-accuracy: 0.7204 - val_loss: 2.8560 - val_accuracy: 0.3861 - val_top-5-accuracy: 0.7115 Epoch 19/50 313/313 [==============================] - 28s 91ms/step - loss: 2.7320 - accuracy: 0.4051 - top-5-accuracy: 0.7263 - val_loss: 2.8780 - val_accuracy: 0.3820 - val_top-5-accuracy: 0.7101 Epoch 20/50 313/313 [==============================] - 28s 90ms/step - loss: 2.7139 - accuracy: 0.4114 - top-5-accuracy: 0.7290 - val_loss: 2.9831 - val_accuracy: 0.3694 - val_top-5-accuracy: 0.6922 Epoch 21/50 313/313 [==============================] - 28s 91ms/step - loss: 2.6991 - accuracy: 0.4142 - top-5-accuracy: 0.7335 - val_loss: 2.8420 - val_accuracy: 0.3968 - val_top-5-accuracy: 0.7138 Epoch 22/50 313/313 [==============================] - 29s 91ms/step - loss: 2.6842 - accuracy: 0.4195 - top-5-accuracy: 0.7377 - val_loss: 2.7965 - val_accuracy: 0.4088 - val_top-5-accuracy: 0.7266 Epoch 23/50 313/313 [==============================] - 28s 91ms/step - loss: 2.6571 - accuracy: 0.4273 - top-5-accuracy: 0.7436 - val_loss: 2.8620 - val_accuracy: 0.3947 - val_top-5-accuracy: 0.7155 Epoch 24/50 313/313 [==============================] - 29s 91ms/step - loss: 2.6508 - accuracy: 0.4277 - top-5-accuracy: 0.7469 - val_loss: 2.8459 - val_accuracy: 0.3963 - val_top-5-accuracy: 0.7150 Epoch 25/50 313/313 [==============================] - 28s 90ms/step - loss: 2.6403 - accuracy: 0.4283 - top-5-accuracy: 0.7520 - val_loss: 2.7886 - val_accuracy: 0.4128 - val_top-5-accuracy: 0.7283 Epoch 26/50 313/313 [==============================] - 29s 92ms/step - loss: 2.6281 - accuracy: 0.4353 - top-5-accuracy: 0.7523 - val_loss: 2.8493 - val_accuracy: 0.4026 - val_top-5-accuracy: 0.7153 Epoch 27/50 313/313 [==============================] - 29s 92ms/step - loss: 2.6092 - accuracy: 0.4403 - top-5-accuracy: 0.7580 - val_loss: 2.7539 - val_accuracy: 0.4186 - val_top-5-accuracy: 0.7392 Epoch 28/50 313/313 [==============================] - 29s 91ms/step - loss: 2.5992 - accuracy: 0.4423 - top-5-accuracy: 0.7600 - val_loss: 2.8625 - val_accuracy: 0.3964 - val_top-5-accuracy: 0.7174 Epoch 29/50 313/313 [==============================] - 28s 90ms/step - loss: 2.5913 - accuracy: 0.4456 - top-5-accuracy: 0.7598 - val_loss: 2.7911 - val_accuracy: 0.4162 - val_top-5-accuracy: 0.7329 Epoch 30/50 313/313 [==============================] - 29s 92ms/step - loss: 2.5780 - accuracy: 0.4480 - top-5-accuracy: 0.7649 - val_loss: 2.8158 - val_accuracy: 0.4118 - val_top-5-accuracy: 0.7288 Epoch 31/50 313/313 [==============================] - 28s 91ms/step - loss: 2.5657 - accuracy: 0.4547 - top-5-accuracy: 0.7661 - val_loss: 2.8651 - val_accuracy: 0.4056 - val_top-5-accuracy: 0.7217 Epoch 32/50 313/313 [==============================] - 29s 91ms/step - loss: 2.5637 - accuracy: 0.4480 - top-5-accuracy: 0.7681 - val_loss: 2.8190 - val_accuracy: 0.4094 - val_top-5-accuracy: 0.7267 Epoch 33/50 313/313 [==============================] - 29s 92ms/step - loss: 2.5525 - accuracy: 0.4545 - top-5-accuracy: 0.7693 - val_loss: 2.7985 - val_accuracy: 0.4216 - val_top-5-accuracy: 0.7303 Epoch 34/50 313/313 [==============================] - 28s 91ms/step - loss: 2.5462 - accuracy: 0.4579 - top-5-accuracy: 0.7721 - val_loss: 2.8865 - val_accuracy: 0.4016 - val_top-5-accuracy: 0.7204 Epoch 35/50 313/313 [==============================] - 29s 92ms/step - loss: 2.5329 - accuracy: 0.4616 - top-5-accuracy: 0.7740 - val_loss: 2.7862 - val_accuracy: 0.4232 - val_top-5-accuracy: 0.7389 Epoch 36/50 313/313 [==============================] - 28s 90ms/step - loss: 2.5234 - accuracy: 0.4610 - top-5-accuracy: 0.7765 - val_loss: 2.8234 - val_accuracy: 0.4134 - val_top-5-accuracy: 0.7312 Epoch 37/50 313/313 [==============================] - 29s 91ms/step - loss: 2.5152 - accuracy: 0.4663 - top-5-accuracy: 0.7774 - val_loss: 2.7894 - val_accuracy: 0.4161 - val_top-5-accuracy: 0.7376 Epoch 38/50 313/313 [==============================] - 29s 92ms/step - loss: 2.5117 - accuracy: 0.4674 - top-5-accuracy: 0.7790 - val_loss: 2.8091 - val_accuracy: 0.4142 - val_top-5-accuracy: 0.7360 Epoch 39/50 313/313 [==============================] - 28s 90ms/step - loss: 2.5047 - accuracy: 0.4681 - top-5-accuracy: 0.7805 - val_loss: 2.8199 - val_accuracy: 0.4167 - val_top-5-accuracy: 0.7299 Epoch 40/50 313/313 [==============================] - 28s 90ms/step - loss: 2.4974 - accuracy: 0.4697 - top-5-accuracy: 0.7819 - val_loss: 2.7864 - val_accuracy: 0.4247 - val_top-5-accuracy: 0.7402 Epoch 41/50 313/313 [==============================] - 28s 90ms/step - loss: 2.4889 - accuracy: 0.4749 - top-5-accuracy: 0.7854 - val_loss: 2.8120 - val_accuracy: 0.4217 - val_top-5-accuracy: 0.7358 Epoch 42/50 313/313 [==============================] - 28s 90ms/step - loss: 2.4799 - accuracy: 0.4771 - top-5-accuracy: 0.7866 - val_loss: 2.9003 - val_accuracy: 0.4038 - val_top-5-accuracy: 0.7170 Epoch 43/50 313/313 [==============================] - 28s 90ms/step - loss: 2.4814 - accuracy: 0.4770 - top-5-accuracy: 0.7868 - val_loss: 2.7504 - val_accuracy: 0.4260 - val_top-5-accuracy: 0.7457 Epoch 44/50 313/313 [==============================] - 28s 91ms/step - loss: 2.4747 - accuracy: 0.4757 - top-5-accuracy: 0.7870 - val_loss: 2.8207 - val_accuracy: 0.4166 - val_top-5-accuracy: 0.7363 Epoch 45/50 313/313 [==============================] - 28s 90ms/step - loss: 2.4653 - accuracy: 0.4809 - top-5-accuracy: 0.7924 - val_loss: 2.8663 - val_accuracy: 0.4130 - val_top-5-accuracy: 0.7209 Epoch 46/50 313/313 [==============================] - 28s 90ms/step - loss: 2.4554 - accuracy: 0.4825 - top-5-accuracy: 0.7929 - val_loss: 2.8145 - val_accuracy: 0.4250 - val_top-5-accuracy: 0.7357 Epoch 47/50 313/313 [==============================] - 29s 91ms/step - loss: 2.4602 - accuracy: 0.4823 - top-5-accuracy: 0.7919 - val_loss: 2.8352 - val_accuracy: 0.4189 - val_top-5-accuracy: 0.7365 Epoch 48/50 313/313 [==============================] - 28s 91ms/step - loss: 2.4493 - accuracy: 0.4848 - top-5-accuracy: 0.7933 - val_loss: 2.8246 - val_accuracy: 0.4160 - val_top-5-accuracy: 0.7362 Epoch 49/50 313/313 [==============================] - 28s 91ms/step - loss: 2.4454 - accuracy: 0.4846 - top-5-accuracy: 0.7958 - val_loss: 2.7731 - val_accuracy: 0.4320 - val_top-5-accuracy: 0.7436 Epoch 50/50 313/313 [==============================] - 29s 92ms/step - loss: 2.4418 - accuracy: 0.4848 - top-5-accuracy: 0.7951 - val_loss: 2.7926 - val_accuracy: 0.4317 - val_top-5-accuracy: 0.7410
(訳注: 実験結果)
Epoch 1/50 313/313 [==============================] - 89s 235ms/step - loss: 4.2106 - accuracy: 0.0749 - top-5-accuracy: 0.2393 - val_loss: 3.9651 - val_accuracy: 0.1128 - val_top-5-accuracy: 0.3299 Epoch 2/50 313/313 [==============================] - 71s 228ms/step - loss: 3.8254 - accuracy: 0.1354 - top-5-accuracy: 0.3769 - val_loss: 3.7533 - val_accuracy: 0.1687 - val_top-5-accuracy: 0.4227 Epoch 3/50 313/313 [==============================] - 71s 228ms/step - loss: 3.5754 - accuracy: 0.1926 - top-5-accuracy: 0.4712 - val_loss: 3.5608 - val_accuracy: 0.2143 - val_top-5-accuracy: 0.4880 Epoch 4/50 313/313 [==============================] - 71s 228ms/step - loss: 3.4059 - accuracy: 0.2330 - top-5-accuracy: 0.5311 - val_loss: 3.4499 - val_accuracy: 0.2435 - val_top-5-accuracy: 0.5373 Epoch 5/50 313/313 [==============================] - 71s 228ms/step - loss: 3.2898 - accuracy: 0.2609 - top-5-accuracy: 0.5681 - val_loss: 3.4805 - val_accuracy: 0.2453 - val_top-5-accuracy: 0.5369 Epoch 6/50 313/313 [==============================] - 72s 229ms/step - loss: 3.2073 - accuracy: 0.2839 - top-5-accuracy: 0.5903 - val_loss: 3.2480 - val_accuracy: 0.2875 - val_top-5-accuracy: 0.5980 Epoch 7/50 313/313 [==============================] - 72s 229ms/step - loss: 3.1330 - accuracy: 0.3011 - top-5-accuracy: 0.6156 - val_loss: 3.2551 - val_accuracy: 0.2961 - val_top-5-accuracy: 0.6029 Epoch 8/50 313/313 [==============================] - 71s 228ms/step - loss: 3.0821 - accuracy: 0.3138 - top-5-accuracy: 0.6290 - val_loss: 3.1370 - val_accuracy: 0.3131 - val_top-5-accuracy: 0.6247 Epoch 9/50 313/313 [==============================] - 71s 228ms/step - loss: 3.0294 - accuracy: 0.3293 - top-5-accuracy: 0.6432 - val_loss: 3.1374 - val_accuracy: 0.3168 - val_top-5-accuracy: 0.6274 Epoch 10/50 313/313 [==============================] - 71s 228ms/step - loss: 2.9877 - accuracy: 0.3368 - top-5-accuracy: 0.6556 - val_loss: 3.0930 - val_accuracy: 0.3265 - val_top-5-accuracy: 0.6378 Epoch 11/50 313/313 [==============================] - 71s 228ms/step - loss: 2.9496 - accuracy: 0.3476 - top-5-accuracy: 0.6667 - val_loss: 3.1077 - val_accuracy: 0.3248 - val_top-5-accuracy: 0.6457 Epoch 12/50 313/313 [==============================] - 71s 228ms/step - loss: 2.9140 - accuracy: 0.3572 - top-5-accuracy: 0.6773 - val_loss: 3.1588 - val_accuracy: 0.3226 - val_top-5-accuracy: 0.6365 Epoch 13/50 313/313 [==============================] - 71s 228ms/step - loss: 2.8865 - accuracy: 0.3657 - top-5-accuracy: 0.6843 - val_loss: 3.0184 - val_accuracy: 0.3506 - val_top-5-accuracy: 0.6682 Epoch 14/50 313/313 [==============================] - 71s 228ms/step - loss: 2.8529 - accuracy: 0.3716 - top-5-accuracy: 0.6951 - val_loss: 3.0481 - val_accuracy: 0.3482 - val_top-5-accuracy: 0.6647 Epoch 15/50 313/313 [==============================] - 71s 228ms/step - loss: 2.8306 - accuracy: 0.3811 - top-5-accuracy: 0.6994 - val_loss: 2.9535 - val_accuracy: 0.3622 - val_top-5-accuracy: 0.6831 Epoch 16/50 313/313 [==============================] - 71s 228ms/step - loss: 2.8069 - accuracy: 0.3829 - top-5-accuracy: 0.7057 - val_loss: 2.9654 - val_accuracy: 0.3645 - val_top-5-accuracy: 0.6820 Epoch 17/50 313/313 [==============================] - 71s 228ms/step - loss: 2.7879 - accuracy: 0.3925 - top-5-accuracy: 0.7102 - val_loss: 2.9547 - val_accuracy: 0.3553 - val_top-5-accuracy: 0.6861 Epoch 18/50 313/313 [==============================] - 71s 228ms/step - loss: 2.7626 - accuracy: 0.3987 - top-5-accuracy: 0.7151 - val_loss: 2.9865 - val_accuracy: 0.3645 - val_top-5-accuracy: 0.6881 Epoch 19/50 313/313 [==============================] - 71s 228ms/step - loss: 2.7427 - accuracy: 0.4040 - top-5-accuracy: 0.7229 - val_loss: 2.9270 - val_accuracy: 0.3765 - val_top-5-accuracy: 0.6984 Epoch 20/50 313/313 [==============================] - 72s 229ms/step - loss: 2.7197 - accuracy: 0.4107 - top-5-accuracy: 0.7306 - val_loss: 2.9114 - val_accuracy: 0.3818 - val_top-5-accuracy: 0.6994 Epoch 21/50 313/313 [==============================] - 72s 229ms/step - loss: 2.7174 - accuracy: 0.4099 - top-5-accuracy: 0.7289 - val_loss: 2.9156 - val_accuracy: 0.3873 - val_top-5-accuracy: 0.7029 Epoch 22/50 313/313 [==============================] - 71s 228ms/step - loss: 2.6940 - accuracy: 0.4160 - top-5-accuracy: 0.7359 - val_loss: 2.9074 - val_accuracy: 0.3792 - val_top-5-accuracy: 0.6990 Epoch 23/50 313/313 [==============================] - 71s 228ms/step - loss: 2.6783 - accuracy: 0.4202 - top-5-accuracy: 0.7384 - val_loss: 2.9476 - val_accuracy: 0.3776 - val_top-5-accuracy: 0.6942 Epoch 24/50 313/313 [==============================] - 71s 227ms/step - loss: 2.6660 - accuracy: 0.4224 - top-5-accuracy: 0.7427 - val_loss: 2.8997 - val_accuracy: 0.3904 - val_top-5-accuracy: 0.7059 Epoch 25/50 313/313 [==============================] - 71s 228ms/step - loss: 2.6528 - accuracy: 0.4274 - top-5-accuracy: 0.7451 - val_loss: 2.8776 - val_accuracy: 0.3933 - val_top-5-accuracy: 0.7135 Epoch 26/50 313/313 [==============================] - 71s 228ms/step - loss: 2.6355 - accuracy: 0.4307 - top-5-accuracy: 0.7491 - val_loss: 2.8783 - val_accuracy: 0.3886 - val_top-5-accuracy: 0.7135 Epoch 27/50 313/313 [==============================] - 72s 229ms/step - loss: 2.6220 - accuracy: 0.4385 - top-5-accuracy: 0.7518 - val_loss: 2.8804 - val_accuracy: 0.3901 - val_top-5-accuracy: 0.7170 Epoch 28/50 313/313 [==============================] - 71s 227ms/step - loss: 2.6163 - accuracy: 0.4362 - top-5-accuracy: 0.7545 - val_loss: 2.9178 - val_accuracy: 0.3892 - val_top-5-accuracy: 0.7044 Epoch 29/50 313/313 [==============================] - 71s 228ms/step - loss: 2.6035 - accuracy: 0.4400 - top-5-accuracy: 0.7572 - val_loss: 2.9148 - val_accuracy: 0.3905 - val_top-5-accuracy: 0.7070 Epoch 30/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5975 - accuracy: 0.4419 - top-5-accuracy: 0.7587 - val_loss: 2.8426 - val_accuracy: 0.4059 - val_top-5-accuracy: 0.7250 Epoch 31/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5888 - accuracy: 0.4487 - top-5-accuracy: 0.7607 - val_loss: 2.8511 - val_accuracy: 0.4017 - val_top-5-accuracy: 0.7287 Epoch 32/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5834 - accuracy: 0.4472 - top-5-accuracy: 0.7633 - val_loss: 2.8667 - val_accuracy: 0.3997 - val_top-5-accuracy: 0.7200 Epoch 33/50 313/313 [==============================] - 72s 229ms/step - loss: 2.5682 - accuracy: 0.4519 - top-5-accuracy: 0.7671 - val_loss: 2.9856 - val_accuracy: 0.3849 - val_top-5-accuracy: 0.7031 Epoch 34/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5501 - accuracy: 0.4581 - top-5-accuracy: 0.7703 - val_loss: 2.8811 - val_accuracy: 0.4032 - val_top-5-accuracy: 0.7156 Epoch 35/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5456 - accuracy: 0.4590 - top-5-accuracy: 0.7724 - val_loss: 2.8354 - val_accuracy: 0.4104 - val_top-5-accuracy: 0.7311 Epoch 36/50 313/313 [==============================] - 72s 230ms/step - loss: 2.5384 - accuracy: 0.4598 - top-5-accuracy: 0.7746 - val_loss: 2.7819 - val_accuracy: 0.4180 - val_top-5-accuracy: 0.7388 Epoch 37/50 313/313 [==============================] - 72s 230ms/step - loss: 2.5309 - accuracy: 0.4621 - top-5-accuracy: 0.7758 - val_loss: 2.7605 - val_accuracy: 0.4188 - val_top-5-accuracy: 0.7420 Epoch 38/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5148 - accuracy: 0.4652 - top-5-accuracy: 0.7797 - val_loss: 2.8214 - val_accuracy: 0.4118 - val_top-5-accuracy: 0.7260 Epoch 39/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5127 - accuracy: 0.4687 - top-5-accuracy: 0.7790 - val_loss: 2.7876 - val_accuracy: 0.4229 - val_top-5-accuracy: 0.7320 Epoch 40/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5021 - accuracy: 0.4698 - top-5-accuracy: 0.7801 - val_loss: 2.7975 - val_accuracy: 0.4226 - val_top-5-accuracy: 0.7404 Epoch 41/50 313/313 [==============================] - 71s 228ms/step - loss: 2.5058 - accuracy: 0.4669 - top-5-accuracy: 0.7820 - val_loss: 2.7739 - val_accuracy: 0.4266 - val_top-5-accuracy: 0.7391 Epoch 42/50 313/313 [==============================] - 72s 229ms/step - loss: 2.4906 - accuracy: 0.4733 - top-5-accuracy: 0.7839 - val_loss: 2.8479 - val_accuracy: 0.4131 - val_top-5-accuracy: 0.7234 Epoch 43/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4866 - accuracy: 0.4754 - top-5-accuracy: 0.7871 - val_loss: 2.8122 - val_accuracy: 0.4202 - val_top-5-accuracy: 0.7354 Epoch 44/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4822 - accuracy: 0.4754 - top-5-accuracy: 0.7877 - val_loss: 2.8189 - val_accuracy: 0.4184 - val_top-5-accuracy: 0.7384 Epoch 45/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4738 - accuracy: 0.4770 - top-5-accuracy: 0.7893 - val_loss: 2.9186 - val_accuracy: 0.3983 - val_top-5-accuracy: 0.7077 Epoch 46/50 313/313 [==============================] - 72s 229ms/step - loss: 2.4703 - accuracy: 0.4765 - top-5-accuracy: 0.7897 - val_loss: 2.7892 - val_accuracy: 0.4258 - val_top-5-accuracy: 0.7384 Epoch 47/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4543 - accuracy: 0.4840 - top-5-accuracy: 0.7947 - val_loss: 2.9225 - val_accuracy: 0.4030 - val_top-5-accuracy: 0.7153 Epoch 48/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4579 - accuracy: 0.4818 - top-5-accuracy: 0.7916 - val_loss: 2.8560 - val_accuracy: 0.4144 - val_top-5-accuracy: 0.7309 Epoch 49/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4485 - accuracy: 0.4852 - top-5-accuracy: 0.7945 - val_loss: 2.8210 - val_accuracy: 0.4231 - val_top-5-accuracy: 0.7341 Epoch 50/50 313/313 [==============================] - 71s 228ms/step - loss: 2.4425 - accuracy: 0.4889 - top-5-accuracy: 0.7948 - val_loss: 2.8463 - val_accuracy: 0.4224 - val_top-5-accuracy: 0.7264 CPU times: user 50min 54s, sys: 3min 40s, total: 54min 35s Wall time: 1h 38s
モデルの訓練進捗を可視化しましょう
plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()
CIFAR-100 上のテストの最終的な結果を表示しましょう
loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
313/313 [==============================] - 6s 21ms/step - loss: 2.7574 - accuracy: 0.4391 - top-5-accuracy: 0.7471 Test loss: 2.76 Test accuracy: 43.91% Test top 5 accuracy: 74.71%
313/313 [==============================] - 13s 42ms/step - loss: 2.8036 - accuracy: 0.4333 - top-5-accuracy: 0.7403 Test loss: 2.8 Test accuracy: 43.33% Test top 5 accuracy: 74.03%
EANet は Vit の自己注意を外部注意で置き換えるだけです。従来の Vit は 50 エポックの訓練後に ~73% テスト top-5 精度と ~41 top-1 精度を得ていますが、60 万パラメータを使用しています。同じ実験環境と同じハイパーパラメータのもとで、ちょうど訓練した EANet モデルは 30 万パラメータだけを持ち、それは ~73% テスト top-5 精度と ~43% top-1 精度に導きます。これは外部注意の有効性を十分に実演しています。私達は EANet の訓練プロセスを示しただけです、同じ実験条件のもとで Vit を訓練してテスト結果を観察することができます。
以上