Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – EANet (外部注意 Transformer) で画像分類

Posted on 11/18/202111/23/2021 by Sales Information

Keras 2 : examples : EANet (外部注意 Transformer) で画像分類 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/18/2021 (keras 2.7.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : Image classification with EANet (External Attention Transformer) (Author: ZhiYong Chang)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス ★ 無料 Web セミナー開催中 ★

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援

  • テレワーク & オンライン授業を支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • E-Mail:sales-info@classcat.com  ;  WebSite: www.classcat.com  ;  Facebook

 

 

Keras 2 : examples : EANet (外部注意 Transformer) で画像分類

Description: 外部注意を活用する Transformer による画像分類。

 

イントロダクション

このサンプルは画像分類のための EANet モデルを実装し、それを CIFAR-100 データセット上で実演します。2 つの外部の (= external), 小さく, 学習可能でそして共有メモリに基づいた、外部注意 (= external attention) と呼ばれる、EANet は新規の注意 (= attention) メカニズムを導入します、これは 2 つのカスケード線形層と 2 つの正規化層を単純に使用して簡単に実装できます。それは既存のアーキテクチャで使用されていた自己注意 (= self-attention) を都合よく置き換えます。外部注意は、総てのサンプル間の相関関係を暗黙的に考えるだけなので、線形複雑度を持ちます。

このサンプルは TensorFlow 2.5 またはそれ以上、そして TensorFlow Addons パッケージを必要とします、これは次のコマンドを使用してインストールできます :

pip install -U tensorflow-addons

 

セットアップ

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

 

データの準備

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 100)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100)

 

ハイパーパラメータの設定

weight_decay = 0.0001
learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
patch_size = 2  # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2  # Number of patch
embedding_dim = 64  # Number of hidden units.
mlp_dim = 64
dim_coefficient = 4
num_heads = 4
attention_dropout = 0.2
projection_dropout = 0.2
num_transformer_blocks = 8  # Number of repetitions of the transformer layer

print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
Patch size: 2 X 2 = 4 
Patches per image: 256

 

データ増強の利用

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.1),
        layers.RandomContrast(factor=0.1),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

 

パッチ抽出とエンコーディング層の実装

class PatchExtract(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super(PatchExtract, self).__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=(1, self.patch_size, self.patch_size, 1),
            strides=(1, self.patch_size, self.patch_size, 1),
            rates=(1, 1, 1, 1),
            padding="VALID",
        )
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super(PatchEmbedding, self).__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        return self.proj(patch) + self.pos_embed(pos)

 

外部注意ブロックの実装

def external_attention(
    x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
):
    _, num_patch, channel = x.shape
    assert dim % num_heads == 0
    num_heads = num_heads * dim_coefficient

    x = layers.Dense(dim * dim_coefficient)(x)
    # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
    x = tf.reshape(
        x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
    )
    x = tf.transpose(x, perm=[0, 2, 1, 3])
    # a linear layer M_k
    attn = layers.Dense(dim // dim_coefficient)(x)
    # normalize attention map
    attn = layers.Softmax(axis=2)(attn)
    # dobule-normalization
    attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))
    attn = layers.Dropout(attention_dropout)(attn)
    # a linear layer M_v
    x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
    x = tf.transpose(x, perm=[0, 2, 1, 3])
    x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])
    # a linear layer to project original dim
    x = layers.Dense(dim)(x)
    x = layers.Dropout(projection_dropout)(x)
    return x

 

MLP ブロックの実装

def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
    x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
    x = layers.Dropout(drop_rate)(x)
    x = layers.Dense(embedding_dim)(x)
    x = layers.Dropout(drop_rate)(x)
    return x

 

Transformer ブロックの実装

def transformer_encoder(
    x,
    embedding_dim,
    mlp_dim,
    num_heads,
    dim_coefficient,
    attention_dropout,
    projection_dropout,
    attention_type="external_attention",
):
    residual_1 = x
    x = layers.LayerNormalization(epsilon=1e-5)(x)
    if attention_type == "external_attention":
        x = external_attention(
            x,
            embedding_dim,
            num_heads,
            dim_coefficient,
            attention_dropout,
            projection_dropout,
        )
    elif attention_type == "self_attention":
        x = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
        )(x, x)
    x = layers.add([x, residual_1])
    residual_2 = x
    x = layers.LayerNormalization(epsilon=1e-5)(x)
    x = mlp(x, embedding_dim, mlp_dim)
    x = layers.add([x, residual_2])
    return x

 

EANet モデルの実装

EANet モデルは外部アテンションを活用しています。従来の自己注意の計算複雑度は O(d * N ** 2) です、ここで d は埋め込みサイズで、N はパッチの数です。著者らは殆どのピクセルは幾つかの他のピクセルだけに密接に関係していて、N 対 N の注意行列は冗長であるかもしれないことを見出しました。そこで、彼らは代替として外部注意モジュールを提案しました、ここで外部注意の計算複雑度は O(d * S * N) です。d と S はハイパーパラメータですので、提案されたアルゴリズムはピクセル数内で線形です。実際には、これは drop パッチ演算に等値です、何故ならば画像内のパッチに含まれる多くの情報は冗長で重要ではないからです。

def get_model(attention_type="external_attention"):
    inputs = layers.Input(shape=input_shape)
    # Image augment
    x = data_augmentation(inputs)
    # Extract patches.
    x = PatchExtract(patch_size)(x)
    # Create patch embedding.
    x = PatchEmbedding(num_patches, embedding_dim)(x)
    # Create Transformer block.
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(
            x,
            embedding_dim,
            mlp_dim,
            num_heads,
            dim_coefficient,
            attention_dropout,
            projection_dropout,
            attention_type,
        )

    x = layers.GlobalAvgPool1D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

 

CIFAR-100 上の訓練

model = get_model(attention_type="external_attention")

model.compile(
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

history = model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_split=validation_split,
)
Epoch 1/50
313/313 [==============================] - 40s 95ms/step - loss: 4.2091 - accuracy: 0.0723 - top-5-accuracy: 0.2384 - val_loss: 3.9706 - val_accuracy: 0.1153 - val_top-5-accuracy: 0.3336
Epoch 2/50
313/313 [==============================] - 29s 91ms/step - loss: 3.8028 - accuracy: 0.1427 - top-5-accuracy: 0.3871 - val_loss: 3.6672 - val_accuracy: 0.1829 - val_top-5-accuracy: 0.4513
Epoch 3/50
313/313 [==============================] - 29s 93ms/step - loss: 3.5493 - accuracy: 0.1978 - top-5-accuracy: 0.4805 - val_loss: 3.5402 - val_accuracy: 0.2141 - val_top-5-accuracy: 0.5038
Epoch 4/50
313/313 [==============================] - 29s 93ms/step - loss: 3.4029 - accuracy: 0.2355 - top-5-accuracy: 0.5328 - val_loss: 3.4496 - val_accuracy: 0.2354 - val_top-5-accuracy: 0.5316
Epoch 5/50
313/313 [==============================] - 29s 92ms/step - loss: 3.2917 - accuracy: 0.2636 - top-5-accuracy: 0.5678 - val_loss: 3.3342 - val_accuracy: 0.2699 - val_top-5-accuracy: 0.5679
Epoch 6/50
313/313 [==============================] - 29s 92ms/step - loss: 3.2116 - accuracy: 0.2830 - top-5-accuracy: 0.5921 - val_loss: 3.2896 - val_accuracy: 0.2749 - val_top-5-accuracy: 0.5874
Epoch 7/50
313/313 [==============================] - 28s 90ms/step - loss: 3.1453 - accuracy: 0.2980 - top-5-accuracy: 0.6100 - val_loss: 3.3090 - val_accuracy: 0.2857 - val_top-5-accuracy: 0.5831
Epoch 8/50
313/313 [==============================] - 29s 94ms/step - loss: 3.0889 - accuracy: 0.3121 - top-5-accuracy: 0.6266 - val_loss: 3.1969 - val_accuracy: 0.2975 - val_top-5-accuracy: 0.6082
Epoch 9/50
313/313 [==============================] - 29s 92ms/step - loss: 3.0390 - accuracy: 0.3252 - top-5-accuracy: 0.6441 - val_loss: 3.1249 - val_accuracy: 0.3175 - val_top-5-accuracy: 0.6330
Epoch 10/50
313/313 [==============================] - 29s 92ms/step - loss: 2.9871 - accuracy: 0.3365 - top-5-accuracy: 0.6615 - val_loss: 3.1121 - val_accuracy: 0.3200 - val_top-5-accuracy: 0.6374
Epoch 11/50
313/313 [==============================] - 29s 92ms/step - loss: 2.9476 - accuracy: 0.3489 - top-5-accuracy: 0.6697 - val_loss: 3.1156 - val_accuracy: 0.3268 - val_top-5-accuracy: 0.6421
Epoch 12/50
313/313 [==============================] - 29s 91ms/step - loss: 2.9106 - accuracy: 0.3576 - top-5-accuracy: 0.6783 - val_loss: 3.1337 - val_accuracy: 0.3226 - val_top-5-accuracy: 0.6389
Epoch 13/50
313/313 [==============================] - 29s 92ms/step - loss: 2.8772 - accuracy: 0.3662 - top-5-accuracy: 0.6871 - val_loss: 3.0373 - val_accuracy: 0.3348 - val_top-5-accuracy: 0.6624
Epoch 14/50
313/313 [==============================] - 29s 92ms/step - loss: 2.8508 - accuracy: 0.3756 - top-5-accuracy: 0.6944 - val_loss: 3.0297 - val_accuracy: 0.3441 - val_top-5-accuracy: 0.6643
Epoch 15/50
313/313 [==============================] - 28s 90ms/step - loss: 2.8211 - accuracy: 0.3821 - top-5-accuracy: 0.7034 - val_loss: 2.9680 - val_accuracy: 0.3604 - val_top-5-accuracy: 0.6847
Epoch 16/50
313/313 [==============================] - 28s 90ms/step - loss: 2.8017 - accuracy: 0.3864 - top-5-accuracy: 0.7090 - val_loss: 2.9746 - val_accuracy: 0.3584 - val_top-5-accuracy: 0.6855
Epoch 17/50
313/313 [==============================] - 29s 91ms/step - loss: 2.7714 - accuracy: 0.3962 - top-5-accuracy: 0.7169 - val_loss: 2.9104 - val_accuracy: 0.3738 - val_top-5-accuracy: 0.6940
Epoch 18/50
313/313 [==============================] - 29s 92ms/step - loss: 2.7523 - accuracy: 0.4008 - top-5-accuracy: 0.7204 - val_loss: 2.8560 - val_accuracy: 0.3861 - val_top-5-accuracy: 0.7115
Epoch 19/50
313/313 [==============================] - 28s 91ms/step - loss: 2.7320 - accuracy: 0.4051 - top-5-accuracy: 0.7263 - val_loss: 2.8780 - val_accuracy: 0.3820 - val_top-5-accuracy: 0.7101
Epoch 20/50
313/313 [==============================] - 28s 90ms/step - loss: 2.7139 - accuracy: 0.4114 - top-5-accuracy: 0.7290 - val_loss: 2.9831 - val_accuracy: 0.3694 - val_top-5-accuracy: 0.6922
Epoch 21/50
313/313 [==============================] - 28s 91ms/step - loss: 2.6991 - accuracy: 0.4142 - top-5-accuracy: 0.7335 - val_loss: 2.8420 - val_accuracy: 0.3968 - val_top-5-accuracy: 0.7138
Epoch 22/50
313/313 [==============================] - 29s 91ms/step - loss: 2.6842 - accuracy: 0.4195 - top-5-accuracy: 0.7377 - val_loss: 2.7965 - val_accuracy: 0.4088 - val_top-5-accuracy: 0.7266
Epoch 23/50
313/313 [==============================] - 28s 91ms/step - loss: 2.6571 - accuracy: 0.4273 - top-5-accuracy: 0.7436 - val_loss: 2.8620 - val_accuracy: 0.3947 - val_top-5-accuracy: 0.7155
Epoch 24/50
313/313 [==============================] - 29s 91ms/step - loss: 2.6508 - accuracy: 0.4277 - top-5-accuracy: 0.7469 - val_loss: 2.8459 - val_accuracy: 0.3963 - val_top-5-accuracy: 0.7150
Epoch 25/50
313/313 [==============================] - 28s 90ms/step - loss: 2.6403 - accuracy: 0.4283 - top-5-accuracy: 0.7520 - val_loss: 2.7886 - val_accuracy: 0.4128 - val_top-5-accuracy: 0.7283
Epoch 26/50
313/313 [==============================] - 29s 92ms/step - loss: 2.6281 - accuracy: 0.4353 - top-5-accuracy: 0.7523 - val_loss: 2.8493 - val_accuracy: 0.4026 - val_top-5-accuracy: 0.7153
Epoch 27/50
313/313 [==============================] - 29s 92ms/step - loss: 2.6092 - accuracy: 0.4403 - top-5-accuracy: 0.7580 - val_loss: 2.7539 - val_accuracy: 0.4186 - val_top-5-accuracy: 0.7392
Epoch 28/50
313/313 [==============================] - 29s 91ms/step - loss: 2.5992 - accuracy: 0.4423 - top-5-accuracy: 0.7600 - val_loss: 2.8625 - val_accuracy: 0.3964 - val_top-5-accuracy: 0.7174
Epoch 29/50
313/313 [==============================] - 28s 90ms/step - loss: 2.5913 - accuracy: 0.4456 - top-5-accuracy: 0.7598 - val_loss: 2.7911 - val_accuracy: 0.4162 - val_top-5-accuracy: 0.7329
Epoch 30/50
313/313 [==============================] - 29s 92ms/step - loss: 2.5780 - accuracy: 0.4480 - top-5-accuracy: 0.7649 - val_loss: 2.8158 - val_accuracy: 0.4118 - val_top-5-accuracy: 0.7288
Epoch 31/50
313/313 [==============================] - 28s 91ms/step - loss: 2.5657 - accuracy: 0.4547 - top-5-accuracy: 0.7661 - val_loss: 2.8651 - val_accuracy: 0.4056 - val_top-5-accuracy: 0.7217
Epoch 32/50
313/313 [==============================] - 29s 91ms/step - loss: 2.5637 - accuracy: 0.4480 - top-5-accuracy: 0.7681 - val_loss: 2.8190 - val_accuracy: 0.4094 - val_top-5-accuracy: 0.7267
Epoch 33/50
313/313 [==============================] - 29s 92ms/step - loss: 2.5525 - accuracy: 0.4545 - top-5-accuracy: 0.7693 - val_loss: 2.7985 - val_accuracy: 0.4216 - val_top-5-accuracy: 0.7303
Epoch 34/50
313/313 [==============================] - 28s 91ms/step - loss: 2.5462 - accuracy: 0.4579 - top-5-accuracy: 0.7721 - val_loss: 2.8865 - val_accuracy: 0.4016 - val_top-5-accuracy: 0.7204
Epoch 35/50
313/313 [==============================] - 29s 92ms/step - loss: 2.5329 - accuracy: 0.4616 - top-5-accuracy: 0.7740 - val_loss: 2.7862 - val_accuracy: 0.4232 - val_top-5-accuracy: 0.7389
Epoch 36/50
313/313 [==============================] - 28s 90ms/step - loss: 2.5234 - accuracy: 0.4610 - top-5-accuracy: 0.7765 - val_loss: 2.8234 - val_accuracy: 0.4134 - val_top-5-accuracy: 0.7312
Epoch 37/50
313/313 [==============================] - 29s 91ms/step - loss: 2.5152 - accuracy: 0.4663 - top-5-accuracy: 0.7774 - val_loss: 2.7894 - val_accuracy: 0.4161 - val_top-5-accuracy: 0.7376
Epoch 38/50
313/313 [==============================] - 29s 92ms/step - loss: 2.5117 - accuracy: 0.4674 - top-5-accuracy: 0.7790 - val_loss: 2.8091 - val_accuracy: 0.4142 - val_top-5-accuracy: 0.7360
Epoch 39/50
313/313 [==============================] - 28s 90ms/step - loss: 2.5047 - accuracy: 0.4681 - top-5-accuracy: 0.7805 - val_loss: 2.8199 - val_accuracy: 0.4167 - val_top-5-accuracy: 0.7299
Epoch 40/50
313/313 [==============================] - 28s 90ms/step - loss: 2.4974 - accuracy: 0.4697 - top-5-accuracy: 0.7819 - val_loss: 2.7864 - val_accuracy: 0.4247 - val_top-5-accuracy: 0.7402
Epoch 41/50
313/313 [==============================] - 28s 90ms/step - loss: 2.4889 - accuracy: 0.4749 - top-5-accuracy: 0.7854 - val_loss: 2.8120 - val_accuracy: 0.4217 - val_top-5-accuracy: 0.7358
Epoch 42/50
313/313 [==============================] - 28s 90ms/step - loss: 2.4799 - accuracy: 0.4771 - top-5-accuracy: 0.7866 - val_loss: 2.9003 - val_accuracy: 0.4038 - val_top-5-accuracy: 0.7170
Epoch 43/50
313/313 [==============================] - 28s 90ms/step - loss: 2.4814 - accuracy: 0.4770 - top-5-accuracy: 0.7868 - val_loss: 2.7504 - val_accuracy: 0.4260 - val_top-5-accuracy: 0.7457
Epoch 44/50
313/313 [==============================] - 28s 91ms/step - loss: 2.4747 - accuracy: 0.4757 - top-5-accuracy: 0.7870 - val_loss: 2.8207 - val_accuracy: 0.4166 - val_top-5-accuracy: 0.7363
Epoch 45/50
313/313 [==============================] - 28s 90ms/step - loss: 2.4653 - accuracy: 0.4809 - top-5-accuracy: 0.7924 - val_loss: 2.8663 - val_accuracy: 0.4130 - val_top-5-accuracy: 0.7209
Epoch 46/50
313/313 [==============================] - 28s 90ms/step - loss: 2.4554 - accuracy: 0.4825 - top-5-accuracy: 0.7929 - val_loss: 2.8145 - val_accuracy: 0.4250 - val_top-5-accuracy: 0.7357
Epoch 47/50
313/313 [==============================] - 29s 91ms/step - loss: 2.4602 - accuracy: 0.4823 - top-5-accuracy: 0.7919 - val_loss: 2.8352 - val_accuracy: 0.4189 - val_top-5-accuracy: 0.7365
Epoch 48/50
313/313 [==============================] - 28s 91ms/step - loss: 2.4493 - accuracy: 0.4848 - top-5-accuracy: 0.7933 - val_loss: 2.8246 - val_accuracy: 0.4160 - val_top-5-accuracy: 0.7362
Epoch 49/50
313/313 [==============================] - 28s 91ms/step - loss: 2.4454 - accuracy: 0.4846 - top-5-accuracy: 0.7958 - val_loss: 2.7731 - val_accuracy: 0.4320 - val_top-5-accuracy: 0.7436
Epoch 50/50
313/313 [==============================] - 29s 92ms/step - loss: 2.4418 - accuracy: 0.4848 - top-5-accuracy: 0.7951 - val_loss: 2.7926 - val_accuracy: 0.4317 - val_top-5-accuracy: 0.7410

(訳注: 実験結果)

Epoch 1/50
313/313 [==============================] - 89s 235ms/step - loss: 4.2106 - accuracy: 0.0749 - top-5-accuracy: 0.2393 - val_loss: 3.9651 - val_accuracy: 0.1128 - val_top-5-accuracy: 0.3299
Epoch 2/50
313/313 [==============================] - 71s 228ms/step - loss: 3.8254 - accuracy: 0.1354 - top-5-accuracy: 0.3769 - val_loss: 3.7533 - val_accuracy: 0.1687 - val_top-5-accuracy: 0.4227
Epoch 3/50
313/313 [==============================] - 71s 228ms/step - loss: 3.5754 - accuracy: 0.1926 - top-5-accuracy: 0.4712 - val_loss: 3.5608 - val_accuracy: 0.2143 - val_top-5-accuracy: 0.4880
Epoch 4/50
313/313 [==============================] - 71s 228ms/step - loss: 3.4059 - accuracy: 0.2330 - top-5-accuracy: 0.5311 - val_loss: 3.4499 - val_accuracy: 0.2435 - val_top-5-accuracy: 0.5373
Epoch 5/50
313/313 [==============================] - 71s 228ms/step - loss: 3.2898 - accuracy: 0.2609 - top-5-accuracy: 0.5681 - val_loss: 3.4805 - val_accuracy: 0.2453 - val_top-5-accuracy: 0.5369
Epoch 6/50
313/313 [==============================] - 72s 229ms/step - loss: 3.2073 - accuracy: 0.2839 - top-5-accuracy: 0.5903 - val_loss: 3.2480 - val_accuracy: 0.2875 - val_top-5-accuracy: 0.5980
Epoch 7/50
313/313 [==============================] - 72s 229ms/step - loss: 3.1330 - accuracy: 0.3011 - top-5-accuracy: 0.6156 - val_loss: 3.2551 - val_accuracy: 0.2961 - val_top-5-accuracy: 0.6029
Epoch 8/50
313/313 [==============================] - 71s 228ms/step - loss: 3.0821 - accuracy: 0.3138 - top-5-accuracy: 0.6290 - val_loss: 3.1370 - val_accuracy: 0.3131 - val_top-5-accuracy: 0.6247
Epoch 9/50
313/313 [==============================] - 71s 228ms/step - loss: 3.0294 - accuracy: 0.3293 - top-5-accuracy: 0.6432 - val_loss: 3.1374 - val_accuracy: 0.3168 - val_top-5-accuracy: 0.6274
Epoch 10/50
313/313 [==============================] - 71s 228ms/step - loss: 2.9877 - accuracy: 0.3368 - top-5-accuracy: 0.6556 - val_loss: 3.0930 - val_accuracy: 0.3265 - val_top-5-accuracy: 0.6378
Epoch 11/50
313/313 [==============================] - 71s 228ms/step - loss: 2.9496 - accuracy: 0.3476 - top-5-accuracy: 0.6667 - val_loss: 3.1077 - val_accuracy: 0.3248 - val_top-5-accuracy: 0.6457
Epoch 12/50
313/313 [==============================] - 71s 228ms/step - loss: 2.9140 - accuracy: 0.3572 - top-5-accuracy: 0.6773 - val_loss: 3.1588 - val_accuracy: 0.3226 - val_top-5-accuracy: 0.6365
Epoch 13/50
313/313 [==============================] - 71s 228ms/step - loss: 2.8865 - accuracy: 0.3657 - top-5-accuracy: 0.6843 - val_loss: 3.0184 - val_accuracy: 0.3506 - val_top-5-accuracy: 0.6682
Epoch 14/50
313/313 [==============================] - 71s 228ms/step - loss: 2.8529 - accuracy: 0.3716 - top-5-accuracy: 0.6951 - val_loss: 3.0481 - val_accuracy: 0.3482 - val_top-5-accuracy: 0.6647
Epoch 15/50
313/313 [==============================] - 71s 228ms/step - loss: 2.8306 - accuracy: 0.3811 - top-5-accuracy: 0.6994 - val_loss: 2.9535 - val_accuracy: 0.3622 - val_top-5-accuracy: 0.6831
Epoch 16/50
313/313 [==============================] - 71s 228ms/step - loss: 2.8069 - accuracy: 0.3829 - top-5-accuracy: 0.7057 - val_loss: 2.9654 - val_accuracy: 0.3645 - val_top-5-accuracy: 0.6820
Epoch 17/50
313/313 [==============================] - 71s 228ms/step - loss: 2.7879 - accuracy: 0.3925 - top-5-accuracy: 0.7102 - val_loss: 2.9547 - val_accuracy: 0.3553 - val_top-5-accuracy: 0.6861
Epoch 18/50
313/313 [==============================] - 71s 228ms/step - loss: 2.7626 - accuracy: 0.3987 - top-5-accuracy: 0.7151 - val_loss: 2.9865 - val_accuracy: 0.3645 - val_top-5-accuracy: 0.6881
Epoch 19/50
313/313 [==============================] - 71s 228ms/step - loss: 2.7427 - accuracy: 0.4040 - top-5-accuracy: 0.7229 - val_loss: 2.9270 - val_accuracy: 0.3765 - val_top-5-accuracy: 0.6984
Epoch 20/50
313/313 [==============================] - 72s 229ms/step - loss: 2.7197 - accuracy: 0.4107 - top-5-accuracy: 0.7306 - val_loss: 2.9114 - val_accuracy: 0.3818 - val_top-5-accuracy: 0.6994
Epoch 21/50
313/313 [==============================] - 72s 229ms/step - loss: 2.7174 - accuracy: 0.4099 - top-5-accuracy: 0.7289 - val_loss: 2.9156 - val_accuracy: 0.3873 - val_top-5-accuracy: 0.7029
Epoch 22/50
313/313 [==============================] - 71s 228ms/step - loss: 2.6940 - accuracy: 0.4160 - top-5-accuracy: 0.7359 - val_loss: 2.9074 - val_accuracy: 0.3792 - val_top-5-accuracy: 0.6990
Epoch 23/50
313/313 [==============================] - 71s 228ms/step - loss: 2.6783 - accuracy: 0.4202 - top-5-accuracy: 0.7384 - val_loss: 2.9476 - val_accuracy: 0.3776 - val_top-5-accuracy: 0.6942
Epoch 24/50
313/313 [==============================] - 71s 227ms/step - loss: 2.6660 - accuracy: 0.4224 - top-5-accuracy: 0.7427 - val_loss: 2.8997 - val_accuracy: 0.3904 - val_top-5-accuracy: 0.7059
Epoch 25/50
313/313 [==============================] - 71s 228ms/step - loss: 2.6528 - accuracy: 0.4274 - top-5-accuracy: 0.7451 - val_loss: 2.8776 - val_accuracy: 0.3933 - val_top-5-accuracy: 0.7135
Epoch 26/50
313/313 [==============================] - 71s 228ms/step - loss: 2.6355 - accuracy: 0.4307 - top-5-accuracy: 0.7491 - val_loss: 2.8783 - val_accuracy: 0.3886 - val_top-5-accuracy: 0.7135
Epoch 27/50
313/313 [==============================] - 72s 229ms/step - loss: 2.6220 - accuracy: 0.4385 - top-5-accuracy: 0.7518 - val_loss: 2.8804 - val_accuracy: 0.3901 - val_top-5-accuracy: 0.7170
Epoch 28/50
313/313 [==============================] - 71s 227ms/step - loss: 2.6163 - accuracy: 0.4362 - top-5-accuracy: 0.7545 - val_loss: 2.9178 - val_accuracy: 0.3892 - val_top-5-accuracy: 0.7044
Epoch 29/50
313/313 [==============================] - 71s 228ms/step - loss: 2.6035 - accuracy: 0.4400 - top-5-accuracy: 0.7572 - val_loss: 2.9148 - val_accuracy: 0.3905 - val_top-5-accuracy: 0.7070
Epoch 30/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5975 - accuracy: 0.4419 - top-5-accuracy: 0.7587 - val_loss: 2.8426 - val_accuracy: 0.4059 - val_top-5-accuracy: 0.7250
Epoch 31/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5888 - accuracy: 0.4487 - top-5-accuracy: 0.7607 - val_loss: 2.8511 - val_accuracy: 0.4017 - val_top-5-accuracy: 0.7287
Epoch 32/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5834 - accuracy: 0.4472 - top-5-accuracy: 0.7633 - val_loss: 2.8667 - val_accuracy: 0.3997 - val_top-5-accuracy: 0.7200
Epoch 33/50
313/313 [==============================] - 72s 229ms/step - loss: 2.5682 - accuracy: 0.4519 - top-5-accuracy: 0.7671 - val_loss: 2.9856 - val_accuracy: 0.3849 - val_top-5-accuracy: 0.7031
Epoch 34/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5501 - accuracy: 0.4581 - top-5-accuracy: 0.7703 - val_loss: 2.8811 - val_accuracy: 0.4032 - val_top-5-accuracy: 0.7156
Epoch 35/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5456 - accuracy: 0.4590 - top-5-accuracy: 0.7724 - val_loss: 2.8354 - val_accuracy: 0.4104 - val_top-5-accuracy: 0.7311
Epoch 36/50
313/313 [==============================] - 72s 230ms/step - loss: 2.5384 - accuracy: 0.4598 - top-5-accuracy: 0.7746 - val_loss: 2.7819 - val_accuracy: 0.4180 - val_top-5-accuracy: 0.7388
Epoch 37/50
313/313 [==============================] - 72s 230ms/step - loss: 2.5309 - accuracy: 0.4621 - top-5-accuracy: 0.7758 - val_loss: 2.7605 - val_accuracy: 0.4188 - val_top-5-accuracy: 0.7420
Epoch 38/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5148 - accuracy: 0.4652 - top-5-accuracy: 0.7797 - val_loss: 2.8214 - val_accuracy: 0.4118 - val_top-5-accuracy: 0.7260
Epoch 39/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5127 - accuracy: 0.4687 - top-5-accuracy: 0.7790 - val_loss: 2.7876 - val_accuracy: 0.4229 - val_top-5-accuracy: 0.7320
Epoch 40/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5021 - accuracy: 0.4698 - top-5-accuracy: 0.7801 - val_loss: 2.7975 - val_accuracy: 0.4226 - val_top-5-accuracy: 0.7404
Epoch 41/50
313/313 [==============================] - 71s 228ms/step - loss: 2.5058 - accuracy: 0.4669 - top-5-accuracy: 0.7820 - val_loss: 2.7739 - val_accuracy: 0.4266 - val_top-5-accuracy: 0.7391
Epoch 42/50
313/313 [==============================] - 72s 229ms/step - loss: 2.4906 - accuracy: 0.4733 - top-5-accuracy: 0.7839 - val_loss: 2.8479 - val_accuracy: 0.4131 - val_top-5-accuracy: 0.7234
Epoch 43/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4866 - accuracy: 0.4754 - top-5-accuracy: 0.7871 - val_loss: 2.8122 - val_accuracy: 0.4202 - val_top-5-accuracy: 0.7354
Epoch 44/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4822 - accuracy: 0.4754 - top-5-accuracy: 0.7877 - val_loss: 2.8189 - val_accuracy: 0.4184 - val_top-5-accuracy: 0.7384
Epoch 45/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4738 - accuracy: 0.4770 - top-5-accuracy: 0.7893 - val_loss: 2.9186 - val_accuracy: 0.3983 - val_top-5-accuracy: 0.7077
Epoch 46/50
313/313 [==============================] - 72s 229ms/step - loss: 2.4703 - accuracy: 0.4765 - top-5-accuracy: 0.7897 - val_loss: 2.7892 - val_accuracy: 0.4258 - val_top-5-accuracy: 0.7384
Epoch 47/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4543 - accuracy: 0.4840 - top-5-accuracy: 0.7947 - val_loss: 2.9225 - val_accuracy: 0.4030 - val_top-5-accuracy: 0.7153
Epoch 48/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4579 - accuracy: 0.4818 - top-5-accuracy: 0.7916 - val_loss: 2.8560 - val_accuracy: 0.4144 - val_top-5-accuracy: 0.7309
Epoch 49/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4485 - accuracy: 0.4852 - top-5-accuracy: 0.7945 - val_loss: 2.8210 - val_accuracy: 0.4231 - val_top-5-accuracy: 0.7341
Epoch 50/50
313/313 [==============================] - 71s 228ms/step - loss: 2.4425 - accuracy: 0.4889 - top-5-accuracy: 0.7948 - val_loss: 2.8463 - val_accuracy: 0.4224 - val_top-5-accuracy: 0.7264
CPU times: user 50min 54s, sys: 3min 40s, total: 54min 35s
Wall time: 1h 38s

 

モデルの訓練進捗を可視化しましょう

plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()

 

CIFAR-100 上のテストの最終的な結果を表示しましょう

loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
313/313 [==============================] - 6s 21ms/step - loss: 2.7574 - accuracy: 0.4391 - top-5-accuracy: 0.7471
Test loss: 2.76
Test accuracy: 43.91%
Test top 5 accuracy: 74.71%
313/313 [==============================] - 13s 42ms/step - loss: 2.8036 - accuracy: 0.4333 - top-5-accuracy: 0.7403
Test loss: 2.8
Test accuracy: 43.33%
Test top 5 accuracy: 74.03%

EANet は Vit の自己注意を外部注意で置き換えるだけです。従来の Vit は 50 エポックの訓練後に ~73% テスト top-5 精度と ~41 top-1 精度を得ていますが、60 万パラメータを使用しています。同じ実験環境と同じハイパーパラメータのもとで、ちょうど訓練した EANet モデルは 30 万パラメータだけを持ち、それは ~73% テスト top-5 精度と ~43% top-1 精度に導きます。これは外部注意の有効性を十分に実演しています。私達は EANet の訓練プロセスを示しただけです、同じ実験条件のもとで Vit を訓練してテスト結果を観察することができます。

 

以上



クラスキャット

最近の投稿

  • Agno : コンセプト : エージェント – エージェントの実行
  • Agno : コンセプト : エージェント – 概要
  • Agno : イントロダクション : Playground / モニタリング & デバッグ
  • Agno : イントロダクション : マルチエージェント・システム
  • Agno : イントロダクション : エージェントとは ? / Colab 実行例

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (24) LangGraph 0.5 (9) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年11月
月 火 水 木 金 土 日
1234567
891011121314
15161718192021
22232425262728
2930  
« 10月   12月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme