Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – MobileViT : 画像分類のためのモバイルで扱いやすい Transformer ベースモデル

Posted on 12/01/202112/05/2021 by Sales Information

Keras 2 : examples : 画像分類のためのモバイルで扱いやすい Transformer ベースモデル (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/01/2021 (keras 2.7.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : MobileViT: A mobile-friendly Transformer-based model for image classification (Author: Sayak Paul)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス ★ 無料 Web セミナー開催中 ★

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援

  • テレワーク & オンライン授業を支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • E-Mail:sales-info@classcat.com  ;  WebSite: www.classcat.com  ;  Facebook

 

 

Keras 2 : examples : 画像分類のためのモバイルで扱いやすい Transformer ベースモデル

Description: 畳み込みと Transformer の組み合わせた利点による画像分類のための MobileViT。

 

イントロダクション

このサンプルでは、MobileViT アーキテクチャ (Mehta et al.) を実装します、これは Transformer (Vaswani et al.) と畳み込みの利点を組合せています。Transformer では、グローバルな表現という結果になる long-range な依存性を捕捉できます。畳み込みでは、局所性をモデル化する空間的な関係を捕捉できます。

Tansformer と畳み込みの特性を組み合わせるだけでなく、著者らは MobileVit を様々な画像認識タスクのための汎用目的なモバイルフレンドリーなバックボーンとして導入しています。彼らの研究結果は、性能について、MobileViT がモバイルデバイス上で効率的でありながら、同じまたはより高い複雑性と持つ他のモデル (例えば、MobileNetV3) よりも良いことを示しています。

 

インポート

import tensorflow as tf

from keras.applications import imagenet_utils
from tensorflow.keras import layers
from tensorflow import keras

import tensorflow_datasets as tfds
import tensorflow_addons as tfa

tfds.disable_progress_bar()

 

ハイパーパラメータ

# Values are from table 4.
patch_size = 4  # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 2  # expansion factor for the MobileNetV2 blocks.

 

MobileViT ユティリティ

MobileViT アーキテクチャは以下のブロックから成ります :

  • 入力画像を処理する strided 3×3 畳み込み。
  • 中間特徴マップの解像度をダウンサンプリングするための MobileNetV2 スタイルの inverted 残差ブロック。
  • Transformer と畳み込みの利点を組み合わせた MobileViT ブロック。それは下の図で表されます (原論文 から引用) :

def conv_block(x, filters=16, kernel_size=3, strides=2):
    conv_layer = layers.Conv2D(
        filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same"
    )
    return conv_layer(x)


# Reference: https://git.io/JKgtC


def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)
    m = layers.DepthwiseConv2D(
        3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
    )(m)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization()(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m


# Reference:
# https://keras.io/examples/vision/image_classification_with_vision_transformer/


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
        # Skip connection 2.
        x = layers.Add()([x3, x2])

    return x


def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides
    )

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
    non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
        local_features
    )
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim
    )

    # Fold into conv-like feature-maps.
    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
        global_features
    )

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    )
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides
    )

    return local_global_features

 
MobileViT ブロックの詳細 :

  • 最初に、特徴表現 (A) は局所的な関係性を捕捉する畳み込みブロックを通過します。ここでは単一のエントリーの想定される shape は (h, w, num_channels) になります。

  • 次に、それらは shape (p, n, num_channels) を持つ別のベクトルに展開されます、ここで p は小さいパッチの面積で、n は (h * w) / p です。従って、結局 n 個の重ならないパッチになります。

  • この展開されたベクトルは次に Transformer ブロックに渡されます、これはパッチ間のグローバルな関係性を捕捉します。

  • 出力ベクトル (B) は再度 shape (h, w, num_channels) のベクトルに畳み込まれます、これは畳み込みから出力される特徴マップに類似しています。

次にベクトル A と B はローカルとグローバル表現を融合するために更に 2 つの畳み込み層を通過します。この時点で最終的なベクトルの空間的解像度が変化しないままでいる方法に着目してください。著者らはまた MobileViT ブロックが CNN の畳み込みブロックに似ている方法の説明も提示しています。詳細は、原論文を参照してください。

次に、これらのブロックをひとつに組合せて MobileViT アーキテクチャ (XXS バリアント) を実装します。次の図 (原論文から引用) はアーキテクチャの概要の表現を表しています :

def create_mobilevit(num_classes=5):
    inputs = keras.Input((image_size, image_size, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)

    # Initial conv-stem -> MV2 block.
    x = conv_block(x, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=16
    )

    # Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    )

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
    )
    x = mobilevit_block(x, num_blocks=2, projection_dim=64)

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
    )
    x = mobilevit_block(x, num_blocks=4, projection_dim=80)

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    )
    x = mobilevit_block(x, num_blocks=3, projection_dim=96)
    x = conv_block(x, filters=320, kernel_size=1, strides=1)

    # Classification head.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)


mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 256, 256, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 16) 448         rescaling[0][0]                  
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 32) 512         conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 32) 128         conv2d_1[0][0]                   
__________________________________________________________________________________________________
tf.nn.silu (TFOpLambda)         (None, 128, 128, 32) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
depthwise_conv2d (DepthwiseConv (None, 128, 128, 32) 288         tf.nn.silu[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 32) 128         depthwise_conv2d[0][0]           
__________________________________________________________________________________________________
tf.nn.silu_1 (TFOpLambda)       (None, 128, 128, 32) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 128, 128, 16) 512         tf.nn.silu_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128, 128, 16) 64          conv2d_2[0][0]                   
__________________________________________________________________________________________________
add (Add)                       (None, 128, 128, 16) 0           batch_normalization_2[0][0]      
                                                                 conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 32) 512         add[0][0]                        
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 128, 128, 32) 128         conv2d_3[0][0]                   
__________________________________________________________________________________________________
tf.nn.silu_2 (TFOpLambda)       (None, 128, 128, 32) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 129, 129, 32) 0           tf.nn.silu_2[0][0]               
__________________________________________________________________________________________________
depthwise_conv2d_1 (DepthwiseCo (None, 64, 64, 32)   288         zero_padding2d[0][0]             
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 32)   128         depthwise_conv2d_1[0][0]         
__________________________________________________________________________________________________
tf.nn.silu_3 (TFOpLambda)       (None, 64, 64, 32)   0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 24)   768         tf.nn.silu_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 24)   96          conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 48)   1152        batch_normalization_5[0][0]      
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 48)   192         conv2d_5[0][0]                   
__________________________________________________________________________________________________
tf.nn.silu_4 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
depthwise_conv2d_2 (DepthwiseCo (None, 64, 64, 48)   432         tf.nn.silu_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 64, 64, 48)   192         depthwise_conv2d_2[0][0]         
__________________________________________________________________________________________________
tf.nn.silu_5 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 24)   1152        tf.nn.silu_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 64, 64, 24)   96          conv2d_6[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 64, 64, 24)   0           batch_normalization_8[0][0]      
                                                                 batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 48)   1152        add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 64, 64, 48)   192         conv2d_7[0][0]                   
__________________________________________________________________________________________________
tf.nn.silu_6 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
depthwise_conv2d_3 (DepthwiseCo (None, 64, 64, 48)   432         tf.nn.silu_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 48)   192         depthwise_conv2d_3[0][0]         
__________________________________________________________________________________________________
tf.nn.silu_7 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 24)   1152        tf.nn.silu_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 24)   96          conv2d_8[0][0]                   
__________________________________________________________________________________________________
add_2 (Add)                     (None, 64, 64, 24)   0           batch_normalization_11[0][0]     
                                                                 add_1[0][0]                      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 48)   1152        add_2[0][0]                      
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 64, 64, 48)   192         conv2d_9[0][0]                   
__________________________________________________________________________________________________
tf.nn.silu_8 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 65, 65, 48)   0           tf.nn.silu_8[0][0]               
__________________________________________________________________________________________________
depthwise_conv2d_4 (DepthwiseCo (None, 32, 32, 48)   432         zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 32, 32, 48)   192         depthwise_conv2d_4[0][0]         
__________________________________________________________________________________________________
tf.nn.silu_9 (TFOpLambda)       (None, 32, 32, 48)   0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 48)   2304        tf.nn.silu_9[0][0]               
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 48)   192         conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   27712       batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 64)   4160        conv2d_11[0][0]                  
__________________________________________________________________________________________________
reshape (Reshape)               (None, 4, 256, 64)   0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 4, 256, 64)   128         reshape[0][0]                    
__________________________________________________________________________________________________
multi_head_attention (MultiHead (None, 4, 256, 64)   33216       layer_normalization[0][0]        
                                                                 layer_normalization[0][0]        
__________________________________________________________________________________________________
add_3 (Add)                     (None, 4, 256, 64)   0           multi_head_attention[0][0]       
                                                                 reshape[0][0]                    
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 4, 256, 64)   128         add_3[0][0]                      
__________________________________________________________________________________________________
dense (Dense)                   (None, 4, 256, 128)  8320        layer_normalization_1[0][0]      
__________________________________________________________________________________________________
dropout (Dropout)               (None, 4, 256, 128)  0           dense[0][0]                      
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 4, 256, 64)   8256        dropout[0][0]                    
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 4, 256, 64)   0           dense_1[0][0]                    
__________________________________________________________________________________________________
add_4 (Add)                     (None, 4, 256, 64)   0           dropout_1[0][0]                  
                                                                 add_3[0][0]                      
