Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – 現代的な MLPモデルによる画像分類

Posted on 12/01/202112/01/2021 by Sales Information

Keras 2 : examples : 現代的な MLPモデルによる画像分類 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/01/2021 (keras 2.7.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : Image classification with modern MLP models (Author: Khalid Salama)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス ★ 無料 Web セミナー開催中 ★

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援

  • テレワーク & オンライン授業を支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • E-Mail:sales-info@classcat.com  ;  WebSite: www.classcat.com  ;  Facebook

 

Keras 2 : examples : 現代的な MLPモデルによる画像分類

Description: CIFAR-100 画像分類のために MLP-Mixer, FNet と gMLP モデルを実装する。

 

イントロダクション

このサンプルは画像分類のための 3 つの現代的な attention-free な多層パーセプトロン (MLP) ベースのモデルを実装します、CIFAR-100 データセット上で実演されます :

  • Ilya Tolstikhin et al. による MLP-Mixer モデル、2 つのタイプの MLP に基づいています。

  • James Lee-Thorp et al. による FNet モデル、unparameterized フーリエ変換に基づいています。

  • Hanxiao Liu et al. による gMLP モデル、ゲーティングを持つ MLP に基づいています。

このサンプルの目的はこれらのモデルを比較することではないです、これらのモデルは上手く調整されたハイパーパラメータにより異なるデータセット上では異なって遂行される可能性があるからです。むしろ、それらの主要なビルディングブロックの単純な実装を示すことにあります。

このサンプルは TensorFlow 2.4 またはそれ以上と、TensorFlow Addons を必要とします、これは以下のコマンドでインストールできます :

pip install -U tensorflow-addons

 

セットアップ

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

 

データの準備

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)

 

ハイパーパラメータの設定

weight_decay = 0.0001
batch_size = 128
num_epochs = 50
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 8  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 256  # Number of hidden units.
num_blocks = 4  # Number of blocks.

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
Image size: 64 X 64 = 4096
Patch size: 8 X 8 = 64 
Patches per image: 64
Elements per patch (3 channels): 192

 

分類モデルの構築

処理ブロックが与えられたときモデルを構築するメソッドを実装します。

def build_classifier(blocks, positional_encoding=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size, num_patches)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
    x = layers.Dense(units=embedding_dim)(patches)
    if positional_encoding:
        positions = tf.range(start=0, limit=num_patches, delta=1)
        position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=embedding_dim
        )(positions)
        x = x + position_embedding
    # Process x using the module blocks.
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)

 

実験の定義

与えられたモデルをコンパイル、訓練そして評価するユティリティ関数を実装します。

def run_experiment(model):
    # Create Adam optimizer with weight decay.
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay,
    )
    # Compile the model.
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )
    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5
    )
    # Create an early stopping callback.
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history

 

データ増強の使用

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

 

パッチ抽出を層として実装する

class Patches(layers.Layer):
    def __init__(self, patch_size, num_patches):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
        return patches

 

MLP-Mixer モデル

MLP-Mixer は多層パーセプトロン (MLP) だけに基づいたアーキテクチャで、2 つのタイプの MLP 層を含みます :

  1. 一つは画像パッチに独立的に適用されます、これは位置ごとの特徴をミックスします。
  2. 他方は (チャネルに沿って) パッチに渡り適用され、これは空間情報をミックスします。

これは Xception モデルのような depthwise に分離可能な畳み込みベースのモデル に似ていますが、2 つの連鎖された dense 変換を持ち、最大プーリングはなく、そしてバッチ正規化の代わりに層正規化があります。

 

MLP-Mixer モジュールの実装

class MLPMixerLayer(layers.Layer):
    def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
        super(MLPMixerLayer, self).__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
            [
                layers.Dense(units=num_patches),
                tfa.layers.GELU(),
                layers.Dense(units=num_patches),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.mlp2 = keras.Sequential(
            [
                layers.Dense(units=num_patches),
                tfa.layers.GELU(),
                layers.Dense(units=embedding_dim),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.normalize = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize(inputs)
        # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
        x_channels = tf.linalg.matrix_transpose(x)
        # Apply mlp1 on each channel independently.
        mlp1_outputs = self.mlp1(x_channels)
        # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
        mlp1_outputs = tf.linalg.matrix_transpose(mlp1_outputs)
        # Add skip connection.
        x = mlp1_outputs + inputs
        # Apply layer normalization.
        x_patches = self.normalize(x)
        # Apply mlp2 on each patch independtenly.
        mlp2_outputs = self.mlp2(x_patches)
        # Add skip connection.
        x = x + mlp2_outputs
        return x

 

MLP-Mixer モデルを構築、訓練そして評価する

現在の設定でのモデルの訓練は V100 GPU 上でエポック毎におよそ 8 秒かかることに注意してください。

mlpmixer_blocks = keras.Sequential(
    [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
history = run_experiment(mlpmixer_classifier)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:390: UserWarning: Default value of `approximate` is changed from `True` to `False`
  return py_builtins.overload_of(f)(*args)

Epoch 1/50
352/352 [==============================] - 13s 25ms/step - loss: 4.1703 - acc: 0.0756 - top5-acc: 0.2322 - val_loss: 3.6202 - val_acc: 0.1532 - val_top5-acc: 0.4140
Epoch 2/50
352/352 [==============================] - 8s 23ms/step - loss: 3.4165 - acc: 0.1789 - top5-acc: 0.4459 - val_loss: 3.1599 - val_acc: 0.2334 - val_top5-acc: 0.5160
Epoch 3/50
352/352 [==============================] - 8s 23ms/step - loss: 3.1367 - acc: 0.2328 - top5-acc: 0.5230 - val_loss: 3.0539 - val_acc: 0.2560 - val_top5-acc: 0.5664
Epoch 4/50
352/352 [==============================] - 8s 23ms/step - loss: 2.9985 - acc: 0.2624 - top5-acc: 0.5600 - val_loss: 2.9498 - val_acc: 0.2798 - val_top5-acc: 0.5856
Epoch 5/50
352/352 [==============================] - 8s 23ms/step - loss: 2.8806 - acc: 0.2809 - top5-acc: 0.5879 - val_loss: 2.8593 - val_acc: 0.2904 - val_top5-acc: 0.6050
Epoch 6/50
352/352 [==============================] - 8s 23ms/step - loss: 2.7860 - acc: 0.3024 - top5-acc: 0.6124 - val_loss: 2.7405 - val_acc: 0.3256 - val_top5-acc: 0.6364
Epoch 7/50
352/352 [==============================] - 8s 23ms/step - loss: 2.7065 - acc: 0.3152 - top5-acc: 0.6280 - val_loss: 2.7548 - val_acc: 0.3328 - val_top5-acc: 0.6450
Epoch 8/50
352/352 [==============================] - 8s 22ms/step - loss: 2.6443 - acc: 0.3263 - top5-acc: 0.6446 - val_loss: 2.6618 - val_acc: 0.3460 - val_top5-acc: 0.6578
Epoch 9/50
352/352 [==============================] - 8s 23ms/step - loss: 2.5886 - acc: 0.3406 - top5-acc: 0.6573 - val_loss: 2.6065 - val_acc: 0.3492 - val_top5-acc: 0.6650
Epoch 10/50
352/352 [==============================] - 8s 23ms/step - loss: 2.5798 - acc: 0.3404 - top5-acc: 0.6591 - val_loss: 2.6546 - val_acc: 0.3502 - val_top5-acc: 0.6630
Epoch 11/50
352/352 [==============================] - 8s 23ms/step - loss: 2.5269 - acc: 0.3498 - top5-acc: 0.6714 - val_loss: 2.6201 - val_acc: 0.3570 - val_top5-acc: 0.6710
Epoch 12/50
352/352 [==============================] - 8s 23ms/step - loss: 2.5003 - acc: 0.3569 - top5-acc: 0.6745 - val_loss: 2.5936 - val_acc: 0.3564 - val_top5-acc: 0.6662
Epoch 13/50
352/352 [==============================] - 8s 22ms/step - loss: 2.4801 - acc: 0.3619 - top5-acc: 0.6792 - val_loss: 2.5236 - val_acc: 0.3700 - val_top5-acc: 0.6786
Epoch 14/50
352/352 [==============================] - 8s 23ms/step - loss: 2.4392 - acc: 0.3676 - top5-acc: 0.6879 - val_loss: 2.4971 - val_acc: 0.3808 - val_top5-acc: 0.6926
Epoch 15/50
352/352 [==============================] - 8s 23ms/step - loss: 2.4073 - acc: 0.3790 - top5-acc: 0.6940 - val_loss: 2.5972 - val_acc: 0.3682 - val_top5-acc: 0.6750
Epoch 16/50
352/352 [==============================] - 8s 23ms/step - loss: 2.3922 - acc: 0.3754 - top5-acc: 0.6980 - val_loss: 2.4317 - val_acc: 0.3964 - val_top5-acc: 0.6992
Epoch 17/50
352/352 [==============================] - 8s 22ms/step - loss: 2.3603 - acc: 0.3891 - top5-acc: 0.7038 - val_loss: 2.4844 - val_acc: 0.3766 - val_top5-acc: 0.6964
Epoch 18/50
352/352 [==============================] - 8s 23ms/step - loss: 2.3560 - acc: 0.3849 - top5-acc: 0.7056 - val_loss: 2.4564 - val_acc: 0.3910 - val_top5-acc: 0.6990
Epoch 19/50
352/352 [==============================] - 8s 23ms/step - loss: 2.3367 - acc: 0.3900 - top5-acc: 0.7069 - val_loss: 2.4282 - val_acc: 0.3906 - val_top5-acc: 0.7058
Epoch 20/50
352/352 [==============================] - 8s 22ms/step - loss: 2.3096 - acc: 0.3945 - top5-acc: 0.7180 - val_loss: 2.4297 - val_acc: 0.3930 - val_top5-acc: 0.7082
Epoch 21/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2935 - acc: 0.3996 - top5-acc: 0.7211 - val_loss: 2.4053 - val_acc: 0.3974 - val_top5-acc: 0.7076
Epoch 22/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2823 - acc: 0.3991 - top5-acc: 0.7248 - val_loss: 2.4756 - val_acc: 0.3920 - val_top5-acc: 0.6988
Epoch 23/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2371 - acc: 0.4126 - top5-acc: 0.7294 - val_loss: 2.3802 - val_acc: 0.3972 - val_top5-acc: 0.7100
Epoch 24/50
352/352 [==============================] - 8s 23ms/step - loss: 2.2234 - acc: 0.4140 - top5-acc: 0.7336 - val_loss: 2.4402 - val_acc: 0.3994 - val_top5-acc: 0.7096
Epoch 25/50
352/352 [==============================] - 8s 23ms/step - loss: 2.2320 - acc: 0.4088 - top5-acc: 0.7333 - val_loss: 2.4343 - val_acc: 0.3936 - val_top5-acc: 0.7052
Epoch 26/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2094 - acc: 0.4193 - top5-acc: 0.7347 - val_loss: 2.4154 - val_acc: 0.4058 - val_top5-acc: 0.7192
Epoch 27/50
352/352 [==============================] - 8s 23ms/step - loss: 2.2029 - acc: 0.4180 - top5-acc: 0.7370 - val_loss: 2.3116 - val_acc: 0.4226 - val_top5-acc: 0.7268
Epoch 28/50
352/352 [==============================] - 8s 23ms/step - loss: 2.1959 - acc: 0.4234 - top5-acc: 0.7380 - val_loss: 2.4053 - val_acc: 0.4064 - val_top5-acc: 0.7168
Epoch 29/50
352/352 [==============================] - 8s 23ms/step - loss: 2.1815 - acc: 0.4227 - top5-acc: 0.7415 - val_loss: 2.4020 - val_acc: 0.4078 - val_top5-acc: 0.7192
Epoch 30/50
352/352 [==============================] - 8s 23ms/step - loss: 2.1783 - acc: 0.4245 - top5-acc: 0.7407 - val_loss: 2.4206 - val_acc: 0.3996 - val_top5-acc: 0.7234
Epoch 31/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1686 - acc: 0.4248 - top5-acc: 0.7442 - val_loss: 2.3743 - val_acc: 0.4100 - val_top5-acc: 0.7162
Epoch 32/50
352/352 [==============================] - 8s 23ms/step - loss: 2.1487 - acc: 0.4317 - top5-acc: 0.7472 - val_loss: 2.3882 - val_acc: 0.4018 - val_top5-acc: 0.7266
Epoch 33/50
352/352 [==============================] - 8s 22ms/step - loss: 1.9836 - acc: 0.4644 - top5-acc: 0.7782 - val_loss: 2.1742 - val_acc: 0.4536 - val_top5-acc: 0.7506
Epoch 34/50
352/352 [==============================] - 8s 23ms/step - loss: 1.8723 - acc: 0.4950 - top5-acc: 0.7985 - val_loss: 2.1716 - val_acc: 0.4506 - val_top5-acc: 0.7546
Epoch 35/50
352/352 [==============================] - 8s 23ms/step - loss: 1.8461 - acc: 0.5009 - top5-acc: 0.8003 - val_loss: 2.1661 - val_acc: 0.4480 - val_top5-acc: 0.7542
Epoch 36/50
352/352 [==============================] - 8s 23ms/step - loss: 1.8499 - acc: 0.4944 - top5-acc: 0.8044 - val_loss: 2.1523 - val_acc: 0.4566 - val_top5-acc: 0.7628
Epoch 37/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8322 - acc: 0.5000 - top5-acc: 0.8059 - val_loss: 2.1334 - val_acc: 0.4570 - val_top5-acc: 0.7560
Epoch 38/50
352/352 [==============================] - 8s 23ms/step - loss: 1.8269 - acc: 0.5027 - top5-acc: 0.8085 - val_loss: 2.1024 - val_acc: 0.4614 - val_top5-acc: 0.7674
Epoch 39/50
352/352 [==============================] - 8s 23ms/step - loss: 1.8242 - acc: 0.4990 - top5-acc: 0.8098 - val_loss: 2.0789 - val_acc: 0.4610 - val_top5-acc: 0.7792
Epoch 40/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7983 - acc: 0.5067 - top5-acc: 0.8122 - val_loss: 2.1514 - val_acc: 0.4546 - val_top5-acc: 0.7628
Epoch 41/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7974 - acc: 0.5112 - top5-acc: 0.8132 - val_loss: 2.1425 - val_acc: 0.4542 - val_top5-acc: 0.7630
Epoch 42/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7972 - acc: 0.5128 - top5-acc: 0.8127 - val_loss: 2.0980 - val_acc: 0.4580 - val_top5-acc: 0.7724
Epoch 43/50
352/352 [==============================] - 8s 23ms/step - loss: 1.8026 - acc: 0.5066 - top5-acc: 0.8115 - val_loss: 2.0922 - val_acc: 0.4684 - val_top5-acc: 0.7678
Epoch 44/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7924 - acc: 0.5092 - top5-acc: 0.8129 - val_loss: 2.0511 - val_acc: 0.4750 - val_top5-acc: 0.7726
Epoch 45/50
352/352 [==============================] - 8s 22ms/step - loss: 1.7695 - acc: 0.5106 - top5-acc: 0.8193 - val_loss: 2.0949 - val_acc: 0.4678 - val_top5-acc: 0.7708
Epoch 46/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7784 - acc: 0.5106 - top5-acc: 0.8141 - val_loss: 2.1094 - val_acc: 0.4656 - val_top5-acc: 0.7704
Epoch 47/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7625 - acc: 0.5155 - top5-acc: 0.8190 - val_loss: 2.0492 - val_acc: 0.4774 - val_top5-acc: 0.7744
Epoch 48/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7441 - acc: 0.5217 - top5-acc: 0.8190 - val_loss: 2.0562 - val_acc: 0.4698 - val_top5-acc: 0.7828
Epoch 49/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7665 - acc: 0.5113 - top5-acc: 0.8196 - val_loss: 2.0348 - val_acc: 0.4708 - val_top5-acc: 0.7730
Epoch 50/50
352/352 [==============================] - 8s 23ms/step - loss: 1.7392 - acc: 0.5201 - top5-acc: 0.8226 - val_loss: 2.0787 - val_acc: 0.4710 - val_top5-acc: 0.7734
313/313 [==============================] - 2s 8ms/step - loss: 2.0571 - acc: 0.4758 - top5-acc: 0.7718
Test accuracy: 47.58%
Test top 5 accuracy: 77.18%

(訳者注: 実験結果)

Epoch 1/50
352/352 [==============================] - 16s 30ms/step - loss: 3.8697 - acc: 0.1100 - top5-acc: 0.3170 - val_loss: 3.4557 - val_acc: 0.1790 - val_top5-acc: 0.4558 - lr: 0.0050
Epoch 2/50
352/352 [==============================] - 10s 28ms/step - loss: 3.3958 - acc: 0.1858 - top5-acc: 0.4552 - val_loss: 3.2891 - val_acc: 0.2098 - val_top5-acc: 0.4870 - lr: 0.0050
Epoch 3/50
352/352 [==============================] - 10s 28ms/step - loss: 3.1991 - acc: 0.2219 - top5-acc: 0.5061 - val_loss: 3.1089 - val_acc: 0.2374 - val_top5-acc: 0.5406 - lr: 0.0050
Epoch 4/50
352/352 [==============================] - 10s 28ms/step - loss: 3.0485 - acc: 0.2494 - top5-acc: 0.5481 - val_loss: 2.9650 - val_acc: 0.2732 - val_top5-acc: 0.5770 - lr: 0.0050
Epoch 5/50
352/352 [==============================] - 10s 28ms/step - loss: 2.9354 - acc: 0.2733 - top5-acc: 0.5764 - val_loss: 2.8154 - val_acc: 0.2944 - val_top5-acc: 0.6134 - lr: 0.0050
Epoch 6/50
352/352 [==============================] - 10s 29ms/step - loss: 2.8340 - acc: 0.2908 - top5-acc: 0.5999 - val_loss: 2.9185 - val_acc: 0.2928 - val_top5-acc: 0.6002 - lr: 0.0050
Epoch 7/50
352/352 [==============================] - 10s 28ms/step - loss: 2.7587 - acc: 0.3058 - top5-acc: 0.6162 - val_loss: 2.7058 - val_acc: 0.3280 - val_top5-acc: 0.6354 - lr: 0.0050
Epoch 8/50
352/352 [==============================] - 10s 28ms/step - loss: 2.7089 - acc: 0.3141 - top5-acc: 0.6285 - val_loss: 2.6579 - val_acc: 0.3394 - val_top5-acc: 0.6520 - lr: 0.0050
Epoch 9/50
352/352 [==============================] - 10s 28ms/step - loss: 2.6591 - acc: 0.3260 - top5-acc: 0.6398 - val_loss: 2.5549 - val_acc: 0.3518 - val_top5-acc: 0.6688 - lr: 0.0050
Epoch 10/50
352/352 [==============================] - 10s 28ms/step - loss: 2.6103 - acc: 0.3338 - top5-acc: 0.6496 - val_loss: 2.5618 - val_acc: 0.3578 - val_top5-acc: 0.6726 - lr: 0.0050
Epoch 11/50
352/352 [==============================] - 10s 28ms/step - loss: 2.5748 - acc: 0.3424 - top5-acc: 0.6580 - val_loss: 2.5788 - val_acc: 0.3588 - val_top5-acc: 0.6784 - lr: 0.0050
Epoch 12/50
352/352 [==============================] - 10s 29ms/step - loss: 2.5514 - acc: 0.3459 - top5-acc: 0.6642 - val_loss: 2.5447 - val_acc: 0.3616 - val_top5-acc: 0.6782 - lr: 0.0050
Epoch 13/50
352/352 [==============================] - 10s 29ms/step - loss: 2.5192 - acc: 0.3520 - top5-acc: 0.6708 - val_loss: 2.5563 - val_acc: 0.3616 - val_top5-acc: 0.6804 - lr: 0.0050
Epoch 14/50
352/352 [==============================] - 10s 29ms/step - loss: 2.4915 - acc: 0.3598 - top5-acc: 0.6774 - val_loss: 2.6326 - val_acc: 0.3536 - val_top5-acc: 0.6648 - lr: 0.0050
Epoch 15/50
352/352 [==============================] - 10s 29ms/step - loss: 2.4594 - acc: 0.3673 - top5-acc: 0.6853 - val_loss: 2.5152 - val_acc: 0.3662 - val_top5-acc: 0.6824 - lr: 0.0050
Epoch 16/50
352/352 [==============================] - 10s 29ms/step - loss: 2.4265 - acc: 0.3724 - top5-acc: 0.6920 - val_loss: 2.5100 - val_acc: 0.3744 - val_top5-acc: 0.6898 - lr: 0.0050
Epoch 17/50
352/352 [==============================] - 10s 29ms/step - loss: 2.4037 - acc: 0.3766 - top5-acc: 0.6960 - val_loss: 2.4445 - val_acc: 0.3796 - val_top5-acc: 0.7012 - lr: 0.0050
Epoch 18/50
352/352 [==============================] - 10s 29ms/step - loss: 2.3910 - acc: 0.3811 - top5-acc: 0.6981 - val_loss: 2.4771 - val_acc: 0.3872 - val_top5-acc: 0.6934 - lr: 0.0050
Epoch 19/50
352/352 [==============================] - 10s 29ms/step - loss: 2.3566 - acc: 0.3860 - top5-acc: 0.7044 - val_loss: 2.4498 - val_acc: 0.3874 - val_top5-acc: 0.7038 - lr: 0.0050
Epoch 20/50
352/352 [==============================] - 10s 29ms/step - loss: 2.3380 - acc: 0.3894 - top5-acc: 0.7100 - val_loss: 2.4931 - val_acc: 0.3750 - val_top5-acc: 0.6974 - lr: 0.0050
Epoch 21/50
352/352 [==============================] - 10s 29ms/step - loss: 2.3364 - acc: 0.3916 - top5-acc: 0.7092 - val_loss: 2.4935 - val_acc: 0.3856 - val_top5-acc: 0.6998 - lr: 0.0050
Epoch 22/50
352/352 [==============================] - 10s 29ms/step - loss: 2.3147 - acc: 0.3958 - top5-acc: 0.7153 - val_loss: 2.3627 - val_acc: 0.4064 - val_top5-acc: 0.7164 - lr: 0.0050
Epoch 23/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2925 - acc: 0.3992 - top5-acc: 0.7216 - val_loss: 2.4123 - val_acc: 0.3942 - val_top5-acc: 0.7156 - lr: 0.0050
Epoch 24/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2822 - acc: 0.4033 - top5-acc: 0.7220 - val_loss: 2.3777 - val_acc: 0.4038 - val_top5-acc: 0.7148 - lr: 0.0050
Epoch 25/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2657 - acc: 0.4042 - top5-acc: 0.7254 - val_loss: 2.3927 - val_acc: 0.4036 - val_top5-acc: 0.7160 - lr: 0.0050
Epoch 26/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2545 - acc: 0.4058 - top5-acc: 0.7267 - val_loss: 2.3147 - val_acc: 0.4158 - val_top5-acc: 0.7296 - lr: 0.0050
Epoch 27/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2393 - acc: 0.4098 - top5-acc: 0.7288 - val_loss: 2.4425 - val_acc: 0.4008 - val_top5-acc: 0.7088 - lr: 0.0050
Epoch 28/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2324 - acc: 0.4099 - top5-acc: 0.7314 - val_loss: 2.4051 - val_acc: 0.3982 - val_top5-acc: 0.7092 - lr: 0.0050
Epoch 29/50
352/352 [==============================] - 10s 29ms/step - loss: 2.2190 - acc: 0.4152 - top5-acc: 0.7347 - val_loss: 2.4149 - val_acc: 0.4068 - val_top5-acc: 0.7178 - lr: 0.0050
Epoch 30/50
352/352 [==============================] - 10s 28ms/step - loss: 2.2246 - acc: 0.4142 - top5-acc: 0.7332 - val_loss: 2.2987 - val_acc: 0.4204 - val_top5-acc: 0.7294 - lr: 0.0050
Epoch 31/50
352/352 [==============================] - 10s 28ms/step - loss: 2.2001 - acc: 0.4194 - top5-acc: 0.7410 - val_loss: 2.3440 - val_acc: 0.4222 - val_top5-acc: 0.7320 - lr: 0.0050
Epoch 32/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1882 - acc: 0.4216 - top5-acc: 0.7411 - val_loss: 2.3560 - val_acc: 0.4172 - val_top5-acc: 0.7244 - lr: 0.0050
Epoch 33/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1878 - acc: 0.4191 - top5-acc: 0.7403 - val_loss: 2.2540 - val_acc: 0.4294 - val_top5-acc: 0.7408 - lr: 0.0050
Epoch 34/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1727 - acc: 0.4242 - top5-acc: 0.7442 - val_loss: 2.3553 - val_acc: 0.4074 - val_top5-acc: 0.7244 - lr: 0.0050
Epoch 35/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1653 - acc: 0.4285 - top5-acc: 0.7452 - val_loss: 2.3520 - val_acc: 0.4132 - val_top5-acc: 0.7292 - lr: 0.0050
Epoch 36/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1532 - acc: 0.4292 - top5-acc: 0.7484 - val_loss: 2.3275 - val_acc: 0.4166 - val_top5-acc: 0.7382 - lr: 0.0050
Epoch 37/50
352/352 [==============================] - 10s 29ms/step - loss: 2.1571 - acc: 0.4286 - top5-acc: 0.7473 - val_loss: 2.2717 - val_acc: 0.4332 - val_top5-acc: 0.7410 - lr: 0.0050
Epoch 38/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1463 - acc: 0.4312 - top5-acc: 0.7505 - val_loss: 2.4053 - val_acc: 0.4074 - val_top5-acc: 0.7196 - lr: 0.0050
Epoch 39/50
352/352 [==============================] - 10s 29ms/step - loss: 1.9158 - acc: 0.4834 - top5-acc: 0.7912 - val_loss: 2.0828 - val_acc: 0.4694 - val_top5-acc: 0.7694 - lr: 0.0025
Epoch 40/50
352/352 [==============================] - 10s 29ms/step - loss: 1.8799 - acc: 0.4912 - top5-acc: 0.7959 - val_loss: 2.0889 - val_acc: 0.4712 - val_top5-acc: 0.7736 - lr: 0.0025
Epoch 41/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8554 - acc: 0.4949 - top5-acc: 0.8021 - val_loss: 2.0817 - val_acc: 0.4624 - val_top5-acc: 0.7688 - lr: 0.0025
Epoch 42/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8454 - acc: 0.4981 - top5-acc: 0.8037 - val_loss: 2.0485 - val_acc: 0.4694 - val_top5-acc: 0.7748 - lr: 0.0025
Epoch 43/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8385 - acc: 0.4971 - top5-acc: 0.8052 - val_loss: 2.0994 - val_acc: 0.4650 - val_top5-acc: 0.7654 - lr: 0.0025
Epoch 44/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8358 - acc: 0.4994 - top5-acc: 0.8036 - val_loss: 2.0693 - val_acc: 0.4678 - val_top5-acc: 0.7826 - lr: 0.0025
Epoch 45/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8179 - acc: 0.5031 - top5-acc: 0.8076 - val_loss: 2.0791 - val_acc: 0.4644 - val_top5-acc: 0.7726 - lr: 0.0025
Epoch 46/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8228 - acc: 0.5016 - top5-acc: 0.8093 - val_loss: 2.0768 - val_acc: 0.4758 - val_top5-acc: 0.7756 - lr: 0.0025
Epoch 47/50
352/352 [==============================] - 10s 28ms/step - loss: 1.8137 - acc: 0.5042 - top5-acc: 0.8071 - val_loss: 2.0507 - val_acc: 0.4710 - val_top5-acc: 0.7770 - lr: 0.0025
Epoch 48/50
352/352 [==============================] - 10s 28ms/step - loss: 1.6814 - acc: 0.5348 - top5-acc: 0.8324 - val_loss: 2.0001 - val_acc: 0.4824 - val_top5-acc: 0.7834 - lr: 0.0012
Epoch 49/50
352/352 [==============================] - 10s 28ms/step - loss: 1.6568 - acc: 0.5461 - top5-acc: 0.8350 - val_loss: 1.9448 - val_acc: 0.4884 - val_top5-acc: 0.7954 - lr: 0.0012
Epoch 50/50
352/352 [==============================] - 10s 28ms/step - loss: 1.6484 - acc: 0.5426 - top5-acc: 0.8374 - val_loss: 1.9811 - val_acc: 0.4918 - val_top5-acc: 0.7906 - lr: 0.0012
313/313 [==============================] - 4s 12ms/step - loss: 1.9585 - acc: 0.4976 - top5-acc: 0.7859
Test accuracy: 49.76%
Test top 5 accuracy: 78.59%
CPU times: user 10min 17s, sys: 49 s, total: 11min 6s
Wall time: 8min 32s

MLP-Mixer モデルは畳み込みと transformer ベースモデルに比べて少ない数のパラメータを持つ傾向にあり、これは少ない訓練と計算コストでサーブすることに繋がります。

MLP-Mixer 論文で述べられているように、大規模データセット上で事前訓練されたときや、現代的な正則化スキームを使用するとき、MLP-Mixer は最先端モデルに匹敵するスコアを達成します。埋め込み次元を増やし、mixer ブロックの数を増やし、そしてモデルをより長く訓練することでより良い結果を得られます。入力画像のサイズを大きくして異なるパッチサイズを使用することを試しても良いでしょう。

 

FNet モデル

FNet は Transformer ブロックに類似のブロックを使用します。けれども、FNet は Transformer ブロックの自己注意層をパラメータフリーな 2D フーリエ変換層と置き換えます :

  1. 一つの 1D フーリエ変換はパッチに沿って適用されます。
  2. 一つの 1D フーリエ変換はチャネルに沿って適用されます。

 

FNet モジュールの実装

class FNetLayer(layers.Layer):
    def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
        super(FNetLayer, self).__init__(*args, **kwargs)

        self.ffn = keras.Sequential(
            [
                layers.Dense(units=embedding_dim),
                tfa.layers.GELU(),
                layers.Dropout(rate=dropout_rate),
                layers.Dense(units=embedding_dim),
            ]
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply fourier transformations.
        x = tf.cast(
            tf.signal.fft2d(tf.cast(inputs, dtype=tf.dtypes.complex64)),
            dtype=tf.dtypes.float32,
        )
        # Add skip connection.
        x = x + inputs
        # Apply layer normalization.
        x = self.normalize1(x)
        # Apply Feedfowrad network.
        x_ffn = self.ffn(x)
        # Add skip connection.
        x = x + x_ffn
        # Apply layer normalization.
        return self.normalize2(x)

 

FNet モデルの構築、訓練と評価

現在の設定でのモデルの訓練は V100 GPU 上でエポック毎におよそ 8 秒かかることに注意してください。

fnet_blocks = keras.Sequential(
    [FNetLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.001
fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)
history = run_experiment(fnet_classifier)
Epoch 1/50
352/352 [==============================] - 11s 23ms/step - loss: 4.3419 - acc: 0.0470 - top5-acc: 0.1652 - val_loss: 3.8279 - val_acc: 0.1178 - val_top5-acc: 0.3268
Epoch 2/50
352/352 [==============================] - 8s 22ms/step - loss: 3.7814 - acc: 0.1202 - top5-acc: 0.3341 - val_loss: 3.5981 - val_acc: 0.1540 - val_top5-acc: 0.3914
Epoch 3/50
352/352 [==============================] - 8s 22ms/step - loss: 3.5319 - acc: 0.1603 - top5-acc: 0.4086 - val_loss: 3.3309 - val_acc: 0.1956 - val_top5-acc: 0.4656
Epoch 4/50
352/352 [==============================] - 8s 22ms/step - loss: 3.3025 - acc: 0.2001 - top5-acc: 0.4730 - val_loss: 3.1215 - val_acc: 0.2334 - val_top5-acc: 0.5234
Epoch 5/50
352/352 [==============================] - 8s 22ms/step - loss: 3.1621 - acc: 0.2224 - top5-acc: 0.5084 - val_loss: 3.0492 - val_acc: 0.2456 - val_top5-acc: 0.5322
Epoch 6/50
352/352 [==============================] - 8s 22ms/step - loss: 3.0506 - acc: 0.2469 - top5-acc: 0.5400 - val_loss: 2.9519 - val_acc: 0.2684 - val_top5-acc: 0.5652
Epoch 7/50
352/352 [==============================] - 8s 22ms/step - loss: 2.9520 - acc: 0.2618 - top5-acc: 0.5677 - val_loss: 2.8936 - val_acc: 0.2688 - val_top5-acc: 0.5864
Epoch 8/50
352/352 [==============================] - 8s 22ms/step - loss: 2.8377 - acc: 0.2828 - top5-acc: 0.5938 - val_loss: 2.7633 - val_acc: 0.2996 - val_top5-acc: 0.6068
Epoch 9/50
352/352 [==============================] - 8s 22ms/step - loss: 2.7670 - acc: 0.2969 - top5-acc: 0.6107 - val_loss: 2.7309 - val_acc: 0.3112 - val_top5-acc: 0.6136
Epoch 10/50
352/352 [==============================] - 8s 22ms/step - loss: 2.7027 - acc: 0.3148 - top5-acc: 0.6231 - val_loss: 2.6552 - val_acc: 0.3214 - val_top5-acc: 0.6436
Epoch 11/50
352/352 [==============================] - 8s 22ms/step - loss: 2.6375 - acc: 0.3256 - top5-acc: 0.6427 - val_loss: 2.6078 - val_acc: 0.3278 - val_top5-acc: 0.6434
Epoch 12/50
352/352 [==============================] - 8s 22ms/step - loss: 2.5573 - acc: 0.3424 - top5-acc: 0.6576 - val_loss: 2.5617 - val_acc: 0.3438 - val_top5-acc: 0.6534
Epoch 13/50
352/352 [==============================] - 8s 22ms/step - loss: 2.5259 - acc: 0.3488 - top5-acc: 0.6640 - val_loss: 2.5177 - val_acc: 0.3550 - val_top5-acc: 0.6652
Epoch 14/50
352/352 [==============================] - 8s 22ms/step - loss: 2.4782 - acc: 0.3586 - top5-acc: 0.6739 - val_loss: 2.5113 - val_acc: 0.3558 - val_top5-acc: 0.6718
Epoch 15/50
352/352 [==============================] - 8s 22ms/step - loss: 2.4242 - acc: 0.3712 - top5-acc: 0.6897 - val_loss: 2.4280 - val_acc: 0.3724 - val_top5-acc: 0.6880
Epoch 16/50
352/352 [==============================] - 8s 22ms/step - loss: 2.3884 - acc: 0.3741 - top5-acc: 0.6967 - val_loss: 2.4670 - val_acc: 0.3654 - val_top5-acc: 0.6794
Epoch 17/50
352/352 [==============================] - 8s 22ms/step - loss: 2.3619 - acc: 0.3797 - top5-acc: 0.7001 - val_loss: 2.3941 - val_acc: 0.3752 - val_top5-acc: 0.6922
Epoch 18/50
352/352 [==============================] - 8s 22ms/step - loss: 2.3183 - acc: 0.3931 - top5-acc: 0.7137 - val_loss: 2.4028 - val_acc: 0.3814 - val_top5-acc: 0.6954
Epoch 19/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2919 - acc: 0.3955 - top5-acc: 0.7209 - val_loss: 2.3672 - val_acc: 0.3878 - val_top5-acc: 0.7022
Epoch 20/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2612 - acc: 0.4038 - top5-acc: 0.7224 - val_loss: 2.3529 - val_acc: 0.3954 - val_top5-acc: 0.6934
Epoch 21/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2416 - acc: 0.4068 - top5-acc: 0.7262 - val_loss: 2.3014 - val_acc: 0.3980 - val_top5-acc: 0.7158
Epoch 22/50
352/352 [==============================] - 8s 22ms/step - loss: 2.2087 - acc: 0.4162 - top5-acc: 0.7359 - val_loss: 2.2904 - val_acc: 0.4062 - val_top5-acc: 0.7120
Epoch 23/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1803 - acc: 0.4200 - top5-acc: 0.7442 - val_loss: 2.3181 - val_acc: 0.4096 - val_top5-acc: 0.7120
Epoch 24/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1718 - acc: 0.4246 - top5-acc: 0.7403 - val_loss: 2.2687 - val_acc: 0.4094 - val_top5-acc: 0.7234
Epoch 25/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1559 - acc: 0.4198 - top5-acc: 0.7458 - val_loss: 2.2730 - val_acc: 0.4060 - val_top5-acc: 0.7190
Epoch 26/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1285 - acc: 0.4300 - top5-acc: 0.7495 - val_loss: 2.2566 - val_acc: 0.4082 - val_top5-acc: 0.7306
Epoch 27/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1118 - acc: 0.4386 - top5-acc: 0.7538 - val_loss: 2.2544 - val_acc: 0.4178 - val_top5-acc: 0.7218
Epoch 28/50
352/352 [==============================] - 8s 22ms/step - loss: 2.1007 - acc: 0.4408 - top5-acc: 0.7562 - val_loss: 2.2703 - val_acc: 0.4136 - val_top5-acc: 0.7172
Epoch 29/50
352/352 [==============================] - 8s 22ms/step - loss: 2.0707 - acc: 0.4446 - top5-acc: 0.7634 - val_loss: 2.2244 - val_acc: 0.4168 - val_top5-acc: 0.7332
Epoch 30/50
352/352 [==============================] - 8s 22ms/step - loss: 2.0694 - acc: 0.4428 - top5-acc: 0.7611 - val_loss: 2.2557 - val_acc: 0.4060 - val_top5-acc: 0.7270
Epoch 31/50
352/352 [==============================] - 8s 22ms/step - loss: 2.0485 - acc: 0.4502 - top5-acc: 0.7672 - val_loss: 2.2192 - val_acc: 0.4214 - val_top5-acc: 0.7308
Epoch 32/50
352/352 [==============================] - 8s 22ms/step - loss: 2.0105 - acc: 0.4617 - top5-acc: 0.7718 - val_loss: 2.2065 - val_acc: 0.4222 - val_top5-acc: 0.7286
Epoch 33/50
352/352 [==============================] - 8s 22ms/step - loss: 2.0238 - acc: 0.4556 - top5-acc: 0.7734 - val_loss: 2.1736 - val_acc: 0.4270 - val_top5-acc: 0.7368
Epoch 34/50
352/352 [==============================] - 8s 22ms/step - loss: 2.0253 - acc: 0.4547 - top5-acc: 0.7712 - val_loss: 2.2231 - val_acc: 0.4280 - val_top5-acc: 0.7308
Epoch 35/50
352/352 [==============================] - 8s 22ms/step - loss: 1.9992 - acc: 0.4593 - top5-acc: 0.7765 - val_loss: 2.1994 - val_acc: 0.4212 - val_top5-acc: 0.7358
Epoch 36/50
352/352 [==============================] - 8s 22ms/step - loss: 1.9849 - acc: 0.4636 - top5-acc: 0.7754 - val_loss: 2.2167 - val_acc: 0.4276 - val_top5-acc: 0.7308
Epoch 37/50
352/352 [==============================] - 8s 22ms/step - loss: 1.9880 - acc: 0.4677 - top5-acc: 0.7783 - val_loss: 2.1746 - val_acc: 0.4270 - val_top5-acc: 0.7416
Epoch 38/50
352/352 [==============================] - 8s 22ms/step - loss: 1.9562 - acc: 0.4720 - top5-acc: 0.7845 - val_loss: 2.1976 - val_acc: 0.4312 - val_top5-acc: 0.7356
Epoch 39/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8736 - acc: 0.4924 - top5-acc: 0.8004 - val_loss: 2.0755 - val_acc: 0.4578 - val_top5-acc: 0.7586
Epoch 40/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8189 - acc: 0.5042 - top5-acc: 0.8076 - val_loss: 2.0804 - val_acc: 0.4508 - val_top5-acc: 0.7600
Epoch 41/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8069 - acc: 0.5062 - top5-acc: 0.8132 - val_loss: 2.0784 - val_acc: 0.4456 - val_top5-acc: 0.7578
Epoch 42/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8156 - acc: 0.5052 - top5-acc: 0.8110 - val_loss: 2.0910 - val_acc: 0.4544 - val_top5-acc: 0.7542
Epoch 43/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8143 - acc: 0.5046 - top5-acc: 0.8105 - val_loss: 2.1037 - val_acc: 0.4466 - val_top5-acc: 0.7562
Epoch 44/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8119 - acc: 0.5032 - top5-acc: 0.8141 - val_loss: 2.0794 - val_acc: 0.4622 - val_top5-acc: 0.7532
Epoch 45/50
352/352 [==============================] - 8s 22ms/step - loss: 1.7611 - acc: 0.5188 - top5-acc: 0.8224 - val_loss: 2.0371 - val_acc: 0.4650 - val_top5-acc: 0.7628
Epoch 46/50
352/352 [==============================] - 8s 22ms/step - loss: 1.7713 - acc: 0.5189 - top5-acc: 0.8226 - val_loss: 2.0245 - val_acc: 0.4630 - val_top5-acc: 0.7644
Epoch 47/50
352/352 [==============================] - 8s 22ms/step - loss: 1.7809 - acc: 0.5130 - top5-acc: 0.8215 - val_loss: 2.0471 - val_acc: 0.4618 - val_top5-acc: 0.7618
Epoch 48/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8052 - acc: 0.5112 - top5-acc: 0.8165 - val_loss: 2.0441 - val_acc: 0.4596 - val_top5-acc: 0.7658
Epoch 49/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8128 - acc: 0.5039 - top5-acc: 0.8178 - val_loss: 2.0569 - val_acc: 0.4600 - val_top5-acc: 0.7614
Epoch 50/50
352/352 [==============================] - 8s 22ms/step - loss: 1.8179 - acc: 0.5089 - top5-acc: 0.8155 - val_loss: 2.0514 - val_acc: 0.4576 - val_top5-acc: 0.7566
313/313 [==============================] - 2s 6ms/step - loss: 2.0142 - acc: 0.4663 - top5-acc: 0.7647
Test accuracy: 46.63%
Test top 5 accuracy: 76.47%
Epoch 1/50
352/352 [==============================] - 14s 30ms/step - loss: 4.1290 - acc: 0.0735 - top5-acc: 0.2303 - val_loss: 3.7880 - val_acc: 0.1228 - val_top5-acc: 0.3322 - lr: 0.0010
Epoch 2/50
352/352 [==============================] - 10s 28ms/step - loss: 3.7066 - acc: 0.1335 - top5-acc: 0.3566 - val_loss: 3.4879 - val_acc: 0.1730 - val_top5-acc: 0.4260 - lr: 0.0010
Epoch 3/50
352/352 [==============================] - 10s 28ms/step - loss: 3.4506 - acc: 0.1723 - top5-acc: 0.4332 - val_loss: 3.2323 - val_acc: 0.2190 - val_top5-acc: 0.4934 - lr: 0.0010
Epoch 4/50
352/352 [==============================] - 10s 28ms/step - loss: 3.2458 - acc: 0.2079 - top5-acc: 0.4879 - val_loss: 3.1450 - val_acc: 0.2264 - val_top5-acc: 0.5162 - lr: 0.0010
Epoch 5/50
352/352 [==============================] - 10s 28ms/step - loss: 3.0989 - acc: 0.2329 - top5-acc: 0.5284 - val_loss: 3.0079 - val_acc: 0.2536 - val_top5-acc: 0.5512 - lr: 0.0010
Epoch 6/50
352/352 [==============================] - 10s 28ms/step - loss: 2.9906 - acc: 0.2529 - top5-acc: 0.5574 - val_loss: 2.9370 - val_acc: 0.2646 - val_top5-acc: 0.5720 - lr: 0.0010
Epoch 7/50
352/352 [==============================] - 10s 28ms/step - loss: 2.8943 - acc: 0.2729 - top5-acc: 0.5781 - val_loss: 2.8244 - val_acc: 0.2928 - val_top5-acc: 0.5936 - lr: 0.0010
Epoch 8/50
352/352 [==============================] - 10s 28ms/step - loss: 2.8082 - acc: 0.2910 - top5-acc: 0.6010 - val_loss: 2.7273 - val_acc: 0.3078 - val_top5-acc: 0.6210 - lr: 0.0010
Epoch 9/50
352/352 [==============================] - 10s 29ms/step - loss: 2.7475 - acc: 0.3035 - top5-acc: 0.6153 - val_loss: 2.6860 - val_acc: 0.3124 - val_top5-acc: 0.6280 - lr: 0.0010
Epoch 10/50
352/352 [==============================] - 10s 28ms/step - loss: 2.6861 - acc: 0.3144 - top5-acc: 0.6287 - val_loss: 2.6646 - val_acc: 0.3148 - val_top5-acc: 0.6342 - lr: 0.0010
Epoch 11/50
352/352 [==============================] - 10s 28ms/step - loss: 2.6354 - acc: 0.3246 - top5-acc: 0.6416 - val_loss: 2.5870 - val_acc: 0.3366 - val_top5-acc: 0.6536 - lr: 0.0010
Epoch 12/50
352/352 [==============================] - 10s 28ms/step - loss: 2.5809 - acc: 0.3361 - top5-acc: 0.6546 - val_loss: 2.5842 - val_acc: 0.3380 - val_top5-acc: 0.6520 - lr: 0.0010
Epoch 13/50
352/352 [==============================] - 10s 28ms/step - loss: 2.5361 - acc: 0.3449 - top5-acc: 0.6618 - val_loss: 2.5554 - val_acc: 0.3474 - val_top5-acc: 0.6600 - lr: 0.0010
Epoch 14/50
352/352 [==============================] - 10s 28ms/step - loss: 2.4875 - acc: 0.3563 - top5-acc: 0.6766 - val_loss: 2.4645 - val_acc: 0.3674 - val_top5-acc: 0.6846 - lr: 0.0010
Epoch 15/50
352/352 [==============================] - 10s 28ms/step - loss: 2.4554 - acc: 0.3632 - top5-acc: 0.6817 - val_loss: 2.4772 - val_acc: 0.3670 - val_top5-acc: 0.6760 - lr: 0.0010
Epoch 16/50
352/352 [==============================] - 10s 28ms/step - loss: 2.4089 - acc: 0.3732 - top5-acc: 0.6914 - val_loss: 2.4345 - val_acc: 0.3702 - val_top5-acc: 0.6892 - lr: 0.0010
Epoch 17/50
352/352 [==============================] - 10s 28ms/step - loss: 2.3810 - acc: 0.3780 - top5-acc: 0.6990 - val_loss: 2.4022 - val_acc: 0.3848 - val_top5-acc: 0.6904 - lr: 0.0010
Epoch 18/50
352/352 [==============================] - 10s 28ms/step - loss: 2.3470 - acc: 0.3843 - top5-acc: 0.7073 - val_loss: 2.3815 - val_acc: 0.3726 - val_top5-acc: 0.6964 - lr: 0.0010
Epoch 19/50
352/352 [==============================] - 10s 28ms/step - loss: 2.3267 - acc: 0.3882 - top5-acc: 0.7111 - val_loss: 2.3811 - val_acc: 0.3870 - val_top5-acc: 0.6952 - lr: 0.0010
Epoch 20/50
352/352 [==============================] - 10s 28ms/step - loss: 2.2939 - acc: 0.3962 - top5-acc: 0.7160 - val_loss: 2.3567 - val_acc: 0.3954 - val_top5-acc: 0.6976 - lr: 0.0010
Epoch 21/50
352/352 [==============================] - 10s 28ms/step - loss: 2.2768 - acc: 0.4031 - top5-acc: 0.7210 - val_loss: 2.3604 - val_acc: 0.3922 - val_top5-acc: 0.6956 - lr: 0.0010
Epoch 22/50
352/352 [==============================] - 10s 27ms/step - loss: 2.2467 - acc: 0.4042 - top5-acc: 0.7278 - val_loss: 2.3631 - val_acc: 0.3896 - val_top5-acc: 0.6892 - lr: 0.0010
Epoch 23/50
352/352 [==============================] - 10s 28ms/step - loss: 2.2360 - acc: 0.4099 - top5-acc: 0.7271 - val_loss: 2.2786 - val_acc: 0.4056 - val_top5-acc: 0.7176 - lr: 0.0010
Epoch 24/50
352/352 [==============================] - 10s 28ms/step - loss: 2.2236 - acc: 0.4127 - top5-acc: 0.7313 - val_loss: 2.3074 - val_acc: 0.3938 - val_top5-acc: 0.7136 - lr: 0.0010
Epoch 25/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1944 - acc: 0.4184 - top5-acc: 0.7385 - val_loss: 2.2984 - val_acc: 0.4052 - val_top5-acc: 0.7096 - lr: 0.0010
Epoch 26/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1830 - acc: 0.4223 - top5-acc: 0.7389 - val_loss: 2.2728 - val_acc: 0.4128 - val_top5-acc: 0.7178 - lr: 0.0010
Epoch 27/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1633 - acc: 0.4260 - top5-acc: 0.7429 - val_loss: 2.2601 - val_acc: 0.4124 - val_top5-acc: 0.7182 - lr: 0.0010
Epoch 28/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1457 - acc: 0.4271 - top5-acc: 0.7466 - val_loss: 2.2843 - val_acc: 0.4120 - val_top5-acc: 0.7136 - lr: 0.0010
Epoch 29/50
352/352 [==============================] - 10s 27ms/step - loss: 2.1145 - acc: 0.4379 - top5-acc: 0.7535 - val_loss: 2.2691 - val_acc: 0.4178 - val_top5-acc: 0.7148 - lr: 0.0010
Epoch 30/50
352/352 [==============================] - 10s 28ms/step - loss: 2.1175 - acc: 0.4338 - top5-acc: 0.7524 - val_loss: 2.2575 - val_acc: 0.4174 - val_top5-acc: 0.7174 - lr: 0.0010
Epoch 31/50
352/352 [==============================] - 10s 27ms/step - loss: 2.1032 - acc: 0.4400 - top5-acc: 0.7562 - val_loss: 2.2561 - val_acc: 0.4184 - val_top5-acc: 0.7232 - lr: 0.0010
Epoch 32/50
352/352 [==============================] - 10s 28ms/step - loss: 2.0866 - acc: 0.4436 - top5-acc: 0.7591 - val_loss: 2.2629 - val_acc: 0.4114 - val_top5-acc: 0.7234 - lr: 0.0010
Epoch 33/50
352/352 [==============================] - 10s 28ms/step - loss: 2.0717 - acc: 0.4457 - top5-acc: 0.7601 - val_loss: 2.2504 - val_acc: 0.4164 - val_top5-acc: 0.7302 - lr: 0.0010
Epoch 34/50
352/352 [==============================] - 10s 28ms/step - loss: 2.0566 - acc: 0.4497 - top5-acc: 0.7675 - val_loss: 2.2221 - val_acc: 0.4246 - val_top5-acc: 0.7282 - lr: 0.0010
Epoch 35/50
352/352 [==============================] - 10s 28ms/step - loss: 2.0396 - acc: 0.4506 - top5-acc: 0.7677 - val_loss: 2.2007 - val_acc: 0.4264 - val_top5-acc: 0.7334 - lr: 0.0010
Epoch 36/50
352/352 [==============================] - 10s 28ms/step - loss: 2.0420 - acc: 0.4508 - top5-acc: 0.7677 - val_loss: 2.2342 - val_acc: 0.4150 - val_top5-acc: 0.7306 - lr: 0.0010
Epoch 37/50
352/352 [==============================] - 10s 28ms/step - loss: 2.0221 - acc: 0.4574 - top5-acc: 0.7729 - val_loss: 2.1794 - val_acc: 0.4286 - val_top5-acc: 0.7388 - lr: 0.0010
Epoch 38/50
352/352 [==============================] - 10s 27ms/step - loss: 2.0204 - acc: 0.4576 - top5-acc: 0.7719 - val_loss: 2.1919 - val_acc: 0.4252 - val_top5-acc: 0.7302 - lr: 0.0010
Epoch 39/50
352/352 [==============================] - 10s 27ms/step - loss: 2.0141 - acc: 0.4580 - top5-acc: 0.7729 - val_loss: 2.1976 - val_acc: 0.4344 - val_top5-acc: 0.7316 - lr: 0.0010
Epoch 40/50
352/352 [==============================] - 10s 28ms/step - loss: 1.9923 - acc: 0.4621 - top5-acc: 0.7775 - val_loss: 2.1793 - val_acc: 0.4334 - val_top5-acc: 0.7434 - lr: 0.0010
Epoch 41/50
352/352 [==============================] - 10s 27ms/step - loss: 1.9835 - acc: 0.4628 - top5-acc: 0.7797 - val_loss: 2.1971 - val_acc: 0.4264 - val_top5-acc: 0.7360 - lr: 0.0010
Epoch 42/50
352/352 [==============================] - 10s 27ms/step - loss: 1.9666 - acc: 0.4679 - top5-acc: 0.7842 - val_loss: 2.1582 - val_acc: 0.4306 - val_top5-acc: 0.7420 - lr: 0.0010
Epoch 43/50
352/352 [==============================] - 10s 28ms/step - loss: 1.9683 - acc: 0.4677 - top5-acc: 0.7821 - val_loss: 2.1964 - val_acc: 0.4330 - val_top5-acc: 0.7382 - lr: 0.0010
Epoch 44/50
352/352 [==============================] - 10s 28ms/step - loss: 1.9540 - acc: 0.4691 - top5-acc: 0.7835 - val_loss: 2.1884 - val_acc: 0.4312 - val_top5-acc: 0.7374 - lr: 0.0010
Epoch 45/50
352/352 [==============================] - 10s 27ms/step - loss: 1.9559 - acc: 0.4718 - top5-acc: 0.7845 - val_loss: 2.2328 - val_acc: 0.4216 - val_top5-acc: 0.7302 - lr: 0.0010
Epoch 46/50
352/352 [==============================] - 10s 28ms/step - loss: 1.9426 - acc: 0.4734 - top5-acc: 0.7880 - val_loss: 2.1601 - val_acc: 0.4414 - val_top5-acc: 0.7376 - lr: 0.0010
Epoch 47/50
352/352 [==============================] - 10s 28ms/step - loss: 1.9308 - acc: 0.4760 - top5-acc: 0.7897 - val_loss: 2.1957 - val_acc: 0.4306 - val_top5-acc: 0.7338 - lr: 0.0010
Epoch 48/50
352/352 [==============================] - 10s 28ms/step - loss: 1.7993 - acc: 0.5082 - top5-acc: 0.8128 - val_loss: 2.0799 - val_acc: 0.4578 - val_top5-acc: 0.7566 - lr: 5.0000e-04
Epoch 49/50
352/352 [==============================] - 10s 28ms/step - loss: 1.7806 - acc: 0.5143 - top5-acc: 0.8173 - val_loss: 2.0714 - val_acc: 0.4576 - val_top5-acc: 0.7582 - lr: 5.0000e-04
Epoch 50/50
352/352 [==============================] - 10s 28ms/step - loss: 1.7798 - acc: 0.5150 - top5-acc: 0.8181 - val_loss: 2.0947 - val_acc: 0.4478 - val_top5-acc: 0.7528 - lr: 5.0000e-04
313/313 [==============================] - 3s 9ms/step - loss: 2.0490 - acc: 0.4629 - top5-acc: 0.7694
Test accuracy: 46.29%
Test top 5 accuracy: 76.94%
CPU times: user 7min 51s, sys: 1min 14s, total: 9min 5s
Wall time: 8min 29s

FNet 論文で述べられているように、埋め込み次元を増やし、FNet ブロックの数を増やし、そしてモデルをより長く訓練することでより良い結果を得られます。入力画像のサイズを大きくして異なるパッチサイズを使用することを試しても良いでしょう。FNet は長い入力に非常に効率的にスケールし、注意ベースの Transformer モデルよりも遥かに高速に動作し、そして競争力のある結果を生成します。

 

gMLP モデル

gMLP は空間ゲートユニット (SGU, Spatial Gating Unit) にフィーチャーした MLP アーキテクチャです。SGU は以下により、空間 (チャネル) 次元に渡る交差パッチの相互作用を可能にします :

  1. (チャネルに沿って) パッチに渡る線形射影を適用することにより入力を空間的に変換します。
  2. 入力とその空間変換の要素ごとの乗算を適用する。

 

gMLP モジュールの実装

class gMLPLayer(layers.Layer):
    def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
        super(gMLPLayer, self).__init__(*args, **kwargs)

        self.channel_projection1 = keras.Sequential(
            [
                layers.Dense(units=embedding_dim * 2),
                tfa.layers.GELU(),
                layers.Dropout(rate=dropout_rate),
            ]
        )

        self.channel_projection2 = layers.Dense(units=embedding_dim)

        self.spatial_projection = layers.Dense(
            units=num_patches, bias_initializer="Ones"
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def spatial_gating_unit(self, x):
        # Split x along the channel dimensions.
        # Tensors u and v will in th shape of [batch_size, num_patchs, embedding_dim].
        u, v = tf.split(x, num_or_size_splits=2, axis=2)
        # Apply layer normalization.
        v = self.normalize2(v)
        # Apply spatial projection.
        v_channels = tf.linalg.matrix_transpose(v)
        v_projected = self.spatial_projection(v_channels)
        v_projected = tf.linalg.matrix_transpose(v_projected)
        # Apply element-wise multiplication.
        return u * v_projected

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize1(inputs)
        # Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
        x_projected = self.channel_projection1(x)
        # Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
        x_spatial = self.spatial_gating_unit(x_projected)
        # Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
        x_projected = self.channel_projection2(x_spatial)
        # Add skip connection.
        return x + x_projected

 

gMLP モデルの構築、訓練と評価

現在の設定でのモデルの訓練は V100 GPU 上でエポック毎におよそ 9 秒かかることに注意してください。

gmlp_blocks = keras.Sequential(
    [gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)
history = run_experiment(gmlp_classifier)
Epoch 1/50
352/352 [==============================] - 13s 28ms/step - loss: 4.1713 - acc: 0.0704 - top5-acc: 0.2206 - val_loss: 3.5629 - val_acc: 0.1548 - val_top5-acc: 0.4086
Epoch 2/50
352/352 [==============================] - 9s 27ms/step - loss: 3.5146 - acc: 0.1633 - top5-acc: 0.4172 - val_loss: 3.2899 - val_acc: 0.2066 - val_top5-acc: 0.4900
Epoch 3/50
352/352 [==============================] - 9s 26ms/step - loss: 3.2588 - acc: 0.2017 - top5-acc: 0.4895 - val_loss: 3.1152 - val_acc: 0.2362 - val_top5-acc: 0.5278
Epoch 4/50
352/352 [==============================] - 9s 26ms/step - loss: 3.1037 - acc: 0.2331 - top5-acc: 0.5288 - val_loss: 2.9771 - val_acc: 0.2624 - val_top5-acc: 0.5646
Epoch 5/50
352/352 [==============================] - 9s 26ms/step - loss: 2.9483 - acc: 0.2637 - top5-acc: 0.5680 - val_loss: 2.8807 - val_acc: 0.2784 - val_top5-acc: 0.5840
Epoch 6/50
352/352 [==============================] - 9s 26ms/step - loss: 2.8411 - acc: 0.2821 - top5-acc: 0.5930 - val_loss: 2.7246 - val_acc: 0.3146 - val_top5-acc: 0.6256
Epoch 7/50
352/352 [==============================] - 9s 26ms/step - loss: 2.7221 - acc: 0.3085 - top5-acc: 0.6193 - val_loss: 2.7022 - val_acc: 0.3108 - val_top5-acc: 0.6270
Epoch 8/50
352/352 [==============================] - 9s 26ms/step - loss: 2.6296 - acc: 0.3334 - top5-acc: 0.6420 - val_loss: 2.6289 - val_acc: 0.3324 - val_top5-acc: 0.6494
Epoch 9/50
352/352 [==============================] - 9s 26ms/step - loss: 2.5691 - acc: 0.3413 - top5-acc: 0.6563 - val_loss: 2.5353 - val_acc: 0.3586 - val_top5-acc: 0.6746
Epoch 10/50
352/352 [==============================] - 9s 26ms/step - loss: 2.4854 - acc: 0.3575 - top5-acc: 0.6760 - val_loss: 2.5271 - val_acc: 0.3578 - val_top5-acc: 0.6720
Epoch 11/50
352/352 [==============================] - 9s 26ms/step - loss: 2.4252 - acc: 0.3722 - top5-acc: 0.6870 - val_loss: 2.4553 - val_acc: 0.3684 - val_top5-acc: 0.6850
Epoch 12/50
352/352 [==============================] - 9s 26ms/step - loss: 2.3814 - acc: 0.3822 - top5-acc: 0.6985 - val_loss: 2.3841 - val_acc: 0.3888 - val_top5-acc: 0.6966
Epoch 13/50
352/352 [==============================] - 9s 26ms/step - loss: 2.3119 - acc: 0.3950 - top5-acc: 0.7135 - val_loss: 2.4306 - val_acc: 0.3780 - val_top5-acc: 0.6894
Epoch 14/50
352/352 [==============================] - 9s 26ms/step - loss: 2.2886 - acc: 0.4033 - top5-acc: 0.7168 - val_loss: 2.4053 - val_acc: 0.3932 - val_top5-acc: 0.7010
Epoch 15/50
352/352 [==============================] - 9s 26ms/step - loss: 2.2455 - acc: 0.4080 - top5-acc: 0.7233 - val_loss: 2.3443 - val_acc: 0.4004 - val_top5-acc: 0.7128
Epoch 16/50
352/352 [==============================] - 9s 26ms/step - loss: 2.2128 - acc: 0.4152 - top5-acc: 0.7317 - val_loss: 2.3150 - val_acc: 0.4018 - val_top5-acc: 0.7174
Epoch 17/50
352/352 [==============================] - 9s 26ms/step - loss: 2.1990 - acc: 0.4206 - top5-acc: 0.7357 - val_loss: 2.3590 - val_acc: 0.3978 - val_top5-acc: 0.7086
Epoch 18/50
352/352 [==============================] - 9s 26ms/step - loss: 2.1574 - acc: 0.4258 - top5-acc: 0.7451 - val_loss: 2.3140 - val_acc: 0.4052 - val_top5-acc: 0.7256
Epoch 19/50
352/352 [==============================] - 9s 26ms/step - loss: 2.1369 - acc: 0.4309 - top5-acc: 0.7487 - val_loss: 2.3012 - val_acc: 0.4124 - val_top5-acc: 0.7190
Epoch 20/50
352/352 [==============================] - 9s 26ms/step - loss: 2.1222 - acc: 0.4350 - top5-acc: 0.7494 - val_loss: 2.3294 - val_acc: 0.4076 - val_top5-acc: 0.7186
Epoch 21/50
352/352 [==============================] - 9s 26ms/step - loss: 2.0822 - acc: 0.4436 - top5-acc: 0.7576 - val_loss: 2.2498 - val_acc: 0.4302 - val_top5-acc: 0.7276
Epoch 22/50
352/352 [==============================] - 9s 26ms/step - loss: 2.0609 - acc: 0.4518 - top5-acc: 0.7610 - val_loss: 2.2915 - val_acc: 0.4232 - val_top5-acc: 0.7280
Epoch 23/50
352/352 [==============================] - 9s 26ms/step - loss: 2.0482 - acc: 0.4590 - top5-acc: 0.7648 - val_loss: 2.2448 - val_acc: 0.4242 - val_top5-acc: 0.7296
Epoch 24/50
352/352 [==============================] - 9s 26ms/step - loss: 2.0292 - acc: 0.4560 - top5-acc: 0.7705 - val_loss: 2.2526 - val_acc: 0.4334 - val_top5-acc: 0.7324
Epoch 25/50
352/352 [==============================] - 9s 26ms/step - loss: 2.0316 - acc: 0.4544 - top5-acc: 0.7687 - val_loss: 2.2430 - val_acc: 0.4318 - val_top5-acc: 0.7338
Epoch 26/50
352/352 [==============================] - 9s 26ms/step - loss: 1.9988 - acc: 0.4616 - top5-acc: 0.7748 - val_loss: 2.2053 - val_acc: 0.4470 - val_top5-acc: 0.7366
Epoch 27/50
352/352 [==============================] - 9s 26ms/step - loss: 1.9788 - acc: 0.4646 - top5-acc: 0.7806 - val_loss: 2.2313 - val_acc: 0.4378 - val_top5-acc: 0.7420
Epoch 28/50
352/352 [==============================] - 9s 26ms/step - loss: 1.9702 - acc: 0.4688 - top5-acc: 0.7829 - val_loss: 2.2392 - val_acc: 0.4344 - val_top5-acc: 0.7338
Epoch 29/50
352/352 [==============================] - 9s 26ms/step - loss: 1.9488 - acc: 0.4699 - top5-acc: 0.7866 - val_loss: 2.1600 - val_acc: 0.4490 - val_top5-acc: 0.7446
Epoch 30/50
352/352 [==============================] - 9s 26ms/step - loss: 1.9302 - acc: 0.4803 - top5-acc: 0.7878 - val_loss: 2.2069 - val_acc: 0.4410 - val_top5-acc: 0.7486
Epoch 31/50
352/352 [==============================] - 9s 26ms/step - loss: 1.9135 - acc: 0.4806 - top5-acc: 0.7916 - val_loss: 2.1929 - val_acc: 0.4486 - val_top5-acc: 0.7514
Epoch 32/50
352/352 [==============================] - 9s 26ms/step - loss: 1.8890 - acc: 0.4844 - top5-acc: 0.7961 - val_loss: 2.2176 - val_acc: 0.4404 - val_top5-acc: 0.7494
Epoch 33/50
352/352 [==============================] - 9s 26ms/step - loss: 1.8844 - acc: 0.4872 - top5-acc: 0.7980 - val_loss: 2.2321 - val_acc: 0.4444 - val_top5-acc: 0.7460
Epoch 34/50
352/352 [==============================] - 9s 26ms/step - loss: 1.8588 - acc: 0.4912 - top5-acc: 0.8005 - val_loss: 2.1895 - val_acc: 0.4532 - val_top5-acc: 0.7510
Epoch 35/50
352/352 [==============================] - 9s 26ms/step - loss: 1.7259 - acc: 0.5232 - top5-acc: 0.8266 - val_loss: 2.1024 - val_acc: 0.4800 - val_top5-acc: 0.7726
Epoch 36/50
352/352 [==============================] - 9s 26ms/step - loss: 1.6262 - acc: 0.5488 - top5-acc: 0.8437 - val_loss: 2.0712 - val_acc: 0.4830 - val_top5-acc: 0.7754
Epoch 37/50
352/352 [==============================] - 9s 26ms/step - loss: 1.6164 - acc: 0.5481 - top5-acc: 0.8390 - val_loss: 2.1219 - val_acc: 0.4772 - val_top5-acc: 0.7678
Epoch 38/50
352/352 [==============================] - 9s 26ms/step - loss: 1.5850 - acc: 0.5568 - top5-acc: 0.8510 - val_loss: 2.0931 - val_acc: 0.4892 - val_top5-acc: 0.7732
Epoch 39/50
352/352 [==============================] - 9s 26ms/step - loss: 1.5741 - acc: 0.5589 - top5-acc: 0.8507 - val_loss: 2.0910 - val_acc: 0.4910 - val_top5-acc: 0.7700
Epoch 40/50
352/352 [==============================] - 9s 26ms/step - loss: 1.5546 - acc: 0.5675 - top5-acc: 0.8519 - val_loss: 2.1388 - val_acc: 0.4790 - val_top5-acc: 0.7742
Epoch 41/50
352/352 [==============================] - 9s 26ms/step - loss: 1.5464 - acc: 0.5684 - top5-acc: 0.8561 - val_loss: 2.1121 - val_acc: 0.4786 - val_top5-acc: 0.7718
Epoch 42/50
352/352 [==============================] - 9s 26ms/step - loss: 1.4494 - acc: 0.5890 - top5-acc: 0.8702 - val_loss: 2.1157 - val_acc: 0.4944 - val_top5-acc: 0.7802
Epoch 43/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3847 - acc: 0.6069 - top5-acc: 0.8825 - val_loss: 2.1048 - val_acc: 0.4884 - val_top5-acc: 0.7752
Epoch 44/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3724 - acc: 0.6087 - top5-acc: 0.8832 - val_loss: 2.0681 - val_acc: 0.4924 - val_top5-acc: 0.7868
Epoch 45/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3643 - acc: 0.6116 - top5-acc: 0.8840 - val_loss: 2.0965 - val_acc: 0.4932 - val_top5-acc: 0.7752
Epoch 46/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3517 - acc: 0.6184 - top5-acc: 0.8849 - val_loss: 2.0869 - val_acc: 0.4956 - val_top5-acc: 0.7778
Epoch 47/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3377 - acc: 0.6211 - top5-acc: 0.8891 - val_loss: 2.1120 - val_acc: 0.4882 - val_top5-acc: 0.7764
Epoch 48/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3369 - acc: 0.6186 - top5-acc: 0.8888 - val_loss: 2.1257 - val_acc: 0.4912 - val_top5-acc: 0.7752
Epoch 49/50
352/352 [==============================] - 9s 26ms/step - loss: 1.3266 - acc: 0.6190 - top5-acc: 0.8893 - val_loss: 2.0961 - val_acc: 0.4958 - val_top5-acc: 0.7828
Epoch 50/50
352/352 [==============================] - 9s 26ms/step - loss: 1.2731 - acc: 0.6352 - top5-acc: 0.8976 - val_loss: 2.0897 - val_acc: 0.4982 - val_top5-acc: 0.7788
313/313 [==============================] - 2s 7ms/step - loss: 2.0743 - acc: 0.5064 - top5-acc: 0.7828
Test accuracy: 50.64%
Test top 5 accuracy: 78.28%
Epoch 1/50
352/352 [==============================] - 16s 33ms/step - loss: 3.9500 - acc: 0.0943 - top5-acc: 0.2878 - val_loss: 3.5389 - val_acc: 0.1556 - val_top5-acc: 0.4148 - lr: 0.0030
Epoch 2/50
352/352 [==============================] - 11s 31ms/step - loss: 3.4664 - acc: 0.1684 - top5-acc: 0.4303 - val_loss: 3.2514 - val_acc: 0.2138 - val_top5-acc: 0.4884 - lr: 0.0030
Epoch 3/50
352/352 [==============================] - 11s 31ms/step - loss: 3.2333 - acc: 0.2090 - top5-acc: 0.4953 - val_loss: 3.0224 - val_acc: 0.2632 - val_top5-acc: 0.5558 - lr: 0.0030
Epoch 4/50
352/352 [==============================] - 11s 31ms/step - loss: 3.0509 - acc: 0.2416 - top5-acc: 0.5418 - val_loss: 2.9632 - val_acc: 0.2676 - val_top5-acc: 0.5686 - lr: 0.0030
Epoch 5/50
352/352 [==============================] - 11s 31ms/step - loss: 2.9120 - acc: 0.2712 - top5-acc: 0.5771 - val_loss: 2.8165 - val_acc: 0.2994 - val_top5-acc: 0.6044 - lr: 0.0030
Epoch 6/50
352/352 [==============================] - 11s 31ms/step - loss: 2.8101 - acc: 0.2913 - top5-acc: 0.6020 - val_loss: 2.7765 - val_acc: 0.3188 - val_top5-acc: 0.6260 - lr: 0.0030
Epoch 7/50
352/352 [==============================] - 11s 31ms/step - loss: 2.7284 - acc: 0.3098 - top5-acc: 0.6229 - val_loss: 2.6594 - val_acc: 0.3310 - val_top5-acc: 0.6378 - lr: 0.0030
Epoch 8/50
352/352 [==============================] - 11s 32ms/step - loss: 2.6453 - acc: 0.3277 - top5-acc: 0.6380 - val_loss: 2.5237 - val_acc: 0.3552 - val_top5-acc: 0.6672 - lr: 0.0030
Epoch 9/50
352/352 [==============================] - 11s 31ms/step - loss: 2.5360 - acc: 0.3508 - top5-acc: 0.6650 - val_loss: 2.4777 - val_acc: 0.3656 - val_top5-acc: 0.6828 - lr: 0.0030
Epoch 10/50
352/352 [==============================] - 11s 32ms/step - loss: 2.4609 - acc: 0.3641 - top5-acc: 0.6810 - val_loss: 2.4785 - val_acc: 0.3688 - val_top5-acc: 0.6886 - lr: 0.0030
Epoch 11/50
352/352 [==============================] - 11s 31ms/step - loss: 2.4225 - acc: 0.3691 - top5-acc: 0.6892 - val_loss: 2.4048 - val_acc: 0.3838 - val_top5-acc: 0.6954 - lr: 0.0030
Epoch 12/50
352/352 [==============================] - 11s 31ms/step - loss: 2.3725 - acc: 0.3785 - top5-acc: 0.7002 - val_loss: 2.3684 - val_acc: 0.3900 - val_top5-acc: 0.7060 - lr: 0.0030
Epoch 13/50
352/352 [==============================] - 11s 31ms/step - loss: 2.3262 - acc: 0.3930 - top5-acc: 0.7093 - val_loss: 2.3695 - val_acc: 0.3958 - val_top5-acc: 0.7060 - lr: 0.0030
Epoch 14/50
352/352 [==============================] - 11s 31ms/step - loss: 2.2951 - acc: 0.3994 - top5-acc: 0.7148 - val_loss: 2.3454 - val_acc: 0.4022 - val_top5-acc: 0.7134 - lr: 0.0030
Epoch 15/50
352/352 [==============================] - 11s 32ms/step - loss: 2.2667 - acc: 0.4046 - top5-acc: 0.7211 - val_loss: 2.3657 - val_acc: 0.4024 - val_top5-acc: 0.7124 - lr: 0.0030
Epoch 16/50
352/352 [==============================] - 11s 31ms/step - loss: 2.2309 - acc: 0.4122 - top5-acc: 0.7277 - val_loss: 2.3058 - val_acc: 0.4024 - val_top5-acc: 0.7166 - lr: 0.0030
Epoch 17/50
352/352 [==============================] - 11s 32ms/step - loss: 2.1990 - acc: 0.4182 - top5-acc: 0.7345 - val_loss: 2.2523 - val_acc: 0.4194 - val_top5-acc: 0.7296 - lr: 0.0030
Epoch 18/50
352/352 [==============================] - 11s 31ms/step - loss: 2.1832 - acc: 0.4241 - top5-acc: 0.7386 - val_loss: 2.2812 - val_acc: 0.4130 - val_top5-acc: 0.7230 - lr: 0.0030
Epoch 19/50
352/352 [==============================] - 11s 31ms/step - loss: 2.1573 - acc: 0.4281 - top5-acc: 0.7437 - val_loss: 2.2921 - val_acc: 0.4182 - val_top5-acc: 0.7276 - lr: 0.0030
Epoch 20/50
352/352 [==============================] - 11s 32ms/step - loss: 2.1399 - acc: 0.4320 - top5-acc: 0.7481 - val_loss: 2.2691 - val_acc: 0.4270 - val_top5-acc: 0.7278 - lr: 0.0030
Epoch 21/50
352/352 [==============================] - 11s 32ms/step - loss: 2.1173 - acc: 0.4381 - top5-acc: 0.7522 - val_loss: 2.2364 - val_acc: 0.4186 - val_top5-acc: 0.7364 - lr: 0.0030
Epoch 22/50
352/352 [==============================] - 11s 32ms/step - loss: 2.0932 - acc: 0.4398 - top5-acc: 0.7575 - val_loss: 2.2614 - val_acc: 0.4218 - val_top5-acc: 0.7352 - lr: 0.0030
Epoch 23/50
352/352 [==============================] - 11s 32ms/step - loss: 2.0779 - acc: 0.4454 - top5-acc: 0.7583 - val_loss: 2.2383 - val_acc: 0.4248 - val_top5-acc: 0.7370 - lr: 0.0030
Epoch 24/50
352/352 [==============================] - 11s 32ms/step - loss: 2.0566 - acc: 0.4508 - top5-acc: 0.7636 - val_loss: 2.1919 - val_acc: 0.4440 - val_top5-acc: 0.7458 - lr: 0.0030
Epoch 25/50
352/352 [==============================] - 11s 31ms/step - loss: 2.0332 - acc: 0.4550 - top5-acc: 0.7682 - val_loss: 2.1731 - val_acc: 0.4412 - val_top5-acc: 0.7398 - lr: 0.0030
Epoch 26/50
352/352 [==============================] - 11s 31ms/step - loss: 2.0127 - acc: 0.4606 - top5-acc: 0.7705 - val_loss: 2.2456 - val_acc: 0.4392 - val_top5-acc: 0.7402 - lr: 0.0030
Epoch 27/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9999 - acc: 0.4626 - top5-acc: 0.7752 - val_loss: 2.1989 - val_acc: 0.4420 - val_top5-acc: 0.7488 - lr: 0.0030
Epoch 28/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9818 - acc: 0.4666 - top5-acc: 0.7791 - val_loss: 2.2228 - val_acc: 0.4408 - val_top5-acc: 0.7446 - lr: 0.0030
Epoch 29/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9701 - acc: 0.4687 - top5-acc: 0.7794 - val_loss: 2.1977 - val_acc: 0.4452 - val_top5-acc: 0.7518 - lr: 0.0030
Epoch 30/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9478 - acc: 0.4711 - top5-acc: 0.7843 - val_loss: 2.1515 - val_acc: 0.4562 - val_top5-acc: 0.7540 - lr: 0.0030
Epoch 31/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9262 - acc: 0.4799 - top5-acc: 0.7885 - val_loss: 2.1403 - val_acc: 0.4546 - val_top5-acc: 0.7574 - lr: 0.0030
Epoch 32/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9224 - acc: 0.4808 - top5-acc: 0.7881 - val_loss: 2.2336 - val_acc: 0.4492 - val_top5-acc: 0.7488 - lr: 0.0030
Epoch 33/50
352/352 [==============================] - 11s 31ms/step - loss: 1.9003 - acc: 0.4831 - top5-acc: 0.7960 - val_loss: 2.1563 - val_acc: 0.4580 - val_top5-acc: 0.7518 - lr: 0.0030
Epoch 34/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8849 - acc: 0.4872 - top5-acc: 0.7964 - val_loss: 2.1260 - val_acc: 0.4646 - val_top5-acc: 0.7588 - lr: 0.0030
Epoch 35/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8782 - acc: 0.4892 - top5-acc: 0.8003 - val_loss: 2.1438 - val_acc: 0.4616 - val_top5-acc: 0.7590 - lr: 0.0030
Epoch 36/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8659 - acc: 0.4924 - top5-acc: 0.8025 - val_loss: 2.0792 - val_acc: 0.4728 - val_top5-acc: 0.7626 - lr: 0.0030
Epoch 37/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8433 - acc: 0.4976 - top5-acc: 0.8045 - val_loss: 2.2000 - val_acc: 0.4554 - val_top5-acc: 0.7602 - lr: 0.0030
Epoch 38/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8371 - acc: 0.5003 - top5-acc: 0.8056 - val_loss: 2.1494 - val_acc: 0.4590 - val_top5-acc: 0.7620 - lr: 0.0030
Epoch 39/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8322 - acc: 0.5011 - top5-acc: 0.8076 - val_loss: 2.1440 - val_acc: 0.4542 - val_top5-acc: 0.7572 - lr: 0.0030
Epoch 40/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8199 - acc: 0.5009 - top5-acc: 0.8107 - val_loss: 2.0831 - val_acc: 0.4710 - val_top5-acc: 0.7674 - lr: 0.0030
Epoch 41/50
352/352 [==============================] - 11s 31ms/step - loss: 1.8071 - acc: 0.5068 - top5-acc: 0.8114 - val_loss: 2.0868 - val_acc: 0.4670 - val_top5-acc: 0.7700 - lr: 0.0030
Epoch 42/50
352/352 [==============================] - 11s 31ms/step - loss: 1.6200 - acc: 0.5499 - top5-acc: 0.8403 - val_loss: 2.0866 - val_acc: 0.4784 - val_top5-acc: 0.7798 - lr: 0.0015
Epoch 43/50
352/352 [==============================] - 11s 31ms/step - loss: 1.5655 - acc: 0.5599 - top5-acc: 0.8532 - val_loss: 2.0813 - val_acc: 0.4888 - val_top5-acc: 0.7778 - lr: 0.0015
Epoch 44/50
352/352 [==============================] - 11s 31ms/step - loss: 1.5466 - acc: 0.5687 - top5-acc: 0.8546 - val_loss: 2.1103 - val_acc: 0.4890 - val_top5-acc: 0.7786 - lr: 0.0015
Epoch 45/50
352/352 [==============================] - 11s 31ms/step - loss: 1.5350 - acc: 0.5717 - top5-acc: 0.8564 - val_loss: 2.1715 - val_acc: 0.4794 - val_top5-acc: 0.7674 - lr: 0.0015
Epoch 46/50
352/352 [==============================] - 11s 31ms/step - loss: 1.5246 - acc: 0.5667 - top5-acc: 0.8585 - val_loss: 2.0667 - val_acc: 0.4924 - val_top5-acc: 0.7832 - lr: 0.0015
Epoch 47/50
352/352 [==============================] - 11s 31ms/step - loss: 1.5062 - acc: 0.5740 - top5-acc: 0.8621 - val_loss: 2.0809 - val_acc: 0.4932 - val_top5-acc: 0.7870 - lr: 0.0015
Epoch 48/50
352/352 [==============================] - 11s 31ms/step - loss: 1.5102 - acc: 0.5743 - top5-acc: 0.8620 - val_loss: 2.0898 - val_acc: 0.4820 - val_top5-acc: 0.7862 - lr: 0.0015
Epoch 49/50
352/352 [==============================] - 11s 31ms/step - loss: 1.4992 - acc: 0.5784 - top5-acc: 0.8638 - val_loss: 2.1189 - val_acc: 0.4834 - val_top5-acc: 0.7750 - lr: 0.0015
Epoch 50/50
352/352 [==============================] - 11s 31ms/step - loss: 1.4927 - acc: 0.5791 - top5-acc: 0.8659 - val_loss: 2.1429 - val_acc: 0.4796 - val_top5-acc: 0.7790 - lr: 0.0015
313/313 [==============================] - 3s 9ms/step - loss: 2.0755 - acc: 0.4963 - top5-acc: 0.7826
Test accuracy: 49.63%
Test top 5 accuracy: 78.26%
CPU times: user 10min 9s, sys: 39.4 s, total: 10min 48s
Wall time: 9min 20s

gMLP 論文で述べられているように、埋め込み次元を増やし、gMLP ブロックの数を増やし、そしてモデルをより長く訓練することでより良い結果を得られます。入力画像のサイズを大きくして異なるパッチサイズを使用することを試しても良いでしょう。論文は MixUp と CutMix、そして AutoAugment のような高度な正則化ストラテジーを使用したことに注意してください。

 

以上



クラスキャット

最近の投稿

  • LangGraph 0.5 : エージェント開発 : ワークフローとエージェント
  • LangGraph 0.5 : エージェント開発 : エージェントの実行
  • LangGraph 0.5 : エージェント開発 : prebuilt コンポーネントを使用したエージェント開発
  • LangGraph 0.5 : Get started : ローカルサーバの実行
  • LangGraph 0.5 on Colab : Get started : human-in-the-loop 制御の追加

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (24) LangGraph 0.5 (8) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年12月
月 火 水 木 金 土 日
 12345
6789101112
13141516171819
20212223242526
2728293031  
« 11月   3月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme