Keras 2 : examples : 現代的な MLPモデルによる画像分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/01/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Image classification with modern MLP models (Author: Khalid Salama)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : 現代的な MLPモデルによる画像分類
Description: CIFAR-100 画像分類のために MLP-Mixer, FNet と gMLP モデルを実装する。
イントロダクション
このサンプルは画像分類のための 3 つの現代的な attention-free な多層パーセプトロン (MLP) ベースのモデルを実装します、CIFAR-100 データセット上で実演されます :
- Ilya Tolstikhin et al. による MLP-Mixer モデル、2 つのタイプの MLP に基づいています。
- James Lee-Thorp et al. による FNet モデル、unparameterized フーリエ変換に基づいています。
- Hanxiao Liu et al. による gMLP モデル、ゲーティングを持つ MLP に基づいています。
このサンプルの目的はこれらのモデルを比較することではないです、これらのモデルは上手く調整されたハイパーパラメータにより異なるデータセット上では異なって遂行される可能性があるからです。むしろ、それらの主要なビルディングブロックの単純な実装を示すことにあります。
このサンプルは TensorFlow 2.4 またはそれ以上と、TensorFlow Addons を必要とします、これは以下のコマンドでインストールできます :
pip install -U tensorflow-addons
セットアップ
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
データの準備
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
ハイパーパラメータの設定
weight_decay = 0.0001
batch_size = 128
num_epochs = 50
dropout_rate = 0.2
image_size = 64 # We'll resize input images to this size.
patch_size = 8 # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2 # Size of the data array.
embedding_dim = 256 # Number of hidden units.
num_blocks = 4 # Number of blocks.
print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
Image size: 64 X 64 = 4096 Patch size: 8 X 8 = 64 Patches per image: 64 Elements per patch (3 channels): 192
分類モデルの構築
処理ブロックが与えられたときモデルを構築するメソッドを実装します。
def build_classifier(blocks, positional_encoding=False):
inputs = layers.Input(shape=input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = Patches(patch_size, num_patches)(augmented)
# Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
x = layers.Dense(units=embedding_dim)(patches)
if positional_encoding:
positions = tf.range(start=0, limit=num_patches, delta=1)
position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=embedding_dim
)(positions)
x = x + position_embedding
# Process x using the module blocks.
x = blocks(x)
# Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
representation = layers.GlobalAveragePooling1D()(x)
# Apply dropout.
representation = layers.Dropout(rate=dropout_rate)(representation)
# Compute logits outputs.
logits = layers.Dense(num_classes)(representation)
# Create the Keras model.
return keras.Model(inputs=inputs, outputs=logits)
実験の定義
与えられたモデルをコンパイル、訓練そして評価するユティリティ関数を実装します。
def run_experiment(model):
# Create Adam optimizer with weight decay.
optimizer = tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay,
)
# Compile the model.
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
],
)
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=5
)
# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=10, restore_best_weights=True
)
# Fit the model.
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[early_stopping, reduce_lr],
)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
# Return history to plot learning curves.
return history
データ増強の使用
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
パッチ抽出を層として実装する
class Patches(layers.Layer):
def __init__(self, patch_size, num_patches):
super(Patches, self).__init__()
self.patch_size = patch_size
self.num_patches = num_patches
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
return patches
MLP-Mixer モデル
MLP-Mixer は多層パーセプトロン (MLP) だけに基づいたアーキテクチャで、2 つのタイプの MLP 層を含みます :
- 一つは画像パッチに独立的に適用されます、これは位置ごとの特徴をミックスします。
- 他方は (チャネルに沿って) パッチに渡り適用され、これは空間情報をミックスします。
これは Xception モデルのような depthwise に分離可能な畳み込みベースのモデル に似ていますが、2 つの連鎖された dense 変換を持ち、最大プーリングはなく、そしてバッチ正規化の代わりに層正規化があります。
MLP-Mixer モジュールの実装
class MLPMixerLayer(layers.Layer):
def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
super(MLPMixerLayer, self).__init__(*args, **kwargs)
self.mlp1 = keras.Sequential(
[
layers.Dense(units=num_patches),
tfa.layers.GELU(),
layers.Dense(units=num_patches),
layers.Dropout(rate=dropout_rate),
]
)
self.mlp2 = keras.Sequential(
[
layers.Dense(units=num_patches),
tfa.layers.GELU(),
layers.Dense(units=embedding_dim),
layers.Dropout(rate=dropout_rate),
]
)
self.normalize = layers.LayerNormalization(epsilon=1e-6)
def call(self, inputs):
# Apply layer normalization.
x = self.normalize(inputs)
# Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
x_channels = tf.linalg.matrix_transpose(x)
# Apply mlp1 on each channel independently.
mlp1_outputs = self.mlp1(x_channels)
# Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
mlp1_outputs = tf.linalg.matrix_transpose(mlp1_outputs)
# Add skip connection.
x = mlp1_outputs + inputs
# Apply layer normalization.
x_patches = self.normalize(x)
# Apply mlp2 on each patch independtenly.
mlp2_outputs = self.mlp2(x_patches)
# Add skip connection.
x = x + mlp2_outputs
return x
MLP-Mixer モデルを構築、訓練そして評価する
現在の設定でのモデルの訓練は V100 GPU 上でエポック毎におよそ 8 秒かかることに注意してください。
mlpmixer_blocks = keras.Sequential(
[MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
history = run_experiment(mlpmixer_classifier)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:390: UserWarning: Default value of `approximate` is changed from `True` to `False` return py_builtins.overload_of(f)(*args) Epoch 1/50 352/352 [==============================] - 13s 25ms/step - loss: 4.1703 - acc: 0.0756 - top5-acc: 0.2322 - val_loss: 3.6202 - val_acc: 0.1532 - val_top5-acc: 0.4140 Epoch 2/50 352/352 [==============================] - 8s 23ms/step - loss: 3.4165 - acc: 0.1789 - top5-acc: 0.4459 - val_loss: 3.1599 - val_acc: 0.2334 - val_top5-acc: 0.5160 Epoch 3/50 352/352 [==============================] - 8s 23ms/step - loss: 3.1367 - acc: 0.2328 - top5-acc: 0.5230 - val_loss: 3.0539 - val_acc: 0.2560 - val_top5-acc: 0.5664 Epoch 4/50 352/352 [==============================] - 8s 23ms/step - loss: 2.9985 - acc: 0.2624 - top5-acc: 0.5600 - val_loss: 2.9498 - val_acc: 0.2798 - val_top5-acc: 0.5856 Epoch 5/50 352/352 [==============================] - 8s 23ms/step - loss: 2.8806 - acc: 0.2809 - top5-acc: 0.5879 - val_loss: 2.8593 - val_acc: 0.2904 - val_top5-acc: 0.6050 Epoch 6/50 352/352 [==============================] - 8s 23ms/step - loss: 2.7860 - acc: 0.3024 - top5-acc: 0.6124 - val_loss: 2.7405 - val_acc: 0.3256 - val_top5-acc: 0.6364 Epoch 7/50 352/352 [==============================] - 8s 23ms/step - loss: 2.7065 - acc: 0.3152 - top5-acc: 0.6280 - val_loss: 2.7548 - val_acc: 0.3328 - val_top5-acc: 0.6450 Epoch 8/50 352/352 [==============================] - 8s 22ms/step - loss: 2.6443 - acc: 0.3263 - top5-acc: 0.6446 - val_loss: 2.6618 - val_acc: 0.3460 - val_top5-acc: 0.6578 Epoch 9/50 352/352 [==============================] - 8s 23ms/step - loss: 2.5886 - acc: 0.3406 - top5-acc: 0.6573 - val_loss: 2.6065 - val_acc: 0.3492 - val_top5-acc: 0.6650 Epoch 10/50 352/352 [==============================] - 8s 23ms/step - loss: 2.5798 - acc: 0.3404 - top5-acc: 0.6591 - val_loss: 2.6546 - val_acc: 0.3502 - val_top5-acc: 0.6630 Epoch 11/50 352/352 [==============================] - 8s 23ms/step - loss: 2.5269 - acc: 0.3498 - top5-acc: 0.6714 - val_loss: 2.6201 - val_acc: 0.3570 - val_top5-acc: 0.6710 Epoch 12/50 352/352 [==============================] - 8s 23ms/step - loss: 2.5003 - acc: 0.3569 - top5-acc: 0.6745 - val_loss: 2.5936 - val_acc: 0.3564 - val_top5-acc: 0.6662 Epoch 13/50 352/352 [==============================] - 8s 22ms/step - loss: 2.4801 - acc: 0.3619 - top5-acc: 0.6792 - val_loss: 2.5236 - val_acc: 0.3700 - val_top5-acc: 0.6786 Epoch 14/50 352/352 [==============================] - 8s 23ms/step - loss: 2.4392 - acc: 0.3676 - top5-acc: 0.6879 - val_loss: 2.4971 - val_acc: 0.3808 - val_top5-acc: 0.6926 Epoch 15/50 352/352 [==============================] - 8s 23ms/step - loss: 2.4073 - acc: 0.3790 - top5-acc: 0.6940 - val_loss: 2.5972 - val_acc: 0.3682 - val_top5-acc: 0.6750 Epoch 16/50 352/352 [==============================] - 8s 23ms/step - loss: 2.3922 - acc: 0.3754 - top5-acc: 0.6980 - val_loss: 2.4317 - val_acc: 0.3964 - val_top5-acc: 0.6992 Epoch 17/50 352/352 [==============================] - 8s 22ms/step - loss: 2.3603 - acc: 0.3891 - top5-acc: 0.7038 - val_loss: 2.4844 - val_acc: 0.3766 - val_top5-acc: 0.6964 Epoch 18/50 352/352 [==============================] - 8s 23ms/step - loss: 2.3560 - acc: 0.3849 - top5-acc: 0.7056 - val_loss: 2.4564 - val_acc: 0.3910 - val_top5-acc: 0.6990 Epoch 19/50 352/352 [==============================] - 8s 23ms/step - loss: 2.3367 - acc: 0.3900 - top5-acc: 0.7069 - val_loss: 2.4282 - val_acc: 0.3906 - val_top5-acc: 0.7058 Epoch 20/50 352/352 [==============================] - 8s 22ms/step - loss: 2.3096 - acc: 0.3945 - top5-acc: 0.7180 - val_loss: 2.4297 - val_acc: 0.3930 - val_top5-acc: 0.7082 Epoch 21/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2935 - acc: 0.3996 - top5-acc: 0.7211 - val_loss: 2.4053 - val_acc: 0.3974 - val_top5-acc: 0.7076 Epoch 22/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2823 - acc: 0.3991 - top5-acc: 0.7248 - val_loss: 2.4756 - val_acc: 0.3920 - val_top5-acc: 0.6988 Epoch 23/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2371 - acc: 0.4126 - top5-acc: 0.7294 - val_loss: 2.3802 - val_acc: 0.3972 - val_top5-acc: 0.7100 Epoch 24/50 352/352 [==============================] - 8s 23ms/step - loss: 2.2234 - acc: 0.4140 - top5-acc: 0.7336 - val_loss: 2.4402 - val_acc: 0.3994 - val_top5-acc: 0.7096 Epoch 25/50 352/352 [==============================] - 8s 23ms/step - loss: 2.2320 - acc: 0.4088 - top5-acc: 0.7333 - val_loss: 2.4343 - val_acc: 0.3936 - val_top5-acc: 0.7052 Epoch 26/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2094 - acc: 0.4193 - top5-acc: 0.7347 - val_loss: 2.4154 - val_acc: 0.4058 - val_top5-acc: 0.7192 Epoch 27/50 352/352 [==============================] - 8s 23ms/step - loss: 2.2029 - acc: 0.4180 - top5-acc: 0.7370 - val_loss: 2.3116 - val_acc: 0.4226 - val_top5-acc: 0.7268 Epoch 28/50 352/352 [==============================] - 8s 23ms/step - loss: 2.1959 - acc: 0.4234 - top5-acc: 0.7380 - val_loss: 2.4053 - val_acc: 0.4064 - val_top5-acc: 0.7168 Epoch 29/50 352/352 [==============================] - 8s 23ms/step - loss: 2.1815 - acc: 0.4227 - top5-acc: 0.7415 - val_loss: 2.4020 - val_acc: 0.4078 - val_top5-acc: 0.7192 Epoch 30/50 352/352 [==============================] - 8s 23ms/step - loss: 2.1783 - acc: 0.4245 - top5-acc: 0.7407 - val_loss: 2.4206 - val_acc: 0.3996 - val_top5-acc: 0.7234 Epoch 31/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1686 - acc: 0.4248 - top5-acc: 0.7442 - val_loss: 2.3743 - val_acc: 0.4100 - val_top5-acc: 0.7162 Epoch 32/50 352/352 [==============================] - 8s 23ms/step - loss: 2.1487 - acc: 0.4317 - top5-acc: 0.7472 - val_loss: 2.3882 - val_acc: 0.4018 - val_top5-acc: 0.7266 Epoch 33/50 352/352 [==============================] - 8s 22ms/step - loss: 1.9836 - acc: 0.4644 - top5-acc: 0.7782 - val_loss: 2.1742 - val_acc: 0.4536 - val_top5-acc: 0.7506 Epoch 34/50 352/352 [==============================] - 8s 23ms/step - loss: 1.8723 - acc: 0.4950 - top5-acc: 0.7985 - val_loss: 2.1716 - val_acc: 0.4506 - val_top5-acc: 0.7546 Epoch 35/50 352/352 [==============================] - 8s 23ms/step - loss: 1.8461 - acc: 0.5009 - top5-acc: 0.8003 - val_loss: 2.1661 - val_acc: 0.4480 - val_top5-acc: 0.7542 Epoch 36/50 352/352 [==============================] - 8s 23ms/step - loss: 1.8499 - acc: 0.4944 - top5-acc: 0.8044 - val_loss: 2.1523 - val_acc: 0.4566 - val_top5-acc: 0.7628 Epoch 37/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8322 - acc: 0.5000 - top5-acc: 0.8059 - val_loss: 2.1334 - val_acc: 0.4570 - val_top5-acc: 0.7560 Epoch 38/50 352/352 [==============================] - 8s 23ms/step - loss: 1.8269 - acc: 0.5027 - top5-acc: 0.8085 - val_loss: 2.1024 - val_acc: 0.4614 - val_top5-acc: 0.7674 Epoch 39/50 352/352 [==============================] - 8s 23ms/step - loss: 1.8242 - acc: 0.4990 - top5-acc: 0.8098 - val_loss: 2.0789 - val_acc: 0.4610 - val_top5-acc: 0.7792 Epoch 40/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7983 - acc: 0.5067 - top5-acc: 0.8122 - val_loss: 2.1514 - val_acc: 0.4546 - val_top5-acc: 0.7628 Epoch 41/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7974 - acc: 0.5112 - top5-acc: 0.8132 - val_loss: 2.1425 - val_acc: 0.4542 - val_top5-acc: 0.7630 Epoch 42/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7972 - acc: 0.5128 - top5-acc: 0.8127 - val_loss: 2.0980 - val_acc: 0.4580 - val_top5-acc: 0.7724 Epoch 43/50 352/352 [==============================] - 8s 23ms/step - loss: 1.8026 - acc: 0.5066 - top5-acc: 0.8115 - val_loss: 2.0922 - val_acc: 0.4684 - val_top5-acc: 0.7678 Epoch 44/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7924 - acc: 0.5092 - top5-acc: 0.8129 - val_loss: 2.0511 - val_acc: 0.4750 - val_top5-acc: 0.7726 Epoch 45/50 352/352 [==============================] - 8s 22ms/step - loss: 1.7695 - acc: 0.5106 - top5-acc: 0.8193 - val_loss: 2.0949 - val_acc: 0.4678 - val_top5-acc: 0.7708 Epoch 46/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7784 - acc: 0.5106 - top5-acc: 0.8141 - val_loss: 2.1094 - val_acc: 0.4656 - val_top5-acc: 0.7704 Epoch 47/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7625 - acc: 0.5155 - top5-acc: 0.8190 - val_loss: 2.0492 - val_acc: 0.4774 - val_top5-acc: 0.7744 Epoch 48/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7441 - acc: 0.5217 - top5-acc: 0.8190 - val_loss: 2.0562 - val_acc: 0.4698 - val_top5-acc: 0.7828 Epoch 49/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7665 - acc: 0.5113 - top5-acc: 0.8196 - val_loss: 2.0348 - val_acc: 0.4708 - val_top5-acc: 0.7730 Epoch 50/50 352/352 [==============================] - 8s 23ms/step - loss: 1.7392 - acc: 0.5201 - top5-acc: 0.8226 - val_loss: 2.0787 - val_acc: 0.4710 - val_top5-acc: 0.7734 313/313 [==============================] - 2s 8ms/step - loss: 2.0571 - acc: 0.4758 - top5-acc: 0.7718 Test accuracy: 47.58% Test top 5 accuracy: 77.18%
(訳者注: 実験結果)
Epoch 1/50 352/352 [==============================] - 16s 30ms/step - loss: 3.8697 - acc: 0.1100 - top5-acc: 0.3170 - val_loss: 3.4557 - val_acc: 0.1790 - val_top5-acc: 0.4558 - lr: 0.0050 Epoch 2/50 352/352 [==============================] - 10s 28ms/step - loss: 3.3958 - acc: 0.1858 - top5-acc: 0.4552 - val_loss: 3.2891 - val_acc: 0.2098 - val_top5-acc: 0.4870 - lr: 0.0050 Epoch 3/50 352/352 [==============================] - 10s 28ms/step - loss: 3.1991 - acc: 0.2219 - top5-acc: 0.5061 - val_loss: 3.1089 - val_acc: 0.2374 - val_top5-acc: 0.5406 - lr: 0.0050 Epoch 4/50 352/352 [==============================] - 10s 28ms/step - loss: 3.0485 - acc: 0.2494 - top5-acc: 0.5481 - val_loss: 2.9650 - val_acc: 0.2732 - val_top5-acc: 0.5770 - lr: 0.0050 Epoch 5/50 352/352 [==============================] - 10s 28ms/step - loss: 2.9354 - acc: 0.2733 - top5-acc: 0.5764 - val_loss: 2.8154 - val_acc: 0.2944 - val_top5-acc: 0.6134 - lr: 0.0050 Epoch 6/50 352/352 [==============================] - 10s 29ms/step - loss: 2.8340 - acc: 0.2908 - top5-acc: 0.5999 - val_loss: 2.9185 - val_acc: 0.2928 - val_top5-acc: 0.6002 - lr: 0.0050 Epoch 7/50 352/352 [==============================] - 10s 28ms/step - loss: 2.7587 - acc: 0.3058 - top5-acc: 0.6162 - val_loss: 2.7058 - val_acc: 0.3280 - val_top5-acc: 0.6354 - lr: 0.0050 Epoch 8/50 352/352 [==============================] - 10s 28ms/step - loss: 2.7089 - acc: 0.3141 - top5-acc: 0.6285 - val_loss: 2.6579 - val_acc: 0.3394 - val_top5-acc: 0.6520 - lr: 0.0050 Epoch 9/50 352/352 [==============================] - 10s 28ms/step - loss: 2.6591 - acc: 0.3260 - top5-acc: 0.6398 - val_loss: 2.5549 - val_acc: 0.3518 - val_top5-acc: 0.6688 - lr: 0.0050 Epoch 10/50 352/352 [==============================] - 10s 28ms/step - loss: 2.6103 - acc: 0.3338 - top5-acc: 0.6496 - val_loss: 2.5618 - val_acc: 0.3578 - val_top5-acc: 0.6726 - lr: 0.0050 Epoch 11/50 352/352 [==============================] - 10s 28ms/step - loss: 2.5748 - acc: 0.3424 - top5-acc: 0.6580 - val_loss: 2.5788 - val_acc: 0.3588 - val_top5-acc: 0.6784 - lr: 0.0050 Epoch 12/50 352/352 [==============================] - 10s 29ms/step - loss: 2.5514 - acc: 0.3459 - top5-acc: 0.6642 - val_loss: 2.5447 - val_acc: 0.3616 - val_top5-acc: 0.6782 - lr: 0.0050 Epoch 13/50 352/352 [==============================] - 10s 29ms/step - loss: 2.5192 - acc: 0.3520 - top5-acc: 0.6708 - val_loss: 2.5563 - val_acc: 0.3616 - val_top5-acc: 0.6804 - lr: 0.0050 Epoch 14/50 352/352 [==============================] - 10s 29ms/step - loss: 2.4915 - acc: 0.3598 - top5-acc: 0.6774 - val_loss: 2.6326 - val_acc: 0.3536 - val_top5-acc: 0.6648 - lr: 0.0050 Epoch 15/50 352/352 [==============================] - 10s 29ms/step - loss: 2.4594 - acc: 0.3673 - top5-acc: 0.6853 - val_loss: 2.5152 - val_acc: 0.3662 - val_top5-acc: 0.6824 - lr: 0.0050 Epoch 16/50 352/352 [==============================] - 10s 29ms/step - loss: 2.4265 - acc: 0.3724 - top5-acc: 0.6920 - val_loss: 2.5100 - val_acc: 0.3744 - val_top5-acc: 0.6898 - lr: 0.0050 Epoch 17/50 352/352 [==============================] - 10s 29ms/step - loss: 2.4037 - acc: 0.3766 - top5-acc: 0.6960 - val_loss: 2.4445 - val_acc: 0.3796 - val_top5-acc: 0.7012 - lr: 0.0050 Epoch 18/50 352/352 [==============================] - 10s 29ms/step - loss: 2.3910 - acc: 0.3811 - top5-acc: 0.6981 - val_loss: 2.4771 - val_acc: 0.3872 - val_top5-acc: 0.6934 - lr: 0.0050 Epoch 19/50 352/352 [==============================] - 10s 29ms/step - loss: 2.3566 - acc: 0.3860 - top5-acc: 0.7044 - val_loss: 2.4498 - val_acc: 0.3874 - val_top5-acc: 0.7038 - lr: 0.0050 Epoch 20/50 352/352 [==============================] - 10s 29ms/step - loss: 2.3380 - acc: 0.3894 - top5-acc: 0.7100 - val_loss: 2.4931 - val_acc: 0.3750 - val_top5-acc: 0.6974 - lr: 0.0050 Epoch 21/50 352/352 [==============================] - 10s 29ms/step - loss: 2.3364 - acc: 0.3916 - top5-acc: 0.7092 - val_loss: 2.4935 - val_acc: 0.3856 - val_top5-acc: 0.6998 - lr: 0.0050 Epoch 22/50 352/352 [==============================] - 10s 29ms/step - loss: 2.3147 - acc: 0.3958 - top5-acc: 0.7153 - val_loss: 2.3627 - val_acc: 0.4064 - val_top5-acc: 0.7164 - lr: 0.0050 Epoch 23/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2925 - acc: 0.3992 - top5-acc: 0.7216 - val_loss: 2.4123 - val_acc: 0.3942 - val_top5-acc: 0.7156 - lr: 0.0050 Epoch 24/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2822 - acc: 0.4033 - top5-acc: 0.7220 - val_loss: 2.3777 - val_acc: 0.4038 - val_top5-acc: 0.7148 - lr: 0.0050 Epoch 25/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2657 - acc: 0.4042 - top5-acc: 0.7254 - val_loss: 2.3927 - val_acc: 0.4036 - val_top5-acc: 0.7160 - lr: 0.0050 Epoch 26/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2545 - acc: 0.4058 - top5-acc: 0.7267 - val_loss: 2.3147 - val_acc: 0.4158 - val_top5-acc: 0.7296 - lr: 0.0050 Epoch 27/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2393 - acc: 0.4098 - top5-acc: 0.7288 - val_loss: 2.4425 - val_acc: 0.4008 - val_top5-acc: 0.7088 - lr: 0.0050 Epoch 28/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2324 - acc: 0.4099 - top5-acc: 0.7314 - val_loss: 2.4051 - val_acc: 0.3982 - val_top5-acc: 0.7092 - lr: 0.0050 Epoch 29/50 352/352 [==============================] - 10s 29ms/step - loss: 2.2190 - acc: 0.4152 - top5-acc: 0.7347 - val_loss: 2.4149 - val_acc: 0.4068 - val_top5-acc: 0.7178 - lr: 0.0050 Epoch 30/50 352/352 [==============================] - 10s 28ms/step - loss: 2.2246 - acc: 0.4142 - top5-acc: 0.7332 - val_loss: 2.2987 - val_acc: 0.4204 - val_top5-acc: 0.7294 - lr: 0.0050 Epoch 31/50 352/352 [==============================] - 10s 28ms/step - loss: 2.2001 - acc: 0.4194 - top5-acc: 0.7410 - val_loss: 2.3440 - val_acc: 0.4222 - val_top5-acc: 0.7320 - lr: 0.0050 Epoch 32/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1882 - acc: 0.4216 - top5-acc: 0.7411 - val_loss: 2.3560 - val_acc: 0.4172 - val_top5-acc: 0.7244 - lr: 0.0050 Epoch 33/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1878 - acc: 0.4191 - top5-acc: 0.7403 - val_loss: 2.2540 - val_acc: 0.4294 - val_top5-acc: 0.7408 - lr: 0.0050 Epoch 34/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1727 - acc: 0.4242 - top5-acc: 0.7442 - val_loss: 2.3553 - val_acc: 0.4074 - val_top5-acc: 0.7244 - lr: 0.0050 Epoch 35/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1653 - acc: 0.4285 - top5-acc: 0.7452 - val_loss: 2.3520 - val_acc: 0.4132 - val_top5-acc: 0.7292 - lr: 0.0050 Epoch 36/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1532 - acc: 0.4292 - top5-acc: 0.7484 - val_loss: 2.3275 - val_acc: 0.4166 - val_top5-acc: 0.7382 - lr: 0.0050 Epoch 37/50 352/352 [==============================] - 10s 29ms/step - loss: 2.1571 - acc: 0.4286 - top5-acc: 0.7473 - val_loss: 2.2717 - val_acc: 0.4332 - val_top5-acc: 0.7410 - lr: 0.0050 Epoch 38/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1463 - acc: 0.4312 - top5-acc: 0.7505 - val_loss: 2.4053 - val_acc: 0.4074 - val_top5-acc: 0.7196 - lr: 0.0050 Epoch 39/50 352/352 [==============================] - 10s 29ms/step - loss: 1.9158 - acc: 0.4834 - top5-acc: 0.7912 - val_loss: 2.0828 - val_acc: 0.4694 - val_top5-acc: 0.7694 - lr: 0.0025 Epoch 40/50 352/352 [==============================] - 10s 29ms/step - loss: 1.8799 - acc: 0.4912 - top5-acc: 0.7959 - val_loss: 2.0889 - val_acc: 0.4712 - val_top5-acc: 0.7736 - lr: 0.0025 Epoch 41/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8554 - acc: 0.4949 - top5-acc: 0.8021 - val_loss: 2.0817 - val_acc: 0.4624 - val_top5-acc: 0.7688 - lr: 0.0025 Epoch 42/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8454 - acc: 0.4981 - top5-acc: 0.8037 - val_loss: 2.0485 - val_acc: 0.4694 - val_top5-acc: 0.7748 - lr: 0.0025 Epoch 43/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8385 - acc: 0.4971 - top5-acc: 0.8052 - val_loss: 2.0994 - val_acc: 0.4650 - val_top5-acc: 0.7654 - lr: 0.0025 Epoch 44/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8358 - acc: 0.4994 - top5-acc: 0.8036 - val_loss: 2.0693 - val_acc: 0.4678 - val_top5-acc: 0.7826 - lr: 0.0025 Epoch 45/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8179 - acc: 0.5031 - top5-acc: 0.8076 - val_loss: 2.0791 - val_acc: 0.4644 - val_top5-acc: 0.7726 - lr: 0.0025 Epoch 46/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8228 - acc: 0.5016 - top5-acc: 0.8093 - val_loss: 2.0768 - val_acc: 0.4758 - val_top5-acc: 0.7756 - lr: 0.0025 Epoch 47/50 352/352 [==============================] - 10s 28ms/step - loss: 1.8137 - acc: 0.5042 - top5-acc: 0.8071 - val_loss: 2.0507 - val_acc: 0.4710 - val_top5-acc: 0.7770 - lr: 0.0025 Epoch 48/50 352/352 [==============================] - 10s 28ms/step - loss: 1.6814 - acc: 0.5348 - top5-acc: 0.8324 - val_loss: 2.0001 - val_acc: 0.4824 - val_top5-acc: 0.7834 - lr: 0.0012 Epoch 49/50 352/352 [==============================] - 10s 28ms/step - loss: 1.6568 - acc: 0.5461 - top5-acc: 0.8350 - val_loss: 1.9448 - val_acc: 0.4884 - val_top5-acc: 0.7954 - lr: 0.0012 Epoch 50/50 352/352 [==============================] - 10s 28ms/step - loss: 1.6484 - acc: 0.5426 - top5-acc: 0.8374 - val_loss: 1.9811 - val_acc: 0.4918 - val_top5-acc: 0.7906 - lr: 0.0012 313/313 [==============================] - 4s 12ms/step - loss: 1.9585 - acc: 0.4976 - top5-acc: 0.7859 Test accuracy: 49.76% Test top 5 accuracy: 78.59% CPU times: user 10min 17s, sys: 49 s, total: 11min 6s Wall time: 8min 32s
MLP-Mixer モデルは畳み込みと transformer ベースモデルに比べて少ない数のパラメータを持つ傾向にあり、これは少ない訓練と計算コストでサーブすることに繋がります。
MLP-Mixer 論文で述べられているように、大規模データセット上で事前訓練されたときや、現代的な正則化スキームを使用するとき、MLP-Mixer は最先端モデルに匹敵するスコアを達成します。埋め込み次元を増やし、mixer ブロックの数を増やし、そしてモデルをより長く訓練することでより良い結果を得られます。入力画像のサイズを大きくして異なるパッチサイズを使用することを試しても良いでしょう。
FNet モデル
FNet は Transformer ブロックに類似のブロックを使用します。けれども、FNet は Transformer ブロックの自己注意層をパラメータフリーな 2D フーリエ変換層と置き換えます :
- 一つの 1D フーリエ変換はパッチに沿って適用されます。
- 一つの 1D フーリエ変換はチャネルに沿って適用されます。
FNet モジュールの実装
class FNetLayer(layers.Layer):
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
super(FNetLayer, self).__init__(*args, **kwargs)
self.ffn = keras.Sequential(
[
layers.Dense(units=embedding_dim),
tfa.layers.GELU(),
layers.Dropout(rate=dropout_rate),
layers.Dense(units=embedding_dim),
]
)
self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
def call(self, inputs):
# Apply fourier transformations.
x = tf.cast(
tf.signal.fft2d(tf.cast(inputs, dtype=tf.dtypes.complex64)),
dtype=tf.dtypes.float32,
)
# Add skip connection.
x = x + inputs
# Apply layer normalization.
x = self.normalize1(x)
# Apply Feedfowrad network.
x_ffn = self.ffn(x)
# Add skip connection.
x = x + x_ffn
# Apply layer normalization.
return self.normalize2(x)
FNet モデルの構築、訓練と評価
現在の設定でのモデルの訓練は V100 GPU 上でエポック毎におよそ 8 秒かかることに注意してください。
fnet_blocks = keras.Sequential(
[FNetLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.001
fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)
history = run_experiment(fnet_classifier)
Epoch 1/50 352/352 [==============================] - 11s 23ms/step - loss: 4.3419 - acc: 0.0470 - top5-acc: 0.1652 - val_loss: 3.8279 - val_acc: 0.1178 - val_top5-acc: 0.3268 Epoch 2/50 352/352 [==============================] - 8s 22ms/step - loss: 3.7814 - acc: 0.1202 - top5-acc: 0.3341 - val_loss: 3.5981 - val_acc: 0.1540 - val_top5-acc: 0.3914 Epoch 3/50 352/352 [==============================] - 8s 22ms/step - loss: 3.5319 - acc: 0.1603 - top5-acc: 0.4086 - val_loss: 3.3309 - val_acc: 0.1956 - val_top5-acc: 0.4656 Epoch 4/50 352/352 [==============================] - 8s 22ms/step - loss: 3.3025 - acc: 0.2001 - top5-acc: 0.4730 - val_loss: 3.1215 - val_acc: 0.2334 - val_top5-acc: 0.5234 Epoch 5/50 352/352 [==============================] - 8s 22ms/step - loss: 3.1621 - acc: 0.2224 - top5-acc: 0.5084 - val_loss: 3.0492 - val_acc: 0.2456 - val_top5-acc: 0.5322 Epoch 6/50 352/352 [==============================] - 8s 22ms/step - loss: 3.0506 - acc: 0.2469 - top5-acc: 0.5400 - val_loss: 2.9519 - val_acc: 0.2684 - val_top5-acc: 0.5652 Epoch 7/50 352/352 [==============================] - 8s 22ms/step - loss: 2.9520 - acc: 0.2618 - top5-acc: 0.5677 - val_loss: 2.8936 - val_acc: 0.2688 - val_top5-acc: 0.5864 Epoch 8/50 352/352 [==============================] - 8s 22ms/step - loss: 2.8377 - acc: 0.2828 - top5-acc: 0.5938 - val_loss: 2.7633 - val_acc: 0.2996 - val_top5-acc: 0.6068 Epoch 9/50 352/352 [==============================] - 8s 22ms/step - loss: 2.7670 - acc: 0.2969 - top5-acc: 0.6107 - val_loss: 2.7309 - val_acc: 0.3112 - val_top5-acc: 0.6136 Epoch 10/50 352/352 [==============================] - 8s 22ms/step - loss: 2.7027 - acc: 0.3148 - top5-acc: 0.6231 - val_loss: 2.6552 - val_acc: 0.3214 - val_top5-acc: 0.6436 Epoch 11/50 352/352 [==============================] - 8s 22ms/step - loss: 2.6375 - acc: 0.3256 - top5-acc: 0.6427 - val_loss: 2.6078 - val_acc: 0.3278 - val_top5-acc: 0.6434 Epoch 12/50 352/352 [==============================] - 8s 22ms/step - loss: 2.5573 - acc: 0.3424 - top5-acc: 0.6576 - val_loss: 2.5617 - val_acc: 0.3438 - val_top5-acc: 0.6534 Epoch 13/50 352/352 [==============================] - 8s 22ms/step - loss: 2.5259 - acc: 0.3488 - top5-acc: 0.6640 - val_loss: 2.5177 - val_acc: 0.3550 - val_top5-acc: 0.6652 Epoch 14/50 352/352 [==============================] - 8s 22ms/step - loss: 2.4782 - acc: 0.3586 - top5-acc: 0.6739 - val_loss: 2.5113 - val_acc: 0.3558 - val_top5-acc: 0.6718 Epoch 15/50 352/352 [==============================] - 8s 22ms/step - loss: 2.4242 - acc: 0.3712 - top5-acc: 0.6897 - val_loss: 2.4280 - val_acc: 0.3724 - val_top5-acc: 0.6880 Epoch 16/50 352/352 [==============================] - 8s 22ms/step - loss: 2.3884 - acc: 0.3741 - top5-acc: 0.6967 - val_loss: 2.4670 - val_acc: 0.3654 - val_top5-acc: 0.6794 Epoch 17/50 352/352 [==============================] - 8s 22ms/step - loss: 2.3619 - acc: 0.3797 - top5-acc: 0.7001 - val_loss: 2.3941 - val_acc: 0.3752 - val_top5-acc: 0.6922 Epoch 18/50 352/352 [==============================] - 8s 22ms/step - loss: 2.3183 - acc: 0.3931 - top5-acc: 0.7137 - val_loss: 2.4028 - val_acc: 0.3814 - val_top5-acc: 0.6954 Epoch 19/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2919 - acc: 0.3955 - top5-acc: 0.7209 - val_loss: 2.3672 - val_acc: 0.3878 - val_top5-acc: 0.7022 Epoch 20/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2612 - acc: 0.4038 - top5-acc: 0.7224 - val_loss: 2.3529 - val_acc: 0.3954 - val_top5-acc: 0.6934 Epoch 21/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2416 - acc: 0.4068 - top5-acc: 0.7262 - val_loss: 2.3014 - val_acc: 0.3980 - val_top5-acc: 0.7158 Epoch 22/50 352/352 [==============================] - 8s 22ms/step - loss: 2.2087 - acc: 0.4162 - top5-acc: 0.7359 - val_loss: 2.2904 - val_acc: 0.4062 - val_top5-acc: 0.7120 Epoch 23/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1803 - acc: 0.4200 - top5-acc: 0.7442 - val_loss: 2.3181 - val_acc: 0.4096 - val_top5-acc: 0.7120 Epoch 24/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1718 - acc: 0.4246 - top5-acc: 0.7403 - val_loss: 2.2687 - val_acc: 0.4094 - val_top5-acc: 0.7234 Epoch 25/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1559 - acc: 0.4198 - top5-acc: 0.7458 - val_loss: 2.2730 - val_acc: 0.4060 - val_top5-acc: 0.7190 Epoch 26/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1285 - acc: 0.4300 - top5-acc: 0.7495 - val_loss: 2.2566 - val_acc: 0.4082 - val_top5-acc: 0.7306 Epoch 27/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1118 - acc: 0.4386 - top5-acc: 0.7538 - val_loss: 2.2544 - val_acc: 0.4178 - val_top5-acc: 0.7218 Epoch 28/50 352/352 [==============================] - 8s 22ms/step - loss: 2.1007 - acc: 0.4408 - top5-acc: 0.7562 - val_loss: 2.2703 - val_acc: 0.4136 - val_top5-acc: 0.7172 Epoch 29/50 352/352 [==============================] - 8s 22ms/step - loss: 2.0707 - acc: 0.4446 - top5-acc: 0.7634 - val_loss: 2.2244 - val_acc: 0.4168 - val_top5-acc: 0.7332 Epoch 30/50 352/352 [==============================] - 8s 22ms/step - loss: 2.0694 - acc: 0.4428 - top5-acc: 0.7611 - val_loss: 2.2557 - val_acc: 0.4060 - val_top5-acc: 0.7270 Epoch 31/50 352/352 [==============================] - 8s 22ms/step - loss: 2.0485 - acc: 0.4502 - top5-acc: 0.7672 - val_loss: 2.2192 - val_acc: 0.4214 - val_top5-acc: 0.7308 Epoch 32/50 352/352 [==============================] - 8s 22ms/step - loss: 2.0105 - acc: 0.4617 - top5-acc: 0.7718 - val_loss: 2.2065 - val_acc: 0.4222 - val_top5-acc: 0.7286 Epoch 33/50 352/352 [==============================] - 8s 22ms/step - loss: 2.0238 - acc: 0.4556 - top5-acc: 0.7734 - val_loss: 2.1736 - val_acc: 0.4270 - val_top5-acc: 0.7368 Epoch 34/50 352/352 [==============================] - 8s 22ms/step - loss: 2.0253 - acc: 0.4547 - top5-acc: 0.7712 - val_loss: 2.2231 - val_acc: 0.4280 - val_top5-acc: 0.7308 Epoch 35/50 352/352 [==============================] - 8s 22ms/step - loss: 1.9992 - acc: 0.4593 - top5-acc: 0.7765 - val_loss: 2.1994 - val_acc: 0.4212 - val_top5-acc: 0.7358 Epoch 36/50 352/352 [==============================] - 8s 22ms/step - loss: 1.9849 - acc: 0.4636 - top5-acc: 0.7754 - val_loss: 2.2167 - val_acc: 0.4276 - val_top5-acc: 0.7308 Epoch 37/50 352/352 [==============================] - 8s 22ms/step - loss: 1.9880 - acc: 0.4677 - top5-acc: 0.7783 - val_loss: 2.1746 - val_acc: 0.4270 - val_top5-acc: 0.7416 Epoch 38/50 352/352 [==============================] - 8s 22ms/step - loss: 1.9562 - acc: 0.4720 - top5-acc: 0.7845 - val_loss: 2.1976 - val_acc: 0.4312 - val_top5-acc: 0.7356 Epoch 39/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8736 - acc: 0.4924 - top5-acc: 0.8004 - val_loss: 2.0755 - val_acc: 0.4578 - val_top5-acc: 0.7586 Epoch 40/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8189 - acc: 0.5042 - top5-acc: 0.8076 - val_loss: 2.0804 - val_acc: 0.4508 - val_top5-acc: 0.7600 Epoch 41/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8069 - acc: 0.5062 - top5-acc: 0.8132 - val_loss: 2.0784 - val_acc: 0.4456 - val_top5-acc: 0.7578 Epoch 42/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8156 - acc: 0.5052 - top5-acc: 0.8110 - val_loss: 2.0910 - val_acc: 0.4544 - val_top5-acc: 0.7542 Epoch 43/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8143 - acc: 0.5046 - top5-acc: 0.8105 - val_loss: 2.1037 - val_acc: 0.4466 - val_top5-acc: 0.7562 Epoch 44/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8119 - acc: 0.5032 - top5-acc: 0.8141 - val_loss: 2.0794 - val_acc: 0.4622 - val_top5-acc: 0.7532 Epoch 45/50 352/352 [==============================] - 8s 22ms/step - loss: 1.7611 - acc: 0.5188 - top5-acc: 0.8224 - val_loss: 2.0371 - val_acc: 0.4650 - val_top5-acc: 0.7628 Epoch 46/50 352/352 [==============================] - 8s 22ms/step - loss: 1.7713 - acc: 0.5189 - top5-acc: 0.8226 - val_loss: 2.0245 - val_acc: 0.4630 - val_top5-acc: 0.7644 Epoch 47/50 352/352 [==============================] - 8s 22ms/step - loss: 1.7809 - acc: 0.5130 - top5-acc: 0.8215 - val_loss: 2.0471 - val_acc: 0.4618 - val_top5-acc: 0.7618 Epoch 48/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8052 - acc: 0.5112 - top5-acc: 0.8165 - val_loss: 2.0441 - val_acc: 0.4596 - val_top5-acc: 0.7658 Epoch 49/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8128 - acc: 0.5039 - top5-acc: 0.8178 - val_loss: 2.0569 - val_acc: 0.4600 - val_top5-acc: 0.7614 Epoch 50/50 352/352 [==============================] - 8s 22ms/step - loss: 1.8179 - acc: 0.5089 - top5-acc: 0.8155 - val_loss: 2.0514 - val_acc: 0.4576 - val_top5-acc: 0.7566 313/313 [==============================] - 2s 6ms/step - loss: 2.0142 - acc: 0.4663 - top5-acc: 0.7647 Test accuracy: 46.63% Test top 5 accuracy: 76.47%
Epoch 1/50 352/352 [==============================] - 14s 30ms/step - loss: 4.1290 - acc: 0.0735 - top5-acc: 0.2303 - val_loss: 3.7880 - val_acc: 0.1228 - val_top5-acc: 0.3322 - lr: 0.0010 Epoch 2/50 352/352 [==============================] - 10s 28ms/step - loss: 3.7066 - acc: 0.1335 - top5-acc: 0.3566 - val_loss: 3.4879 - val_acc: 0.1730 - val_top5-acc: 0.4260 - lr: 0.0010 Epoch 3/50 352/352 [==============================] - 10s 28ms/step - loss: 3.4506 - acc: 0.1723 - top5-acc: 0.4332 - val_loss: 3.2323 - val_acc: 0.2190 - val_top5-acc: 0.4934 - lr: 0.0010 Epoch 4/50 352/352 [==============================] - 10s 28ms/step - loss: 3.2458 - acc: 0.2079 - top5-acc: 0.4879 - val_loss: 3.1450 - val_acc: 0.2264 - val_top5-acc: 0.5162 - lr: 0.0010 Epoch 5/50 352/352 [==============================] - 10s 28ms/step - loss: 3.0989 - acc: 0.2329 - top5-acc: 0.5284 - val_loss: 3.0079 - val_acc: 0.2536 - val_top5-acc: 0.5512 - lr: 0.0010 Epoch 6/50 352/352 [==============================] - 10s 28ms/step - loss: 2.9906 - acc: 0.2529 - top5-acc: 0.5574 - val_loss: 2.9370 - val_acc: 0.2646 - val_top5-acc: 0.5720 - lr: 0.0010 Epoch 7/50 352/352 [==============================] - 10s 28ms/step - loss: 2.8943 - acc: 0.2729 - top5-acc: 0.5781 - val_loss: 2.8244 - val_acc: 0.2928 - val_top5-acc: 0.5936 - lr: 0.0010 Epoch 8/50 352/352 [==============================] - 10s 28ms/step - loss: 2.8082 - acc: 0.2910 - top5-acc: 0.6010 - val_loss: 2.7273 - val_acc: 0.3078 - val_top5-acc: 0.6210 - lr: 0.0010 Epoch 9/50 352/352 [==============================] - 10s 29ms/step - loss: 2.7475 - acc: 0.3035 - top5-acc: 0.6153 - val_loss: 2.6860 - val_acc: 0.3124 - val_top5-acc: 0.6280 - lr: 0.0010 Epoch 10/50 352/352 [==============================] - 10s 28ms/step - loss: 2.6861 - acc: 0.3144 - top5-acc: 0.6287 - val_loss: 2.6646 - val_acc: 0.3148 - val_top5-acc: 0.6342 - lr: 0.0010 Epoch 11/50 352/352 [==============================] - 10s 28ms/step - loss: 2.6354 - acc: 0.3246 - top5-acc: 0.6416 - val_loss: 2.5870 - val_acc: 0.3366 - val_top5-acc: 0.6536 - lr: 0.0010 Epoch 12/50 352/352 [==============================] - 10s 28ms/step - loss: 2.5809 - acc: 0.3361 - top5-acc: 0.6546 - val_loss: 2.5842 - val_acc: 0.3380 - val_top5-acc: 0.6520 - lr: 0.0010 Epoch 13/50 352/352 [==============================] - 10s 28ms/step - loss: 2.5361 - acc: 0.3449 - top5-acc: 0.6618 - val_loss: 2.5554 - val_acc: 0.3474 - val_top5-acc: 0.6600 - lr: 0.0010 Epoch 14/50 352/352 [==============================] - 10s 28ms/step - loss: 2.4875 - acc: 0.3563 - top5-acc: 0.6766 - val_loss: 2.4645 - val_acc: 0.3674 - val_top5-acc: 0.6846 - lr: 0.0010 Epoch 15/50 352/352 [==============================] - 10s 28ms/step - loss: 2.4554 - acc: 0.3632 - top5-acc: 0.6817 - val_loss: 2.4772 - val_acc: 0.3670 - val_top5-acc: 0.6760 - lr: 0.0010 Epoch 16/50 352/352 [==============================] - 10s 28ms/step - loss: 2.4089 - acc: 0.3732 - top5-acc: 0.6914 - val_loss: 2.4345 - val_acc: 0.3702 - val_top5-acc: 0.6892 - lr: 0.0010 Epoch 17/50 352/352 [==============================] - 10s 28ms/step - loss: 2.3810 - acc: 0.3780 - top5-acc: 0.6990 - val_loss: 2.4022 - val_acc: 0.3848 - val_top5-acc: 0.6904 - lr: 0.0010 Epoch 18/50 352/352 [==============================] - 10s 28ms/step - loss: 2.3470 - acc: 0.3843 - top5-acc: 0.7073 - val_loss: 2.3815 - val_acc: 0.3726 - val_top5-acc: 0.6964 - lr: 0.0010 Epoch 19/50 352/352 [==============================] - 10s 28ms/step - loss: 2.3267 - acc: 0.3882 - top5-acc: 0.7111 - val_loss: 2.3811 - val_acc: 0.3870 - val_top5-acc: 0.6952 - lr: 0.0010 Epoch 20/50 352/352 [==============================] - 10s 28ms/step - loss: 2.2939 - acc: 0.3962 - top5-acc: 0.7160 - val_loss: 2.3567 - val_acc: 0.3954 - val_top5-acc: 0.6976 - lr: 0.0010 Epoch 21/50 352/352 [==============================] - 10s 28ms/step - loss: 2.2768 - acc: 0.4031 - top5-acc: 0.7210 - val_loss: 2.3604 - val_acc: 0.3922 - val_top5-acc: 0.6956 - lr: 0.0010 Epoch 22/50 352/352 [==============================] - 10s 27ms/step - loss: 2.2467 - acc: 0.4042 - top5-acc: 0.7278 - val_loss: 2.3631 - val_acc: 0.3896 - val_top5-acc: 0.6892 - lr: 0.0010 Epoch 23/50 352/352 [==============================] - 10s 28ms/step - loss: 2.2360 - acc: 0.4099 - top5-acc: 0.7271 - val_loss: 2.2786 - val_acc: 0.4056 - val_top5-acc: 0.7176 - lr: 0.0010 Epoch 24/50 352/352 [==============================] - 10s 28ms/step - loss: 2.2236 - acc: 0.4127 - top5-acc: 0.7313 - val_loss: 2.3074 - val_acc: 0.3938 - val_top5-acc: 0.7136 - lr: 0.0010 Epoch 25/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1944 - acc: 0.4184 - top5-acc: 0.7385 - val_loss: 2.2984 - val_acc: 0.4052 - val_top5-acc: 0.7096 - lr: 0.0010 Epoch 26/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1830 - acc: 0.4223 - top5-acc: 0.7389 - val_loss: 2.2728 - val_acc: 0.4128 - val_top5-acc: 0.7178 - lr: 0.0010 Epoch 27/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1633 - acc: 0.4260 - top5-acc: 0.7429 - val_loss: 2.2601 - val_acc: 0.4124 - val_top5-acc: 0.7182 - lr: 0.0010 Epoch 28/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1457 - acc: 0.4271 - top5-acc: 0.7466 - val_loss: 2.2843 - val_acc: 0.4120 - val_top5-acc: 0.7136 - lr: 0.0010 Epoch 29/50 352/352 [==============================] - 10s 27ms/step - loss: 2.1145 - acc: 0.4379 - top5-acc: 0.7535 - val_loss: 2.2691 - val_acc: 0.4178 - val_top5-acc: 0.7148 - lr: 0.0010 Epoch 30/50 352/352 [==============================] - 10s 28ms/step - loss: 2.1175 - acc: 0.4338 - top5-acc: 0.7524 - val_loss: 2.2575 - val_acc: 0.4174 - val_top5-acc: 0.7174 - lr: 0.0010 Epoch 31/50 352/352 [==============================] - 10s 27ms/step - loss: 2.1032 - acc: 0.4400 - top5-acc: 0.7562 - val_loss: 2.2561 - val_acc: 0.4184 - val_top5-acc: 0.7232 - lr: 0.0010 Epoch 32/50 352/352 [==============================] - 10s 28ms/step - loss: 2.0866 - acc: 0.4436 - top5-acc: 0.7591 - val_loss: 2.2629 - val_acc: 0.4114 - val_top5-acc: 0.7234 - lr: 0.0010 Epoch 33/50 352/352 [==============================] - 10s 28ms/step - loss: 2.0717 - acc: 0.4457 - top5-acc: 0.7601 - val_loss: 2.2504 - val_acc: 0.4164 - val_top5-acc: 0.7302 - lr: 0.0010 Epoch 34/50 352/352 [==============================] - 10s 28ms/step - loss: 2.0566 - acc: 0.4497 - top5-acc: 0.7675 - val_loss: 2.2221 - val_acc: 0.4246 - val_top5-acc: 0.7282 - lr: 0.0010 Epoch 35/50 352/352 [==============================] - 10s 28ms/step - loss: 2.0396 - acc: 0.4506 - top5-acc: 0.7677 - val_loss: 2.2007 - val_acc: 0.4264 - val_top5-acc: 0.7334 - lr: 0.0010 Epoch 36/50 352/352 [==============================] - 10s 28ms/step - loss: 2.0420 - acc: 0.4508 - top5-acc: 0.7677 - val_loss: 2.2342 - val_acc: 0.4150 - val_top5-acc: 0.7306 - lr: 0.0010 Epoch 37/50 352/352 [==============================] - 10s 28ms/step - loss: 2.0221 - acc: 0.4574 - top5-acc: 0.7729 - val_loss: 2.1794 - val_acc: 0.4286 - val_top5-acc: 0.7388 - lr: 0.0010 Epoch 38/50 352/352 [==============================] - 10s 27ms/step - loss: 2.0204 - acc: 0.4576 - top5-acc: 0.7719 - val_loss: 2.1919 - val_acc: 0.4252 - val_top5-acc: 0.7302 - lr: 0.0010 Epoch 39/50 352/352 [==============================] - 10s 27ms/step - loss: 2.0141 - acc: 0.4580 - top5-acc: 0.7729 - val_loss: 2.1976 - val_acc: 0.4344 - val_top5-acc: 0.7316 - lr: 0.0010 Epoch 40/50 352/352 [==============================] - 10s 28ms/step - loss: 1.9923 - acc: 0.4621 - top5-acc: 0.7775 - val_loss: 2.1793 - val_acc: 0.4334 - val_top5-acc: 0.7434 - lr: 0.0010 Epoch 41/50 352/352 [==============================] - 10s 27ms/step - loss: 1.9835 - acc: 0.4628 - top5-acc: 0.7797 - val_loss: 2.1971 - val_acc: 0.4264 - val_top5-acc: 0.7360 - lr: 0.0010 Epoch 42/50 352/352 [==============================] - 10s 27ms/step - loss: 1.9666 - acc: 0.4679 - top5-acc: 0.7842 - val_loss: 2.1582 - val_acc: 0.4306 - val_top5-acc: 0.7420 - lr: 0.0010 Epoch 43/50 352/352 [==============================] - 10s 28ms/step - loss: 1.9683 - acc: 0.4677 - top5-acc: 0.7821 - val_loss: 2.1964 - val_acc: 0.4330 - val_top5-acc: 0.7382 - lr: 0.0010 Epoch 44/50 352/352 [==============================] - 10s 28ms/step - loss: 1.9540 - acc: 0.4691 - top5-acc: 0.7835 - val_loss: 2.1884 - val_acc: 0.4312 - val_top5-acc: 0.7374 - lr: 0.0010 Epoch 45/50 352/352 [==============================] - 10s 27ms/step - loss: 1.9559 - acc: 0.4718 - top5-acc: 0.7845 - val_loss: 2.2328 - val_acc: 0.4216 - val_top5-acc: 0.7302 - lr: 0.0010 Epoch 46/50 352/352 [==============================] - 10s 28ms/step - loss: 1.9426 - acc: 0.4734 - top5-acc: 0.7880 - val_loss: 2.1601 - val_acc: 0.4414 - val_top5-acc: 0.7376 - lr: 0.0010 Epoch 47/50 352/352 [==============================] - 10s 28ms/step - loss: 1.9308 - acc: 0.4760 - top5-acc: 0.7897 - val_loss: 2.1957 - val_acc: 0.4306 - val_top5-acc: 0.7338 - lr: 0.0010 Epoch 48/50 352/352 [==============================] - 10s 28ms/step - loss: 1.7993 - acc: 0.5082 - top5-acc: 0.8128 - val_loss: 2.0799 - val_acc: 0.4578 - val_top5-acc: 0.7566 - lr: 5.0000e-04 Epoch 49/50 352/352 [==============================] - 10s 28ms/step - loss: 1.7806 - acc: 0.5143 - top5-acc: 0.8173 - val_loss: 2.0714 - val_acc: 0.4576 - val_top5-acc: 0.7582 - lr: 5.0000e-04 Epoch 50/50 352/352 [==============================] - 10s 28ms/step - loss: 1.7798 - acc: 0.5150 - top5-acc: 0.8181 - val_loss: 2.0947 - val_acc: 0.4478 - val_top5-acc: 0.7528 - lr: 5.0000e-04 313/313 [==============================] - 3s 9ms/step - loss: 2.0490 - acc: 0.4629 - top5-acc: 0.7694 Test accuracy: 46.29% Test top 5 accuracy: 76.94% CPU times: user 7min 51s, sys: 1min 14s, total: 9min 5s Wall time: 8min 29s
FNet 論文で述べられているように、埋め込み次元を増やし、FNet ブロックの数を増やし、そしてモデルをより長く訓練することでより良い結果を得られます。入力画像のサイズを大きくして異なるパッチサイズを使用することを試しても良いでしょう。FNet は長い入力に非常に効率的にスケールし、注意ベースの Transformer モデルよりも遥かに高速に動作し、そして競争力のある結果を生成します。
gMLP モデル
gMLP は空間ゲートユニット (SGU, Spatial Gating Unit) にフィーチャーした MLP アーキテクチャです。SGU は以下により、空間 (チャネル) 次元に渡る交差パッチの相互作用を可能にします :
- (チャネルに沿って) パッチに渡る線形射影を適用することにより入力を空間的に変換します。
- 入力とその空間変換の要素ごとの乗算を適用する。
gMLP モジュールの実装
class gMLPLayer(layers.Layer):
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
super(gMLPLayer, self).__init__(*args, **kwargs)
self.channel_projection1 = keras.Sequential(
[
layers.Dense(units=embedding_dim * 2),
tfa.layers.GELU(),
layers.Dropout(rate=dropout_rate),
]
)
self.channel_projection2 = layers.Dense(units=embedding_dim)
self.spatial_projection = layers.Dense(
units=num_patches, bias_initializer="Ones"
)
self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
def spatial_gating_unit(self, x):
# Split x along the channel dimensions.
# Tensors u and v will in th shape of [batch_size, num_patchs, embedding_dim].
u, v = tf.split(x, num_or_size_splits=2, axis=2)
# Apply layer normalization.
v = self.normalize2(v)
# Apply spatial projection.
v_channels = tf.linalg.matrix_transpose(v)
v_projected = self.spatial_projection(v_channels)
v_projected = tf.linalg.matrix_transpose(v_projected)
# Apply element-wise multiplication.
return u * v_projected
def call(self, inputs):
# Apply layer normalization.
x = self.normalize1(inputs)
# Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
x_projected = self.channel_projection1(x)
# Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
x_spatial = self.spatial_gating_unit(x_projected)
# Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
x_projected = self.channel_projection2(x_spatial)
# Add skip connection.
return x + x_projected
gMLP モデルの構築、訓練と評価
現在の設定でのモデルの訓練は V100 GPU 上でエポック毎におよそ 9 秒かかることに注意してください。
gmlp_blocks = keras.Sequential(
[gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)
history = run_experiment(gmlp_classifier)
Epoch 1/50 352/352 [==============================] - 13s 28ms/step - loss: 4.1713 - acc: 0.0704 - top5-acc: 0.2206 - val_loss: 3.5629 - val_acc: 0.1548 - val_top5-acc: 0.4086 Epoch 2/50 352/352 [==============================] - 9s 27ms/step - loss: 3.5146 - acc: 0.1633 - top5-acc: 0.4172 - val_loss: 3.2899 - val_acc: 0.2066 - val_top5-acc: 0.4900 Epoch 3/50 352/352 [==============================] - 9s 26ms/step - loss: 3.2588 - acc: 0.2017 - top5-acc: 0.4895 - val_loss: 3.1152 - val_acc: 0.2362 - val_top5-acc: 0.5278 Epoch 4/50 352/352 [==============================] - 9s 26ms/step - loss: 3.1037 - acc: 0.2331 - top5-acc: 0.5288 - val_loss: 2.9771 - val_acc: 0.2624 - val_top5-acc: 0.5646 Epoch 5/50 352/352 [==============================] - 9s 26ms/step - loss: 2.9483 - acc: 0.2637 - top5-acc: 0.5680 - val_loss: 2.8807 - val_acc: 0.2784 - val_top5-acc: 0.5840 Epoch 6/50 352/352 [==============================] - 9s 26ms/step - loss: 2.8411 - acc: 0.2821 - top5-acc: 0.5930 - val_loss: 2.7246 - val_acc: 0.3146 - val_top5-acc: 0.6256 Epoch 7/50 352/352 [==============================] - 9s 26ms/step - loss: 2.7221 - acc: 0.3085 - top5-acc: 0.6193 - val_loss: 2.7022 - val_acc: 0.3108 - val_top5-acc: 0.6270 Epoch 8/50 352/352 [==============================] - 9s 26ms/step - loss: 2.6296 - acc: 0.3334 - top5-acc: 0.6420 - val_loss: 2.6289 - val_acc: 0.3324 - val_top5-acc: 0.6494 Epoch 9/50 352/352 [==============================] - 9s 26ms/step - loss: 2.5691 - acc: 0.3413 - top5-acc: 0.6563 - val_loss: 2.5353 - val_acc: 0.3586 - val_top5-acc: 0.6746 Epoch 10/50 352/352 [==============================] - 9s 26ms/step - loss: 2.4854 - acc: 0.3575 - top5-acc: 0.6760 - val_loss: 2.5271 - val_acc: 0.3578 - val_top5-acc: 0.6720 Epoch 11/50 352/352 [==============================] - 9s 26ms/step - loss: 2.4252 - acc: 0.3722 - top5-acc: 0.6870 - val_loss: 2.4553 - val_acc: 0.3684 - val_top5-acc: 0.6850 Epoch 12/50 352/352 [==============================] - 9s 26ms/step - loss: 2.3814 - acc: 0.3822 - top5-acc: 0.6985 - val_loss: 2.3841 - val_acc: 0.3888 - val_top5-acc: 0.6966 Epoch 13/50 352/352 [==============================] - 9s 26ms/step - loss: 2.3119 - acc: 0.3950 - top5-acc: 0.7135 - val_loss: 2.4306 - val_acc: 0.3780 - val_top5-acc: 0.6894 Epoch 14/50 352/352 [==============================] - 9s 26ms/step - loss: 2.2886 - acc: 0.4033 - top5-acc: 0.7168 - val_loss: 2.4053 - val_acc: 0.3932 - val_top5-acc: 0.7010 Epoch 15/50 352/352 [==============================] - 9s 26ms/step - loss: 2.2455 - acc: 0.4080 - top5-acc: 0.7233 - val_loss: 2.3443 - val_acc: 0.4004 - val_top5-acc: 0.7128 Epoch 16/50 352/352 [==============================] - 9s 26ms/step - loss: 2.2128 - acc: 0.4152 - top5-acc: 0.7317 - val_loss: 2.3150 - val_acc: 0.4018 - val_top5-acc: 0.7174 Epoch 17/50 352/352 [==============================] - 9s 26ms/step - loss: 2.1990 - acc: 0.4206 - top5-acc: 0.7357 - val_loss: 2.3590 - val_acc: 0.3978 - val_top5-acc: 0.7086 Epoch 18/50 352/352 [==============================] - 9s 26ms/step - loss: 2.1574 - acc: 0.4258 - top5-acc: 0.7451 - val_loss: 2.3140 - val_acc: 0.4052 - val_top5-acc: 0.7256 Epoch 19/50 352/352 [==============================] - 9s 26ms/step - loss: 2.1369 - acc: 0.4309 - top5-acc: 0.7487 - val_loss: 2.3012 - val_acc: 0.4124 - val_top5-acc: 0.7190 Epoch 20/50 352/352 [==============================] - 9s 26ms/step - loss: 2.1222 - acc: 0.4350 - top5-acc: 0.7494 - val_loss: 2.3294 - val_acc: 0.4076 - val_top5-acc: 0.7186 Epoch 21/50 352/352 [==============================] - 9s 26ms/step - loss: 2.0822 - acc: 0.4436 - top5-acc: 0.7576 - val_loss: 2.2498 - val_acc: 0.4302 - val_top5-acc: 0.7276 Epoch 22/50 352/352 [==============================] - 9s 26ms/step - loss: 2.0609 - acc: 0.4518 - top5-acc: 0.7610 - val_loss: 2.2915 - val_acc: 0.4232 - val_top5-acc: 0.7280 Epoch 23/50 352/352 [==============================] - 9s 26ms/step - loss: 2.0482 - acc: 0.4590 - top5-acc: 0.7648 - val_loss: 2.2448 - val_acc: 0.4242 - val_top5-acc: 0.7296 Epoch 24/50 352/352 [==============================] - 9s 26ms/step - loss: 2.0292 - acc: 0.4560 - top5-acc: 0.7705 - val_loss: 2.2526 - val_acc: 0.4334 - val_top5-acc: 0.7324 Epoch 25/50 352/352 [==============================] - 9s 26ms/step - loss: 2.0316 - acc: 0.4544 - top5-acc: 0.7687 - val_loss: 2.2430 - val_acc: 0.4318 - val_top5-acc: 0.7338 Epoch 26/50 352/352 [==============================] - 9s 26ms/step - loss: 1.9988 - acc: 0.4616 - top5-acc: 0.7748 - val_loss: 2.2053 - val_acc: 0.4470 - val_top5-acc: 0.7366 Epoch 27/50 352/352 [==============================] - 9s 26ms/step - loss: 1.9788 - acc: 0.4646 - top5-acc: 0.7806 - val_loss: 2.2313 - val_acc: 0.4378 - val_top5-acc: 0.7420 Epoch 28/50 352/352 [==============================] - 9s 26ms/step - loss: 1.9702 - acc: 0.4688 - top5-acc: 0.7829 - val_loss: 2.2392 - val_acc: 0.4344 - val_top5-acc: 0.7338 Epoch 29/50 352/352 [==============================] - 9s 26ms/step - loss: 1.9488 - acc: 0.4699 - top5-acc: 0.7866 - val_loss: 2.1600 - val_acc: 0.4490 - val_top5-acc: 0.7446 Epoch 30/50 352/352 [==============================] - 9s 26ms/step - loss: 1.9302 - acc: 0.4803 - top5-acc: 0.7878 - val_loss: 2.2069 - val_acc: 0.4410 - val_top5-acc: 0.7486 Epoch 31/50 352/352 [==============================] - 9s 26ms/step - loss: 1.9135 - acc: 0.4806 - top5-acc: 0.7916 - val_loss: 2.1929 - val_acc: 0.4486 - val_top5-acc: 0.7514 Epoch 32/50 352/352 [==============================] - 9s 26ms/step - loss: 1.8890 - acc: 0.4844 - top5-acc: 0.7961 - val_loss: 2.2176 - val_acc: 0.4404 - val_top5-acc: 0.7494 Epoch 33/50 352/352 [==============================] - 9s 26ms/step - loss: 1.8844 - acc: 0.4872 - top5-acc: 0.7980 - val_loss: 2.2321 - val_acc: 0.4444 - val_top5-acc: 0.7460 Epoch 34/50 352/352 [==============================] - 9s 26ms/step - loss: 1.8588 - acc: 0.4912 - top5-acc: 0.8005 - val_loss: 2.1895 - val_acc: 0.4532 - val_top5-acc: 0.7510 Epoch 35/50 352/352 [==============================] - 9s 26ms/step - loss: 1.7259 - acc: 0.5232 - top5-acc: 0.8266 - val_loss: 2.1024 - val_acc: 0.4800 - val_top5-acc: 0.7726 Epoch 36/50 352/352 [==============================] - 9s 26ms/step - loss: 1.6262 - acc: 0.5488 - top5-acc: 0.8437 - val_loss: 2.0712 - val_acc: 0.4830 - val_top5-acc: 0.7754 Epoch 37/50 352/352 [==============================] - 9s 26ms/step - loss: 1.6164 - acc: 0.5481 - top5-acc: 0.8390 - val_loss: 2.1219 - val_acc: 0.4772 - val_top5-acc: 0.7678 Epoch 38/50 352/352 [==============================] - 9s 26ms/step - loss: 1.5850 - acc: 0.5568 - top5-acc: 0.8510 - val_loss: 2.0931 - val_acc: 0.4892 - val_top5-acc: 0.7732 Epoch 39/50 352/352 [==============================] - 9s 26ms/step - loss: 1.5741 - acc: 0.5589 - top5-acc: 0.8507 - val_loss: 2.0910 - val_acc: 0.4910 - val_top5-acc: 0.7700 Epoch 40/50 352/352 [==============================] - 9s 26ms/step - loss: 1.5546 - acc: 0.5675 - top5-acc: 0.8519 - val_loss: 2.1388 - val_acc: 0.4790 - val_top5-acc: 0.7742 Epoch 41/50 352/352 [==============================] - 9s 26ms/step - loss: 1.5464 - acc: 0.5684 - top5-acc: 0.8561 - val_loss: 2.1121 - val_acc: 0.4786 - val_top5-acc: 0.7718 Epoch 42/50 352/352 [==============================] - 9s 26ms/step - loss: 1.4494 - acc: 0.5890 - top5-acc: 0.8702 - val_loss: 2.1157 - val_acc: 0.4944 - val_top5-acc: 0.7802 Epoch 43/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3847 - acc: 0.6069 - top5-acc: 0.8825 - val_loss: 2.1048 - val_acc: 0.4884 - val_top5-acc: 0.7752 Epoch 44/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3724 - acc: 0.6087 - top5-acc: 0.8832 - val_loss: 2.0681 - val_acc: 0.4924 - val_top5-acc: 0.7868 Epoch 45/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3643 - acc: 0.6116 - top5-acc: 0.8840 - val_loss: 2.0965 - val_acc: 0.4932 - val_top5-acc: 0.7752 Epoch 46/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3517 - acc: 0.6184 - top5-acc: 0.8849 - val_loss: 2.0869 - val_acc: 0.4956 - val_top5-acc: 0.7778 Epoch 47/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3377 - acc: 0.6211 - top5-acc: 0.8891 - val_loss: 2.1120 - val_acc: 0.4882 - val_top5-acc: 0.7764 Epoch 48/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3369 - acc: 0.6186 - top5-acc: 0.8888 - val_loss: 2.1257 - val_acc: 0.4912 - val_top5-acc: 0.7752 Epoch 49/50 352/352 [==============================] - 9s 26ms/step - loss: 1.3266 - acc: 0.6190 - top5-acc: 0.8893 - val_loss: 2.0961 - val_acc: 0.4958 - val_top5-acc: 0.7828 Epoch 50/50 352/352 [==============================] - 9s 26ms/step - loss: 1.2731 - acc: 0.6352 - top5-acc: 0.8976 - val_loss: 2.0897 - val_acc: 0.4982 - val_top5-acc: 0.7788 313/313 [==============================] - 2s 7ms/step - loss: 2.0743 - acc: 0.5064 - top5-acc: 0.7828 Test accuracy: 50.64% Test top 5 accuracy: 78.28%
Epoch 1/50 352/352 [==============================] - 16s 33ms/step - loss: 3.9500 - acc: 0.0943 - top5-acc: 0.2878 - val_loss: 3.5389 - val_acc: 0.1556 - val_top5-acc: 0.4148 - lr: 0.0030 Epoch 2/50 352/352 [==============================] - 11s 31ms/step - loss: 3.4664 - acc: 0.1684 - top5-acc: 0.4303 - val_loss: 3.2514 - val_acc: 0.2138 - val_top5-acc: 0.4884 - lr: 0.0030 Epoch 3/50 352/352 [==============================] - 11s 31ms/step - loss: 3.2333 - acc: 0.2090 - top5-acc: 0.4953 - val_loss: 3.0224 - val_acc: 0.2632 - val_top5-acc: 0.5558 - lr: 0.0030 Epoch 4/50 352/352 [==============================] - 11s 31ms/step - loss: 3.0509 - acc: 0.2416 - top5-acc: 0.5418 - val_loss: 2.9632 - val_acc: 0.2676 - val_top5-acc: 0.5686 - lr: 0.0030 Epoch 5/50 352/352 [==============================] - 11s 31ms/step - loss: 2.9120 - acc: 0.2712 - top5-acc: 0.5771 - val_loss: 2.8165 - val_acc: 0.2994 - val_top5-acc: 0.6044 - lr: 0.0030 Epoch 6/50 352/352 [==============================] - 11s 31ms/step - loss: 2.8101 - acc: 0.2913 - top5-acc: 0.6020 - val_loss: 2.7765 - val_acc: 0.3188 - val_top5-acc: 0.6260 - lr: 0.0030 Epoch 7/50 352/352 [==============================] - 11s 31ms/step - loss: 2.7284 - acc: 0.3098 - top5-acc: 0.6229 - val_loss: 2.6594 - val_acc: 0.3310 - val_top5-acc: 0.6378 - lr: 0.0030 Epoch 8/50 352/352 [==============================] - 11s 32ms/step - loss: 2.6453 - acc: 0.3277 - top5-acc: 0.6380 - val_loss: 2.5237 - val_acc: 0.3552 - val_top5-acc: 0.6672 - lr: 0.0030 Epoch 9/50 352/352 [==============================] - 11s 31ms/step - loss: 2.5360 - acc: 0.3508 - top5-acc: 0.6650 - val_loss: 2.4777 - val_acc: 0.3656 - val_top5-acc: 0.6828 - lr: 0.0030 Epoch 10/50 352/352 [==============================] - 11s 32ms/step - loss: 2.4609 - acc: 0.3641 - top5-acc: 0.6810 - val_loss: 2.4785 - val_acc: 0.3688 - val_top5-acc: 0.6886 - lr: 0.0030 Epoch 11/50 352/352 [==============================] - 11s 31ms/step - loss: 2.4225 - acc: 0.3691 - top5-acc: 0.6892 - val_loss: 2.4048 - val_acc: 0.3838 - val_top5-acc: 0.6954 - lr: 0.0030 Epoch 12/50 352/352 [==============================] - 11s 31ms/step - loss: 2.3725 - acc: 0.3785 - top5-acc: 0.7002 - val_loss: 2.3684 - val_acc: 0.3900 - val_top5-acc: 0.7060 - lr: 0.0030 Epoch 13/50 352/352 [==============================] - 11s 31ms/step - loss: 2.3262 - acc: 0.3930 - top5-acc: 0.7093 - val_loss: 2.3695 - val_acc: 0.3958 - val_top5-acc: 0.7060 - lr: 0.0030 Epoch 14/50 352/352 [==============================] - 11s 31ms/step - loss: 2.2951 - acc: 0.3994 - top5-acc: 0.7148 - val_loss: 2.3454 - val_acc: 0.4022 - val_top5-acc: 0.7134 - lr: 0.0030 Epoch 15/50 352/352 [==============================] - 11s 32ms/step - loss: 2.2667 - acc: 0.4046 - top5-acc: 0.7211 - val_loss: 2.3657 - val_acc: 0.4024 - val_top5-acc: 0.7124 - lr: 0.0030 Epoch 16/50 352/352 [==============================] - 11s 31ms/step - loss: 2.2309 - acc: 0.4122 - top5-acc: 0.7277 - val_loss: 2.3058 - val_acc: 0.4024 - val_top5-acc: 0.7166 - lr: 0.0030 Epoch 17/50 352/352 [==============================] - 11s 32ms/step - loss: 2.1990 - acc: 0.4182 - top5-acc: 0.7345 - val_loss: 2.2523 - val_acc: 0.4194 - val_top5-acc: 0.7296 - lr: 0.0030 Epoch 18/50 352/352 [==============================] - 11s 31ms/step - loss: 2.1832 - acc: 0.4241 - top5-acc: 0.7386 - val_loss: 2.2812 - val_acc: 0.4130 - val_top5-acc: 0.7230 - lr: 0.0030 Epoch 19/50 352/352 [==============================] - 11s 31ms/step - loss: 2.1573 - acc: 0.4281 - top5-acc: 0.7437 - val_loss: 2.2921 - val_acc: 0.4182 - val_top5-acc: 0.7276 - lr: 0.0030 Epoch 20/50 352/352 [==============================] - 11s 32ms/step - loss: 2.1399 - acc: 0.4320 - top5-acc: 0.7481 - val_loss: 2.2691 - val_acc: 0.4270 - val_top5-acc: 0.7278 - lr: 0.0030 Epoch 21/50 352/352 [==============================] - 11s 32ms/step - loss: 2.1173 - acc: 0.4381 - top5-acc: 0.7522 - val_loss: 2.2364 - val_acc: 0.4186 - val_top5-acc: 0.7364 - lr: 0.0030 Epoch 22/50 352/352 [==============================] - 11s 32ms/step - loss: 2.0932 - acc: 0.4398 - top5-acc: 0.7575 - val_loss: 2.2614 - val_acc: 0.4218 - val_top5-acc: 0.7352 - lr: 0.0030 Epoch 23/50 352/352 [==============================] - 11s 32ms/step - loss: 2.0779 - acc: 0.4454 - top5-acc: 0.7583 - val_loss: 2.2383 - val_acc: 0.4248 - val_top5-acc: 0.7370 - lr: 0.0030 Epoch 24/50 352/352 [==============================] - 11s 32ms/step - loss: 2.0566 - acc: 0.4508 - top5-acc: 0.7636 - val_loss: 2.1919 - val_acc: 0.4440 - val_top5-acc: 0.7458 - lr: 0.0030 Epoch 25/50 352/352 [==============================] - 11s 31ms/step - loss: 2.0332 - acc: 0.4550 - top5-acc: 0.7682 - val_loss: 2.1731 - val_acc: 0.4412 - val_top5-acc: 0.7398 - lr: 0.0030 Epoch 26/50 352/352 [==============================] - 11s 31ms/step - loss: 2.0127 - acc: 0.4606 - top5-acc: 0.7705 - val_loss: 2.2456 - val_acc: 0.4392 - val_top5-acc: 0.7402 - lr: 0.0030 Epoch 27/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9999 - acc: 0.4626 - top5-acc: 0.7752 - val_loss: 2.1989 - val_acc: 0.4420 - val_top5-acc: 0.7488 - lr: 0.0030 Epoch 28/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9818 - acc: 0.4666 - top5-acc: 0.7791 - val_loss: 2.2228 - val_acc: 0.4408 - val_top5-acc: 0.7446 - lr: 0.0030 Epoch 29/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9701 - acc: 0.4687 - top5-acc: 0.7794 - val_loss: 2.1977 - val_acc: 0.4452 - val_top5-acc: 0.7518 - lr: 0.0030 Epoch 30/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9478 - acc: 0.4711 - top5-acc: 0.7843 - val_loss: 2.1515 - val_acc: 0.4562 - val_top5-acc: 0.7540 - lr: 0.0030 Epoch 31/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9262 - acc: 0.4799 - top5-acc: 0.7885 - val_loss: 2.1403 - val_acc: 0.4546 - val_top5-acc: 0.7574 - lr: 0.0030 Epoch 32/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9224 - acc: 0.4808 - top5-acc: 0.7881 - val_loss: 2.2336 - val_acc: 0.4492 - val_top5-acc: 0.7488 - lr: 0.0030 Epoch 33/50 352/352 [==============================] - 11s 31ms/step - loss: 1.9003 - acc: 0.4831 - top5-acc: 0.7960 - val_loss: 2.1563 - val_acc: 0.4580 - val_top5-acc: 0.7518 - lr: 0.0030 Epoch 34/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8849 - acc: 0.4872 - top5-acc: 0.7964 - val_loss: 2.1260 - val_acc: 0.4646 - val_top5-acc: 0.7588 - lr: 0.0030 Epoch 35/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8782 - acc: 0.4892 - top5-acc: 0.8003 - val_loss: 2.1438 - val_acc: 0.4616 - val_top5-acc: 0.7590 - lr: 0.0030 Epoch 36/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8659 - acc: 0.4924 - top5-acc: 0.8025 - val_loss: 2.0792 - val_acc: 0.4728 - val_top5-acc: 0.7626 - lr: 0.0030 Epoch 37/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8433 - acc: 0.4976 - top5-acc: 0.8045 - val_loss: 2.2000 - val_acc: 0.4554 - val_top5-acc: 0.7602 - lr: 0.0030 Epoch 38/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8371 - acc: 0.5003 - top5-acc: 0.8056 - val_loss: 2.1494 - val_acc: 0.4590 - val_top5-acc: 0.7620 - lr: 0.0030 Epoch 39/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8322 - acc: 0.5011 - top5-acc: 0.8076 - val_loss: 2.1440 - val_acc: 0.4542 - val_top5-acc: 0.7572 - lr: 0.0030 Epoch 40/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8199 - acc: 0.5009 - top5-acc: 0.8107 - val_loss: 2.0831 - val_acc: 0.4710 - val_top5-acc: 0.7674 - lr: 0.0030 Epoch 41/50 352/352 [==============================] - 11s 31ms/step - loss: 1.8071 - acc: 0.5068 - top5-acc: 0.8114 - val_loss: 2.0868 - val_acc: 0.4670 - val_top5-acc: 0.7700 - lr: 0.0030 Epoch 42/50 352/352 [==============================] - 11s 31ms/step - loss: 1.6200 - acc: 0.5499 - top5-acc: 0.8403 - val_loss: 2.0866 - val_acc: 0.4784 - val_top5-acc: 0.7798 - lr: 0.0015 Epoch 43/50 352/352 [==============================] - 11s 31ms/step - loss: 1.5655 - acc: 0.5599 - top5-acc: 0.8532 - val_loss: 2.0813 - val_acc: 0.4888 - val_top5-acc: 0.7778 - lr: 0.0015 Epoch 44/50 352/352 [==============================] - 11s 31ms/step - loss: 1.5466 - acc: 0.5687 - top5-acc: 0.8546 - val_loss: 2.1103 - val_acc: 0.4890 - val_top5-acc: 0.7786 - lr: 0.0015 Epoch 45/50 352/352 [==============================] - 11s 31ms/step - loss: 1.5350 - acc: 0.5717 - top5-acc: 0.8564 - val_loss: 2.1715 - val_acc: 0.4794 - val_top5-acc: 0.7674 - lr: 0.0015 Epoch 46/50 352/352 [==============================] - 11s 31ms/step - loss: 1.5246 - acc: 0.5667 - top5-acc: 0.8585 - val_loss: 2.0667 - val_acc: 0.4924 - val_top5-acc: 0.7832 - lr: 0.0015 Epoch 47/50 352/352 [==============================] - 11s 31ms/step - loss: 1.5062 - acc: 0.5740 - top5-acc: 0.8621 - val_loss: 2.0809 - val_acc: 0.4932 - val_top5-acc: 0.7870 - lr: 0.0015 Epoch 48/50 352/352 [==============================] - 11s 31ms/step - loss: 1.5102 - acc: 0.5743 - top5-acc: 0.8620 - val_loss: 2.0898 - val_acc: 0.4820 - val_top5-acc: 0.7862 - lr: 0.0015 Epoch 49/50 352/352 [==============================] - 11s 31ms/step - loss: 1.4992 - acc: 0.5784 - top5-acc: 0.8638 - val_loss: 2.1189 - val_acc: 0.4834 - val_top5-acc: 0.7750 - lr: 0.0015 Epoch 50/50 352/352 [==============================] - 11s 31ms/step - loss: 1.4927 - acc: 0.5791 - top5-acc: 0.8659 - val_loss: 2.1429 - val_acc: 0.4796 - val_top5-acc: 0.7790 - lr: 0.0015 313/313 [==============================] - 3s 9ms/step - loss: 2.0755 - acc: 0.4963 - top5-acc: 0.7826 Test accuracy: 49.63% Test top 5 accuracy: 78.26% CPU times: user 10min 9s, sys: 39.4 s, total: 10min 48s Wall time: 9min 20s
gMLP 論文で述べられているように、埋め込み次元を増やし、gMLP ブロックの数を増やし、そしてモデルをより長く訓練することでより良い結果を得られます。入力画像のサイズを大きくして異なるパッチサイズを使用することを試しても良いでしょう。論文は MixUp と CutMix、そして AutoAugment のような高度な正則化ストラテジーを使用したことに注意してください。
以上