__________________________________________________________________________________________________
layer_normalization_2 (LayerNor (None, 4, 256, 64)   128         add_4[0][0]                      
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe (None, 4, 256, 64)   33216       layer_normalization_2[0][0]      
                                                                 layer_normalization_2[0][0]      
__________________________________________________________________________________________________
add_5 (Add)                     (None, 4, 256, 64)   0           multi_head_attention_1[0][0]     
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 4, 256, 64)   128         add_5[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 4, 256, 128)  8320        layer_normalization_3[0][0]      
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 4, 256, 128)  0           dense_2[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 4, 256, 64)   8256        dropout_2[0][0]                  
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 4, 256, 64)   0           dense_3[0][0]                    
__________________________________________________________________________________________________
add_6 (Add)                     (None, 4, 256, 64)   0           dropout_3[0][0]                  
                                                                 add_5[0][0]                      
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 32, 32, 64)   0           add_6[0][0]                      
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 32, 32, 48)   3120        reshape_1[0][0]                  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 32, 32, 96)   0           batch_normalization_14[0][0]     
                                                                 conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 32, 32, 64)   55360       concatenate[0][0]                
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 32, 32, 128)  8192        conv2d_14[0][0]                  
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 32, 32, 128)  512         conv2d_15[0][0]                  
__________________________________________________________________________________________________
tf.nn.silu_10 (TFOpLambda)      (None, 32, 32, 128)  0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 33, 33, 128)  0           tf.nn.silu_10[0][0]              
__________________________________________________________________________________________________
depthwise_conv2d_5 (DepthwiseCo (None, 16, 16, 128)  1152        zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 16, 16, 128)  512         depthwise_conv2d_5[0][0]         
__________________________________________________________________________________________________
tf.nn.silu_11 (TFOpLambda)      (None, 16, 16, 128)  0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 64)   8192        tf.nn.silu_11[0][0]              
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 16, 16, 64)   256         conv2d_16[0][0]                  
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 16, 16, 80)   46160       batch_normalization_17[0][0]     
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 16, 16, 80)   6480        conv2d_17[0][0]                  
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 4, 64, 80)    0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 4, 64, 80)    160         reshape_2[0][0]                  
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_4[0][0]      
                                                                 layer_normalization_4[0][0]      
__________________________________________________________________________________________________
add_7 (Add)                     (None, 4, 64, 80)    0           multi_head_attention_2[0][0]     
                                                                 reshape_2[0][0]                  
__________________________________________________________________________________________________
layer_normalization_5 (LayerNor (None, 4, 64, 80)    160         add_7[0][0]                      
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 4, 64, 160)   12960       layer_normalization_5[0][0]      
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 4, 64, 160)   0           dense_4[0][0]                    
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 4, 64, 80)    12880       dropout_4[0][0]                  
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 4, 64, 80)    0           dense_5[0][0]                    
__________________________________________________________________________________________________
add_8 (Add)                     (None, 4, 64, 80)    0           dropout_5[0][0]                  
                                                                 add_7[0][0]                      
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 4, 64, 80)    160         add_8[0][0]                      
__________________________________________________________________________________________________
multi_head_attention_3 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_6[0][0]      
                                                                 layer_normalization_6[0][0]      
__________________________________________________________________________________________________
add_9 (Add)                     (None, 4, 64, 80)    0           multi_head_attention_3[0][0]     
                                                                 add_8[0][0]                      
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 4, 64, 80)    160         add_9[0][0]                      
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 4, 64, 160)   12960       layer_normalization_7[0][0]      
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 4, 64, 160)   0           dense_6[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 4, 64, 80)    12880       dropout_6[0][0]                  
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 4, 64, 80)    0           dense_7[0][0]                    
__________________________________________________________________________________________________
add_10 (Add)                    (None, 4, 64, 80)    0           dropout_7[0][0]                  
                                                                 add_9[0][0]                      
__________________________________________________________________________________________________
layer_normalization_8 (LayerNor (None, 4, 64, 80)    160         add_10[0][0]                     
__________________________________________________________________________________________________
multi_head_attention_4 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_8[0][0]      
                                                                 layer_normalization_8[0][0]      
__________________________________________________________________________________________________
add_11 (Add)                    (None, 4, 64, 80)    0           multi_head_attention_4[0][0]     
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
layer_normalization_9 (LayerNor (None, 4, 64, 80)    160         add_11[0][0]                     
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 4, 64, 160)   12960       layer_normalization_9[0][0]      
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 4, 64, 160)   0           dense_8[0][0]                    
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 4, 64, 80)    12880       dropout_8[0][0]                  
__________________________________________________________________________________________________
dropout_9 (Dropout)             (None, 4, 64, 80)    0           dense_9[0][0]                    
__________________________________________________________________________________________________
add_12 (Add)                    (None, 4, 64, 80)    0           dropout_9[0][0]                  
                                                                 add_11[0][0]                     
__________________________________________________________________________________________________
layer_normalization_10 (LayerNo (None, 4, 64, 80)    160         add_12[0][0]                     
__________________________________________________________________________________________________
multi_head_attention_5 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_10[0][0]     
                                                                 layer_normalization_10[0][0]     
__________________________________________________________________________________________________
add_13 (Add)                    (None, 4, 64, 80)    0           multi_head_attention_5[0][0]     
                                                                 add_12[0][0]                     
__________________________________________________________________________________________________
layer_normalization_11 (LayerNo (None, 4, 64, 80)    160         add_13[0][0]                     
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 4, 64, 160)   12960       layer_normalization_11[0][0]     
__________________________________________________________________________________________________
dropout_10 (Dropout)            (None, 4, 64, 160)   0           dense_10[0][0]                   
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 4, 64, 80)    12880       dropout_10[0][0]                 
__________________________________________________________________________________________________
dropout_11 (Dropout)            (None, 4, 64, 80)    0           dense_11[0][0]                   
__________________________________________________________________________________________________
add_14 (Add)                    (None, 4, 64, 80)    0           dropout_11[0][0]                 
                                                                 add_13[0][0]                     
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 16, 16, 80)   0           add_14[0][0]                     
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 16, 16, 64)   5184        reshape_3[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 128)  0           batch_normalization_17[0][0]     
                                                                 conv2d_19[0][0]                  
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 16, 16, 80)   92240       concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 16, 16, 160)  12800       conv2d_20[0][0]                  
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 16, 16, 160)  640         conv2d_21[0][0]                  
__________________________________________________________________________________________________
tf.nn.silu_12 (TFOpLambda)      (None, 16, 16, 160)  0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 17, 17, 160)  0           tf.nn.silu_12[0][0]              
__________________________________________________________________________________________________
depthwise_conv2d_6 (DepthwiseCo (None, 8, 8, 160)    1440        zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 8, 8, 160)    640         depthwise_conv2d_6[0][0]         
__________________________________________________________________________________________________
tf.nn.silu_13 (TFOpLambda)      (None, 8, 8, 160)    0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 8, 8, 80)     12800       tf.nn.silu_13[0][0]              
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 8, 8, 80)     320         conv2d_22[0][0]                  
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 8, 8, 96)     69216       batch_normalization_20[0][0]     
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 8, 8, 96)     9312        conv2d_23[0][0]                  
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 4, 16, 96)    0           conv2d_24[0][0]                  
__________________________________________________________________________________________________
layer_normalization_12 (LayerNo (None, 4, 16, 96)    192         reshape_4[0][0]                  
__________________________________________________________________________________________________
multi_head_attention_6 (MultiHe (None, 4, 16, 96)    74400       layer_normalization_12[0][0]     
                                                                 layer_normalization_12[0][0]     
__________________________________________________________________________________________________
add_15 (Add)                    (None, 4, 16, 96)    0           multi_head_attention_6[0][0]     
                                                                 reshape_4[0][0]                  
__________________________________________________________________________________________________
layer_normalization_13 (LayerNo (None, 4, 16, 96)    192         add_15[0][0]                     
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 4, 16, 192)   18624       layer_normalization_13[0][0]     
__________________________________________________________________________________________________
dropout_12 (Dropout)            (None, 4, 16, 192)   0           dense_12[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 4, 16, 96)    18528       dropout_12[0][0]                 
__________________________________________________________________________________________________
dropout_13 (Dropout)            (None, 4, 16, 96)    0           dense_13[0][0]                   
__________________________________________________________________________________________________
add_16 (Add)                    (None, 4, 16, 96)    0           dropout_13[0][0]                 
                                                                 add_15[0][0]                     
__________________________________________________________________________________________________
layer_normalization_14 (LayerNo (None, 4, 16, 96)    192         add_16[0][0]                     
__________________________________________________________________________________________________
multi_head_attention_7 (MultiHe (None, 4, 16, 96)    74400       layer_normalization_14[0][0]     
                                                                 layer_normalization_14[0][0]     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 4, 16, 96)    0           multi_head_attention_7[0][0]     
                                                                 add_16[0][0]                     
__________________________________________________________________________________________________
layer_normalization_15 (LayerNo (None, 4, 16, 96)    192         add_17[0][0]                     
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 4, 16, 192)   18624       layer_normalization_15[0][0]     
__________________________________________________________________________________________________
dropout_14 (Dropout)            (None, 4, 16, 192)   0           dense_14[0][0]                   
__________________________________________________________________________________________________
dense_15 (Dense)                (None, 4, 16, 96)    18528       dropout_14[0][0]                 
__________________________________________________________________________________________________
dropout_15 (Dropout)            (None, 4, 16, 96)    0           dense_15[0][0]                   
__________________________________________________________________________________________________
add_18 (Add)                    (None, 4, 16, 96)    0           dropout_15[0][0]                 
                                                                 add_17[0][0]                     
__________________________________________________________________________________________________
layer_normalization_16 (LayerNo (None, 4, 16, 96)    192         add_18[0][0]                     
__________________________________________________________________________________________________
multi_head_attention_8 (MultiHe (None, 4, 16, 96)    74400       layer_normalization_16[0][0]     
                                                                 layer_normalization_16[0][0]     
__________________________________________________________________________________________________
add_19 (Add)                    (None, 4, 16, 96)    0           multi_head_attention_8[0][0]     
                                                                 add_18[0][0]                     
__________________________________________________________________________________________________
layer_normalization_17 (LayerNo (None, 4, 16, 96)    192         add_19[0][0]                     
__________________________________________________________________________________________________
dense_16 (Dense)                (None, 4, 16, 192)   18624       layer_normalization_17[0][0]     
__________________________________________________________________________________________________
dropout_16 (Dropout)            (None, 4, 16, 192)   0           dense_16[0][0]                   
__________________________________________________________________________________________________
dense_17 (Dense)                (None, 4, 16, 96)    18528       dropout_16[0][0]                 
__________________________________________________________________________________________________
dropout_17 (Dropout)            (None, 4, 16, 96)    0           dense_17[0][0]                   
__________________________________________________________________________________________________
add_20 (Add)                    (None, 4, 16, 96)    0           dropout_17[0][0]                 
                                                                 add_19[0][0]                     
__________________________________________________________________________________________________
reshape_5 (Reshape)             (None, 8, 8, 96)     0           add_20[0][0]                     
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 8, 8, 80)     7760        reshape_5[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 8, 8, 160)    0           batch_normalization_20[0][0]     
                                                                 conv2d_25[0][0]                  
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 8, 8, 96)     138336      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 8, 8, 320)    31040       conv2d_26[0][0]                  
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 320)          0           conv2d_27[0][0]                  
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 5)            1605        global_average_pooling2d[0][0]   
==================================================================================================
Total params: 1,307,621
Trainable params: 1,305,077
Non-trainable params: 2,544
______________________________________________________________________________________________

 

データセットの準備

モデルを実演するために tf_flowers データセットを使用していきます。他の Transformer ベースのアーキテクチャとは違い、MobileViT は単純な増強パイプラインを使用します、これは主としてそれが CNN の特性を持つためです。

batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 280
num_classes = 5


def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            # Resize to a bigger spatial resolution and take the random
            # crops.
            image = tf.image.resize(image, (resize_bigger, resize_bigger))
            image = tf.image.random_crop(image, (image_size, image_size, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (image_size, image_size))
        label = tf.one_hot(label, depth=num_classes)
        return image, label

    return _pp


def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
    return dataset.batch(batch_size).prefetch(auto)

著者らは、モデルが多様なスケールの表現を学習するのを支援するためにマルチスケールのデータサンプラーを使用しています。この例では、この部分は捨てます。

 

データセットをロードして準備する

train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)

num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367

 

MobileViT (XXS) モデルを訓練する

learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30

optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)


def run_experiment(epochs=epochs):
    mobilevit_xxs = create_mobilevit(num_classes=num_classes)
    mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    mobilevit_xxs.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[checkpoint_callback],
    )
    mobilevit_xxs.load_weights(checkpoint_filepath)
    _, accuracy = mobilevit_xxs.evaluate(val_dataset)
    print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
    return mobilevit_xxs


mobilevit_xxs = run_experiment()
Epoch 1/30
52/52 [==============================] - 47s 459ms/step - loss: 1.3397 - accuracy: 0.4832 - val_loss: 1.7250 - val_accuracy: 0.1662
Epoch 2/30
52/52 [==============================] - 21s 404ms/step - loss: 1.1167 - accuracy: 0.6210 - val_loss: 1.9844 - val_accuracy: 0.1907
Epoch 3/30
52/52 [==============================] - 21s 403ms/step - loss: 1.0217 - accuracy: 0.6709 - val_loss: 1.8187 - val_accuracy: 0.1907
Epoch 4/30
52/52 [==============================] - 21s 409ms/step - loss: 0.9682 - accuracy: 0.7048 - val_loss: 2.0329 - val_accuracy: 0.1907
Epoch 5/30
52/52 [==============================] - 21s 408ms/step - loss: 0.9552 - accuracy: 0.7196 - val_loss: 2.1150 - val_accuracy: 0.1907
Epoch 6/30
52/52 [==============================] - 21s 407ms/step - loss: 0.9186 - accuracy: 0.7318 - val_loss: 2.9713 - val_accuracy: 0.1907
Epoch 7/30
52/52 [==============================] - 21s 407ms/step - loss: 0.8986 - accuracy: 0.7457 - val_loss: 3.2062 - val_accuracy: 0.1907
Epoch 8/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8831 - accuracy: 0.7542 - val_loss: 3.8631 - val_accuracy: 0.1907
Epoch 9/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8433 - accuracy: 0.7714 - val_loss: 1.8029 - val_accuracy: 0.3542
Epoch 10/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8489 - accuracy: 0.7763 - val_loss: 1.7920 - val_accuracy: 0.4796
Epoch 11/30
52/52 [==============================] - 21s 409ms/step - loss: 0.8256 - accuracy: 0.7884 - val_loss: 1.4992 - val_accuracy: 0.5477
Epoch 12/30
52/52 [==============================] - 21s 407ms/step - loss: 0.7859 - accuracy: 0.8123 - val_loss: 0.9236 - val_accuracy: 0.7330
Epoch 13/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7702 - accuracy: 0.8159 - val_loss: 0.8059 - val_accuracy: 0.8011
Epoch 14/30
52/52 [==============================] - 21s 403ms/step - loss: 0.7670 - accuracy: 0.8153 - val_loss: 1.1535 - val_accuracy: 0.7084
Epoch 15/30
52/52 [==============================] - 21s 408ms/step - loss: 0.7332 - accuracy: 0.8344 - val_loss: 0.7746 - val_accuracy: 0.8147
Epoch 16/30
52/52 [==============================] - 21s 404ms/step - loss: 0.7284 - accuracy: 0.8335 - val_loss: 1.0342 - val_accuracy: 0.7330
Epoch 17/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7484 - accuracy: 0.8262 - val_loss: 1.0523 - val_accuracy: 0.7112
Epoch 18/30
52/52 [==============================] - 21s 408ms/step - loss: 0.7209 - accuracy: 0.8450 - val_loss: 0.8146 - val_accuracy: 0.8174
Epoch 19/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7141 - accuracy: 0.8435 - val_loss: 0.8016 - val_accuracy: 0.7875
Epoch 20/30
52/52 [==============================] - 21s 410ms/step - loss: 0.7075 - accuracy: 0.8435 - val_loss: 0.9352 - val_accuracy: 0.7439
Epoch 21/30
52/52 [==============================] - 21s 406ms/step - loss: 0.7066 - accuracy: 0.8504 - val_loss: 1.0171 - val_accuracy: 0.7139
Epoch 22/30
52/52 [==============================] - 21s 405ms/step - loss: 0.6913 - accuracy: 0.8532 - val_loss: 0.7059 - val_accuracy: 0.8610
Epoch 23/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6681 - accuracy: 0.8671 - val_loss: 0.8007 - val_accuracy: 0.8147
Epoch 24/30
52/52 [==============================] - 21s 409ms/step - loss: 0.6636 - accuracy: 0.8747 - val_loss: 0.9490 - val_accuracy: 0.7302
Epoch 25/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6637 - accuracy: 0.8722 - val_loss: 0.6913 - val_accuracy: 0.8556
Epoch 26/30
52/52 [==============================] - 21s 406ms/step - loss: 0.6443 - accuracy: 0.8837 - val_loss: 1.0483 - val_accuracy: 0.7139
Epoch 27/30
52/52 [==============================] - 21s 407ms/step - loss: 0.6555 - accuracy: 0.8695 - val_loss: 0.9448 - val_accuracy: 0.7602
Epoch 28/30
52/52 [==============================] - 21s 409ms/step - loss: 0.6409 - accuracy: 0.8807 - val_loss: 0.9337 - val_accuracy: 0.7302
Epoch 29/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6300 - accuracy: 0.8910 - val_loss: 0.7461 - val_accuracy: 0.8256
Epoch 30/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6093 - accuracy: 0.8968 - val_loss: 0.8651 - val_accuracy: 0.7766
6/6 [==============================] - 0s 65ms/step - loss: 0.7059 - accuracy: 0.8610
Validation accuracy: 86.1%

(訳者注: 実験結果)

Epoch 1/30
52/52 [==============================] - 46s 480ms/step - loss: 1.2939 - accuracy: 0.5098 - val_loss: 1.8419 - val_accuracy: 0.1907
Epoch 2/30
52/52 [==============================] - 22s 414ms/step - loss: 1.0890 - accuracy: 0.6322 - val_loss: 1.7239 - val_accuracy: 0.1907
Epoch 3/30
52/52 [==============================] - 22s 416ms/step - loss: 1.0121 - accuracy: 0.6791 - val_loss: 1.9058 - val_accuracy: 0.1907
Epoch 4/30
52/52 [==============================] - 22s 418ms/step - loss: 0.9721 - accuracy: 0.7021 - val_loss: 2.0354 - val_accuracy: 0.1907
Epoch 5/30
52/52 [==============================] - 22s 422ms/step - loss: 0.9360 - accuracy: 0.7342 - val_loss: 2.7073 - val_accuracy: 0.1907
Epoch 6/30
52/52 [==============================] - 22s 424ms/step - loss: 0.9060 - accuracy: 0.7433 - val_loss: 3.2771 - val_accuracy: 0.1907
Epoch 7/30
52/52 [==============================] - 22s 421ms/step - loss: 0.8940 - accuracy: 0.7469 - val_loss: 2.7802 - val_accuracy: 0.1907
Epoch 8/30
52/52 [==============================] - 22s 416ms/step - loss: 0.8648 - accuracy: 0.7602 - val_loss: 3.1462 - val_accuracy: 0.1907
Epoch 9/30
52/52 [==============================] - 22s 427ms/step - loss: 0.8474 - accuracy: 0.7666 - val_loss: 1.8916 - val_accuracy: 0.3815
Epoch 10/30
52/52 [==============================] - 22s 419ms/step - loss: 0.8161 - accuracy: 0.7938 - val_loss: 1.9075 - val_accuracy: 0.4005
Epoch 11/30
52/52 [==============================] - 22s 419ms/step - loss: 0.8048 - accuracy: 0.7959 - val_loss: 2.2018 - val_accuracy: 0.3515
Epoch 12/30
52/52 [==============================] - 22s 426ms/step - loss: 0.8074 - accuracy: 0.7908 - val_loss: 0.9702 - val_accuracy: 0.7248
Epoch 13/30
52/52 [==============================] - 22s 415ms/step - loss: 0.7950 - accuracy: 0.8071 - val_loss: 1.1303 - val_accuracy: 0.6730
Epoch 14/30
52/52 [==============================] - 22s 417ms/step - loss: 0.7678 - accuracy: 0.8147 - val_loss: 1.1896 - val_accuracy: 0.6349
Epoch 15/30
52/52 [==============================] - 22s 415ms/step - loss: 0.7501 - accuracy: 0.8253 - val_loss: 1.2600 - val_accuracy: 0.6676
Epoch 16/30
52/52 [==============================] - 22s 421ms/step - loss: 0.7655 - accuracy: 0.8256 - val_loss: 0.8396 - val_accuracy: 0.7738
Epoch 17/30
52/52 [==============================] - 22s 416ms/step - loss: 0.7248 - accuracy: 0.8414 - val_loss: 1.1027 - val_accuracy: 0.6649
Epoch 18/30
52/52 [==============================] - 22s 426ms/step - loss: 0.7272 - accuracy: 0.8365 - val_loss: 0.8282 - val_accuracy: 0.7929
Epoch 19/30
52/52 [==============================] - 22s 415ms/step - loss: 0.7322 - accuracy: 0.8341 - val_loss: 0.9421 - val_accuracy: 0.7439
Epoch 20/30
52/52 [==============================] - 22s 416ms/step - loss: 0.7054 - accuracy: 0.8507 - val_loss: 0.8501 - val_accuracy: 0.7902
Epoch 21/30
52/52 [==============================] - 22s 416ms/step - loss: 0.6965 - accuracy: 0.8520 - val_loss: 0.9330 - val_accuracy: 0.7384
Epoch 22/30
52/52 [==============================] - 22s 417ms/step - loss: 0.6903 - accuracy: 0.8607 - val_loss: 1.0051 - val_accuracy: 0.7384
Epoch 23/30
52/52 [==============================] - 22s 418ms/step - loss: 0.6924 - accuracy: 0.8544 - val_loss: 0.8628 - val_accuracy: 0.7711
Epoch 24/30
52/52 [==============================] - 22s 415ms/step - loss: 0.6659 - accuracy: 0.8722 - val_loss: 1.0199 - val_accuracy: 0.7357
Epoch 25/30
52/52 [==============================] - 22s 414ms/step - loss: 0.6664 - accuracy: 0.8662 - val_loss: 0.9770 - val_accuracy: 0.7248
Epoch 26/30
52/52 [==============================] - 22s 412ms/step - loss: 0.6540 - accuracy: 0.8756 - val_loss: 0.8168 - val_accuracy: 0.7820
Epoch 27/30
52/52 [==============================] - 22s 425ms/step - loss: 0.6643 - accuracy: 0.8713 - val_loss: 0.7484 - val_accuracy: 0.8420
Epoch 28/30
52/52 [==============================] - 22s 415ms/step - loss: 0.6364 - accuracy: 0.8850 - val_loss: 1.0473 - val_accuracy: 0.7384
Epoch 29/30
52/52 [==============================] - 22s 412ms/step - loss: 0.6495 - accuracy: 0.8747 - val_loss: 1.0393 - val_accuracy: 0.7302
Epoch 30/30
52/52 [==============================] - 22s 416ms/step - loss: 0.6397 - accuracy: 0.8801 - val_loss: 0.9585 - val_accuracy: 0.7493
6/6 [==============================] - 1s 76ms/step - loss: 0.7484 - accuracy: 0.8420
Validation accuracy: 84.2%
CPU times: user 12min 52s, sys: 22.1 s, total: 13min 15s
Wall time: 19min 21s

 

結果と TFLite 変換

約 100 万パラメータで、256×256 解像度上で ~85% top-1 精度への到達は強力な結果です。この MobileViT モバイルは TensorFlow Lite (TFlite) と完全に互換で、次のコードで変換できます :

# Serialize the model as a SavedModel.
mobilevit_xxs.save("mobilevit_xxs")

# Convert to TFLite. This form of quantization is called
# post-training dynamic-range quantization in TFLite.
converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # Enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS,  # Enable TensorFlow ops.
]
tflite_model = converter.convert()
open("mobilevit_xxs.tflite", "wb").write(tflite_model)

To learn more about different quantization recipes available in TFLite and running inference with TFLite models, check out this official resource.

 

以上



クラスキャット

最近の投稿

  • LangGraph on Colab : エージェント型 RAG
  • LangGraph : 例題 : エージェント型 RAG
  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要
  • LangGraph : Prebuilt エージェント : ユーザインターフェイス

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (22) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年12月
月 火 水 木 金 土 日
 12345
6789101112
13141516171819
20212223242526
2728293031  
« 11月   3月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme