Keras 2 : examples : コンパクトな畳込み Transformer (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/12/2021 (keras 2.6.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Compact Convolutional Transformers (Author: Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
Keras 2 : examples : コンパクトな畳込み Transformer
Description : 効率的な画像分類のためのコンパクトな畳込み Transformer
ビジョン Transformers (VIT) 論文で議論されているように、ビジョンのための Transformer ベースのアーキテクチャは典型的には通常より大規模なデータセット、そして長い事前訓練スケジュールを必要とします。ImageNet-1k (およそ 100 万枚の画像を持ちます) は ViT に関しては中規模サイズのデータに該当すると考えられます。これは主として、CNN とは違い、ViT (典型的な Transformer ベースのアーキテクチャ) は (画像を処理するための畳込みのような) 十分な情報を持つ inductive (誘導的, 帰納的) なバイアスを持たないためです。This begs the question: 畳込みの利点と Transformer の利点を単一ネットワーク・アーキテクチャで組合せられないか?これらの利点はパラメータ効率性、そして long-range とグローバル依存性 (画像の異なる領域間の相互作用) を扱う self-attention を含みます。
Escaping the Big Data Paradigm with Compact Transformers, Hassani et al. では、これを正確に行なうためのアプローチを提示しています。Compact 畳込み Transformer (CCT) アーキテクチャを提案しました。このサンプルでは、CCT の実装で作業して CIFAR-10 データセット上でどのくらい上手く実行するかを見ます。
self-attention や Transformer の概念に馴染みがない場合は、François Chollet の書籍 Deep Learning with Python からの この章 を読むことができます。このサンプルは別のサンプル Image classification with Vision Transformer からのコードスニペットを使用しています。
このサンプルは TensorFlow 2.5 またはそれ以上、そして TensorFlow Addons を必要とします、これは次のコマンドを使用してインストールできます :
!pip install -U -q tensorflow-addons
|████████████████████████████████| 1.1 MB 5.1 MB/s
インポート
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
ハイパーパラメータと定数
positional_emb = True
conv_layers = 2
projection_dim = 128
num_heads = 2
transformer_units = [
projection_dim,
projection_dim,
]
transformer_layers = 2
stochastic_depth_rate = 0.1
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 30
image_size = 32
CIFAR-10 データセットをロードする
num_classes = 10
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 11s 0us/step x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 10) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 10)
CCT トークナイザー
CCT の作者により導入された最初のレシピは画像を処理するためのトークナイザーです。標準的な ViT では、画像は一様なオーバーラップしないパッチに構造化されます。これは異なるパッチ間に存在する境界レベルの情報を除外します。これはニューラルネットワークが位置関係の情報を効果的に利用するために重要です。下図は画像がパッチに構造化される方法の実例を表しています。
畳込みが位置関係の情報を利用するのがかなり上手いことを私達は既に知っています。そこで、これに基づいて、著者は画像パッチを生成するために総て畳込みのミニネットワークを導入しています。
class CCTTokenizer(layers.Layer):
def __init__(
self,
kernel_size=3,
stride=1,
padding=1,
pooling_kernel_size=3,
pooling_stride=2,
num_conv_layers=conv_layers,
num_output_channels=[64, 128],
positional_emb=positional_emb,
**kwargs,
):
super(CCTTokenizer, self).__init__(**kwargs)
# This is our tokenizer.
self.conv_model = keras.Sequential()
for i in range(num_conv_layers):
self.conv_model.add(
layers.Conv2D(
num_output_channels[i],
kernel_size,
stride,
padding="valid",
use_bias=False,
activation="relu",
kernel_initializer="he_normal",
)
)
self.conv_model.add(layers.ZeroPadding2D(padding))
self.conv_model.add(
layers.MaxPool2D(pooling_kernel_size, pooling_stride, "same")
)
self.positional_emb = positional_emb
def call(self, images):
outputs = self.conv_model(images)
# After passing the images through our mini-network the spatial dimensions
# are flattened to form sequences.
reshaped = tf.reshape(
outputs,
(-1, tf.shape(outputs)[1] * tf.shape(outputs)[2], tf.shape(outputs)[-1]),
)
return reshaped
def positional_embedding(self, image_size):
# Positional embeddings are optional in CCT. Here, we calculate
# the number of sequences and initialize an `Embedding` layer to
# compute the positional embeddings later.
if self.positional_emb:
dummy_inputs = tf.ones((1, image_size, image_size, 3))
dummy_outputs = self.call(dummy_inputs)
sequence_length = tf.shape(dummy_outputs)[1]
projection_dim = tf.shape(dummy_outputs)[-1]
embed_layer = layers.Embedding(
input_dim=sequence_length, output_dim=projection_dim
)
return embed_layer, sequence_length
else:
return None
正則化のための確率的 depth
確率的 depth は層のセットをランダムにドロップする正則化テクニックです。推論の間は、層はそのまま保持されます。それは Dropout に非常に類似していますが、層の内部にある個々のノードではなく層のブロック上で作用するという点だけが異なります。CCT では、確率的 depth は Transformer エンコーダの残差ブロックの直前で使用されます。
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super(StochasticDepth, self).__init__(**kwargs)
self.drop_prob = drop_prop
def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
Transformer エンコーダのための MLP
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
データ増強
元の論文 では、著者はより強い正則化を誘導するために AutoAugment を使用しています。このサンプルのためには、ランダムなクロッピングや反転のような標準的な幾何学的増強を使用していきます。
# Note the rescaling layer. These layers have pre-defined inference behavior.
data_augmentation = keras.Sequential(
[
layers.Rescaling(scale=1.0 / 255),
layers.RandomCrop(image_size, image_size),
layers.RandomFlip("horizontal"),
],
name="data_augmentation",
)
最終的な CCT モデル
CCT で導入されたもう一つのレシピは attention プーリング or シークエンス・プーリングです。ViT では、クラストークンに対応する特徴マップだけがプールされてそして続く分類タスク (or 任意の他の下流タスク) のために使用されます。CCT では、Transformer エンコーダからの出力は重み付けられてから最後のタスク固有の層に渡されます (この例では、分類を行ないます)。
def create_cct_model(
image_size=image_size,
input_shape=input_shape,
num_heads=num_heads,
projection_dim=projection_dim,
transformer_units=transformer_units,
):
inputs = layers.Input(input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Encode patches.
cct_tokenizer = CCTTokenizer()
encoded_patches = cct_tokenizer(augmented)
# Apply positional embedding.
if positional_emb:
pos_embed, seq_length = cct_tokenizer.positional_embedding(image_size)
positions = tf.range(start=0, limit=seq_length, delta=1)
position_embeddings = pos_embed(positions)
encoded_patches += position_embeddings
# Calculate Stochastic Depth probabilities.
dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]
# Create multiple layers of the Transformer block.
for i in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
attention_output = StochasticDepth(dpr[i])(attention_output)
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
x3 = StochasticDepth(dpr[i])(x3)
encoded_patches = layers.Add()([x3, x2])
# Apply sequence pooling.
representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
attention_weights = tf.nn.softmax(layers.Dense(1)(representation), axis=1)
weighted_representation = tf.matmul(
attention_weights, representation, transpose_a=True
)
weighted_representation = tf.squeeze(weighted_representation, -2)
# Classify outputs.
logits = layers.Dense(num_classes)(weighted_representation)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=logits)
return model
モデル訓練と評価
def run_experiment(model):
optimizer = tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001)
model.compile(
optimizer=optimizer,
loss=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
metrics=[
keras.metrics.CategoricalAccuracy(name="accuracy"),
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
checkpoint_filepath = "/tmp/checkpoint"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return history
cct_model = create_cct_model()
history = run_experiment(cct_model)
Epoch 1/30 352/352 [==============================] - 10s 18ms/step - loss: 1.9181 - accuracy: 0.3277 - top-5-accuracy: 0.8296 - val_loss: 1.7123 - val_accuracy: 0.4250 - val_top-5-accuracy: 0.9028 Epoch 2/30 352/352 [==============================] - 6s 16ms/step - loss: 1.5725 - accuracy: 0.5010 - top-5-accuracy: 0.9295 - val_loss: 1.5026 - val_accuracy: 0.5530 - val_top-5-accuracy: 0.9364 Epoch 3/30 352/352 [==============================] - 6s 16ms/step - loss: 1.4492 - accuracy: 0.5633 - top-5-accuracy: 0.9476 - val_loss: 1.3744 - val_accuracy: 0.6038 - val_top-5-accuracy: 0.9558 Epoch 4/30 352/352 [==============================] - 6s 16ms/step - loss: 1.3658 - accuracy: 0.6055 - top-5-accuracy: 0.9576 - val_loss: 1.3258 - val_accuracy: 0.6148 - val_top-5-accuracy: 0.9648 Epoch 5/30 352/352 [==============================] - 6s 16ms/step - loss: 1.3142 - accuracy: 0.6302 - top-5-accuracy: 0.9640 - val_loss: 1.2723 - val_accuracy: 0.6468 - val_top-5-accuracy: 0.9710 Epoch 6/30 352/352 [==============================] - 6s 16ms/step - loss: 1.2729 - accuracy: 0.6489 - top-5-accuracy: 0.9684 - val_loss: 1.2490 - val_accuracy: 0.6640 - val_top-5-accuracy: 0.9704 Epoch 7/30 352/352 [==============================] - 6s 16ms/step - loss: 1.2371 - accuracy: 0.6664 - top-5-accuracy: 0.9711 - val_loss: 1.1822 - val_accuracy: 0.6906 - val_top-5-accuracy: 0.9744 Epoch 8/30 352/352 [==============================] - 6s 16ms/step - loss: 1.1899 - accuracy: 0.6942 - top-5-accuracy: 0.9735 - val_loss: 1.1799 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9768 Epoch 9/30 352/352 [==============================] - 6s 16ms/step - loss: 1.1706 - accuracy: 0.6972 - top-5-accuracy: 0.9767 - val_loss: 1.1390 - val_accuracy: 0.7148 - val_top-5-accuracy: 0.9768 Epoch 10/30 352/352 [==============================] - 6s 16ms/step - loss: 1.1524 - accuracy: 0.7054 - top-5-accuracy: 0.9783 - val_loss: 1.1803 - val_accuracy: 0.7000 - val_top-5-accuracy: 0.9740 Epoch 11/30 352/352 [==============================] - 6s 16ms/step - loss: 1.1219 - accuracy: 0.7222 - top-5-accuracy: 0.9798 - val_loss: 1.1066 - val_accuracy: 0.7254 - val_top-5-accuracy: 0.9812 Epoch 12/30 352/352 [==============================] - 6s 16ms/step - loss: 1.1029 - accuracy: 0.7287 - top-5-accuracy: 0.9811 - val_loss: 1.0844 - val_accuracy: 0.7388 - val_top-5-accuracy: 0.9814 Epoch 13/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0841 - accuracy: 0.7380 - top-5-accuracy: 0.9825 - val_loss: 1.1159 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9792 Epoch 14/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0677 - accuracy: 0.7462 - top-5-accuracy: 0.9832 - val_loss: 1.0862 - val_accuracy: 0.7444 - val_top-5-accuracy: 0.9834 Epoch 15/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0511 - accuracy: 0.7535 - top-5-accuracy: 0.9846 - val_loss: 1.0613 - val_accuracy: 0.7494 - val_top-5-accuracy: 0.9832 Epoch 16/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0377 - accuracy: 0.7608 - top-5-accuracy: 0.9854 - val_loss: 1.0379 - val_accuracy: 0.7606 - val_top-5-accuracy: 0.9834 Epoch 17/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0304 - accuracy: 0.7650 - top-5-accuracy: 0.9849 - val_loss: 1.0602 - val_accuracy: 0.7562 - val_top-5-accuracy: 0.9814 Epoch 18/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0121 - accuracy: 0.7746 - top-5-accuracy: 0.9869 - val_loss: 1.0430 - val_accuracy: 0.7630 - val_top-5-accuracy: 0.9834 Epoch 19/30 352/352 [==============================] - 6s 16ms/step - loss: 1.0037 - accuracy: 0.7760 - top-5-accuracy: 0.9872 - val_loss: 1.0951 - val_accuracy: 0.7460 - val_top-5-accuracy: 0.9826 Epoch 20/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9964 - accuracy: 0.7805 - top-5-accuracy: 0.9871 - val_loss: 1.0683 - val_accuracy: 0.7538 - val_top-5-accuracy: 0.9834 Epoch 21/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9838 - accuracy: 0.7850 - top-5-accuracy: 0.9886 - val_loss: 1.0185 - val_accuracy: 0.7770 - val_top-5-accuracy: 0.9876 Epoch 22/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9742 - accuracy: 0.7904 - top-5-accuracy: 0.9894 - val_loss: 1.0253 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9838 Epoch 23/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9662 - accuracy: 0.7935 - top-5-accuracy: 0.9889 - val_loss: 1.0107 - val_accuracy: 0.7786 - val_top-5-accuracy: 0.9860 Epoch 24/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9549 - accuracy: 0.7994 - top-5-accuracy: 0.9897 - val_loss: 1.0089 - val_accuracy: 0.7790 - val_top-5-accuracy: 0.9852 Epoch 25/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9522 - accuracy: 0.8018 - top-5-accuracy: 0.9896 - val_loss: 1.0214 - val_accuracy: 0.7780 - val_top-5-accuracy: 0.9866 Epoch 26/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9469 - accuracy: 0.8023 - top-5-accuracy: 0.9897 - val_loss: 0.9993 - val_accuracy: 0.7816 - val_top-5-accuracy: 0.9882 Epoch 27/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9463 - accuracy: 0.8022 - top-5-accuracy: 0.9906 - val_loss: 1.0071 - val_accuracy: 0.7848 - val_top-5-accuracy: 0.9850 Epoch 28/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9336 - accuracy: 0.8077 - top-5-accuracy: 0.9909 - val_loss: 1.0113 - val_accuracy: 0.7868 - val_top-5-accuracy: 0.9856 Epoch 29/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9352 - accuracy: 0.8071 - top-5-accuracy: 0.9909 - val_loss: 1.0073 - val_accuracy: 0.7856 - val_top-5-accuracy: 0.9830 Epoch 30/30 352/352 [==============================] - 6s 16ms/step - loss: 0.9273 - accuracy: 0.8112 - top-5-accuracy: 0.9908 - val_loss: 1.0144 - val_accuracy: 0.7792 - val_top-5-accuracy: 0.9836 313/313 [==============================] - 2s 6ms/step - loss: 1.0396 - accuracy: 0.7676 - top-5-accuracy: 0.9839 Test accuracy: 76.76% Test top 5 accuracy: 98.39%
(訳注: 実験結果)
Epoch 1/30 352/352 [==============================] - 15s 30ms/step - loss: 1.8991 - accuracy: 0.3368 - top-5-accuracy: 0.8364 - val_loss: 1.5731 - val_accuracy: 0.5028 - val_top-5-accuracy: 0.9304 Epoch 2/30 352/352 [==============================] - 10s 28ms/step - loss: 1.5602 - accuracy: 0.5088 - top-5-accuracy: 0.9332 - val_loss: 1.4789 - val_accuracy: 0.5504 - val_top-5-accuracy: 0.9398 Epoch 3/30 352/352 [==============================] - 10s 28ms/step - loss: 1.4368 - accuracy: 0.5692 - top-5-accuracy: 0.9517 - val_loss: 1.4297 - val_accuracy: 0.5770 - val_top-5-accuracy: 0.9496 Epoch 4/30 352/352 [==============================] - 10s 28ms/step - loss: 1.3666 - accuracy: 0.6047 - top-5-accuracy: 0.9589 - val_loss: 1.3147 - val_accuracy: 0.6190 - val_top-5-accuracy: 0.9686 Epoch 5/30 352/352 [==============================] - 10s 28ms/step - loss: 1.3103 - accuracy: 0.6325 - top-5-accuracy: 0.9640 - val_loss: 1.2908 - val_accuracy: 0.6370 - val_top-5-accuracy: 0.9670 Epoch 6/30 352/352 [==============================] - 10s 28ms/step - loss: 1.2633 - accuracy: 0.6525 - top-5-accuracy: 0.9694 - val_loss: 1.2480 - val_accuracy: 0.6624 - val_top-5-accuracy: 0.9670 Epoch 7/30 352/352 [==============================] - 10s 28ms/step - loss: 1.2261 - accuracy: 0.6739 - top-5-accuracy: 0.9727 - val_loss: 1.2029 - val_accuracy: 0.6882 - val_top-5-accuracy: 0.9730 Epoch 8/30 352/352 [==============================] - 10s 28ms/step - loss: 1.1958 - accuracy: 0.6863 - top-5-accuracy: 0.9747 - val_loss: 1.1896 - val_accuracy: 0.6926 - val_top-5-accuracy: 0.9762 Epoch 9/30 352/352 [==============================] - 10s 28ms/step - loss: 1.1636 - accuracy: 0.7007 - top-5-accuracy: 0.9775 - val_loss: 1.1412 - val_accuracy: 0.7154 - val_top-5-accuracy: 0.9790 Epoch 10/30 352/352 [==============================] - 10s 28ms/step - loss: 1.1439 - accuracy: 0.7121 - top-5-accuracy: 0.9785 - val_loss: 1.1552 - val_accuracy: 0.7046 - val_top-5-accuracy: 0.9788 Epoch 11/30 352/352 [==============================] - 10s 28ms/step - loss: 1.1266 - accuracy: 0.7196 - top-5-accuracy: 0.9799 - val_loss: 1.1423 - val_accuracy: 0.7154 - val_top-5-accuracy: 0.9778 Epoch 12/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0951 - accuracy: 0.7343 - top-5-accuracy: 0.9814 - val_loss: 1.1165 - val_accuracy: 0.7298 - val_top-5-accuracy: 0.9804 Epoch 13/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0906 - accuracy: 0.7367 - top-5-accuracy: 0.9827 - val_loss: 1.0669 - val_accuracy: 0.7510 - val_top-5-accuracy: 0.9820 Epoch 14/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0696 - accuracy: 0.7458 - top-5-accuracy: 0.9828 - val_loss: 1.0707 - val_accuracy: 0.7468 - val_top-5-accuracy: 0.9806 Epoch 15/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0516 - accuracy: 0.7541 - top-5-accuracy: 0.9853 - val_loss: 1.1005 - val_accuracy: 0.7266 - val_top-5-accuracy: 0.9836 Epoch 16/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0401 - accuracy: 0.7607 - top-5-accuracy: 0.9844 - val_loss: 1.0733 - val_accuracy: 0.7588 - val_top-5-accuracy: 0.9810 Epoch 17/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0302 - accuracy: 0.7640 - top-5-accuracy: 0.9857 - val_loss: 1.0731 - val_accuracy: 0.7458 - val_top-5-accuracy: 0.9840 Epoch 18/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0219 - accuracy: 0.7701 - top-5-accuracy: 0.9848 - val_loss: 1.0563 - val_accuracy: 0.7560 - val_top-5-accuracy: 0.9832 Epoch 19/30 352/352 [==============================] - 10s 28ms/step - loss: 1.0132 - accuracy: 0.7719 - top-5-accuracy: 0.9864 - val_loss: 1.0902 - val_accuracy: 0.7448 - val_top-5-accuracy: 0.9806 Epoch 20/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9972 - accuracy: 0.7798 - top-5-accuracy: 0.9864 - val_loss: 1.0518 - val_accuracy: 0.7604 - val_top-5-accuracy: 0.9832 Epoch 21/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9881 - accuracy: 0.7842 - top-5-accuracy: 0.9878 - val_loss: 1.0460 - val_accuracy: 0.7604 - val_top-5-accuracy: 0.9812 Epoch 22/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9837 - accuracy: 0.7844 - top-5-accuracy: 0.9874 - val_loss: 1.0264 - val_accuracy: 0.7712 - val_top-5-accuracy: 0.9838 Epoch 23/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9709 - accuracy: 0.7922 - top-5-accuracy: 0.9892 - val_loss: 1.0193 - val_accuracy: 0.7760 - val_top-5-accuracy: 0.9846 Epoch 24/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9609 - accuracy: 0.7968 - top-5-accuracy: 0.9896 - val_loss: 1.0386 - val_accuracy: 0.7686 - val_top-5-accuracy: 0.9816 Epoch 25/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9498 - accuracy: 0.8014 - top-5-accuracy: 0.9901 - val_loss: 1.0118 - val_accuracy: 0.7820 - val_top-5-accuracy: 0.9854 Epoch 26/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9471 - accuracy: 0.8032 - top-5-accuracy: 0.9893 - val_loss: 1.0069 - val_accuracy: 0.7802 - val_top-5-accuracy: 0.9860 Epoch 27/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9420 - accuracy: 0.8052 - top-5-accuracy: 0.9904 - val_loss: 1.0296 - val_accuracy: 0.7678 - val_top-5-accuracy: 0.9850 Epoch 28/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9410 - accuracy: 0.8061 - top-5-accuracy: 0.9898 - val_loss: 1.0020 - val_accuracy: 0.7818 - val_top-5-accuracy: 0.9842 Epoch 29/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9281 - accuracy: 0.8107 - top-5-accuracy: 0.9904 - val_loss: 1.0106 - val_accuracy: 0.7786 - val_top-5-accuracy: 0.9830 Epoch 30/30 352/352 [==============================] - 10s 28ms/step - loss: 0.9307 - accuracy: 0.8104 - top-5-accuracy: 0.9903 - val_loss: 1.0378 - val_accuracy: 0.7744 - val_top-5-accuracy: 0.9808 313/313 [==============================] - 2s 7ms/step - loss: 1.0255 - accuracy: 0.7751 - top-5-accuracy: 0.9832 Test accuracy: 77.51% Test top 5 accuracy: 98.32% CPU times: user 4min 54s, sys: 16.7 s, total: 5min 11s Wall time: 5min 29s
次はモデルの訓練プロセスを可視化しましょう。
plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()
ちょうど訓練した CCT モデルは僅か 40 万 パラメータで、30 エポック内に ~78% top-1 精度に到達します。上のプロットは過剰適合の兆候も示していません。これは、このネットワークを (おそらくもう少しの正則化とともに) より長い間訓練できてより良い性能さえ得られる可能性があることを意味します。この性能は、コサイン減衰学習率スケジュール、AutoAugment, MixUp or Cutmix のような他のデータ増強テクニックのような追加のレシピで更に改良できます。これらの変更により、著者は CIFAR-10 データセットで 95.1% top-1 精度を提示しています。著者はまた、畳込みブロックの数、Transformer 層, etc. が CCT の最終的な性能にどのように影響するか研究する多くの実験を提示しています。
比較のため、ViT モデルは CIFAR-10 データセット上 78.22% の top-1 精度に到達するためにおよそ 470 万 パラメータと 100 エポック の訓練が必要です。実験のセットアップについて知るには このノートブック を参照できます。
著者はまた NLP タスク上でコンパクトな畳込み Transformer の性能を示していてそこでは競争力のある結果を報告しています。
以上