Keras 2 : examples : Vision Transformer による画像分類 (翻訳/解説)
翻訳 : クラスキャット セールスインフォメーション
作成日時 : 12/10/2023
* 本ページは、以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Image classification with Vision Transformer (Author: Khalid Salama ; 2021/01/18)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Website: www.classcat.com ; ClassCatJP
Keras 3 : examples : Vision Transformer による画像分類
説明: 画像分類のための Vision Transformer (ViT) モデルの実装。
イントロダクション
このサンプルは画像分類のための Alexey Dosovitskiy et al. による Vision Transformer (ViT) モデルを実装し、そしてそれを CIFAR-100 データセットで実演します。ViT モデルは、畳込み層を使用することなく、画像パッチのシークエンスに自己アテンションを持つ Transformer アーキテクチャを適用します。
セットアップ
import os
os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
import keras
from keras import layers
from keras import ops
import numpy as np
import matplotlib.pyplot as plt
データの準備
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
ハイパーパラメータの設定
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 10 # For real training, use num_epochs=100. 10 is a test value
image_size = 72 # We'll resize input images to this size
patch_size = 6 # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [
2048,
1024,
] # Size of the dense layers of the final classifier
データ増強の使用
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
多層パーセプトロン (MLP) の実装
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
パッチ作成を層として実装する
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
input_shape = ops.shape(images)
batch_size = input_shape[0]
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patches = keras.ops.image.extract_patches(images, size=self.patch_size)
patches = ops.reshape(
patches,
(
batch_size,
num_patches_h * num_patches_w,
self.patch_size * self.patch_size * channels,
),
)
return patches
def get_config(self):
config = super().get_config()
config.update({"patch_size": self.patch_size})
return config
サンプル画像のためにパッチを表示しましょう
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")
resized_image = ops.image.resize(
ops.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
ax = plt.subplot(n, n, i + 1)
patch_img = ops.reshape(patch, (patch_size, patch_size, 3))
plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
plt.axis("off")
Image size: 72 X 72 Patch size: 6 X 6 Patches per image: 144 Elements per patch: 108
パッチエンコーディング層の実装
PatchEncoder 層はパッチをサイズ projection_dim のベクトルに射影することで線形に変換します。更に、射影されたベクトルに学習可能な位置埋め込みを追加します。
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
projected_patches = self.projection(patch)
encoded = projected_patches + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
return config
ViT モデルの構築
ViT モデルは複数の Transformer ブロックから成り、これはパッチのシークエンスに適用される自己アテンション・メカニズムとして layers.MultiHeadAttention 層を使用します。Transformer ブロックは [batch_size, num_patches, projection_dim] テンソルを生成します、これは最終的なクラス確率出力を生成するために softmax を持つ分類器ヘッドを通して処理されます。
画像表現として役立つように学習可能な埋め込みをエンコードされたパッチのシークエンスの先頭に追加する、論文 で記述されているテクニックとは違い、最後の Transformer ブロックの総ての出力は layers.Flatten() で reshape されて分類器ヘッドへの画像表現入力として使用されます。layers.GlobalAveragePooling1D 層はまた Transformer ブロックの出力を集約するために代わりに使用できることにも注意してください、特にパッチ数と射影次元が大きい場合です。
def create_vit_classifier():
inputs = keras.Input(shape=input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = Patches(patch_size)(augmented)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Create multiple layers of the Transformer block.
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# Add MLP.
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
# Classify outputs.
logits = layers.Dense(num_classes)(features)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=logits)
return model
モデルをコンパイル, 訓練, そして評価する
def run_experiment(model):
optimizer = keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return history
vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
def plot_history(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_history("loss")
plot_history("top-5-accuracy")
Epoch 1/10 ... Epoch 10/10 176/176 ━━━━━━━━━━━━━━━━━━━━ 449s 3s/step - accuracy: 0.0790 - loss: 3.9468 - top-5-accuracy: 0.2711 - val_accuracy: 0.0986 - val_loss: 3.8537 - val_top-5-accuracy: 0.3052 313/313 ━━━━━━━━━━━━━━━━━━━━ 66s 198ms/step - accuracy: 0.1001 - loss: 3.8428 - top-5-accuracy: 0.3107 Test accuracy: 10.61% Test top 5 accuracy: 31.51%
100 エポック後、ViT モデルはテストデータ上でおよそ 55% 精度と 82% top-5 精度を獲得しました。これは CIFAR-100 データセット上で競争力のある結果ではありません、同じデータ上でスクラッチから訓練された ResNet50V2 は 67% 精度を獲得できるからです。
この 論文 で報告されている最先端の結果は JFT-300M データセットを使用して ViT モデルを事前訓練してからそれをターゲットデータセット上で再調整することで獲得されたことに注意してください。事前訓練なしにモデル品質を向上するために、より多いエポックの間モデルを訓練し、より大きい数の Transformer 層を使用し、入力画像をリサイズし、パッチサイズを変更し、あるいは射影次元を増やすことを試すことができます。また、論文で述べられているように、モデルの品質はアーキテクチャの選択だけではなく、学習率スケジュール, optimizer, 重み減衰 etc. のようなパラメータにも影響されます。実際には、大規模で高解像度のデータセットを使用して事前訓練された ViT モデルを再調整することが推奨されます。
以上