Keras 2 : examples : 時系列 – Transformer モデルによる時系列分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/19/2022 (keras 2.9.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Timeseries : Timeseries classification with a Transformer model (Author: Theodoros Ntakouris)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 時系列 – Transformer モデルによる時系列分類
Description : このノートブックは Transformer モデルを使用した時系列分類を行なう方法を実演します。
イントロダクション
これは自然言語の代わりに時系列に適用される、Attention Is All You Need からの Transformer アーキテクチャです。
This example requires TensorFlow 2.4 or higher.
データセットのロード
ゼロからの時系列分類 の例と同じデータセットと前処理を使用していきます。
import numpy as np
def readucr(filename):
data = np.loadtxt(filename, delimiter="\t")
y = data[:, 0]
x = data[:, 1:]
return x, y.astype(int)
root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"
x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")
x_test, y_test = readucr(root_url + "FordA_TEST.tsv")
x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))
x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))
n_classes = len(np.unique(y_train))
idx = np.random.permutation(len(x_train))
x_train = x_train[idx]
y_train = y_train[idx]
y_train[y_train == -1] = 0
y_test[y_test == -1] = 0
モデルの構築
モデルは shape (batch size, sequence length, features) のテンソルを処理します、ここでシークエンス長は時間ステップ数で特徴は各入力時系列です。
分類 RNN 層をこのものに置き換えることができます : 入力は完全に互換です!
from tensorflow import keras
from tensorflow.keras import layers
残差接続, 層正規化, そして dropout を含めます。結果としての層は複数回スタックできます。
射影 (= projection) 層は keras.layers.Conv1D により実装されます。
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
# Attention and Normalization
x = layers.MultiHeadAttention(
key_dim=head_size, num_heads=num_heads, dropout=dropout
)(inputs, inputs)
x = layers.Dropout(dropout)(x)
x = layers.LayerNormalization(epsilon=1e-6)(x)
res = x + inputs
# Feed Forward Part
x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
x = layers.Dropout(dropout)(x)
x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
x = layers.LayerNormalization(epsilon=1e-6)(x)
return x + res
これでモデルの主要部は完成です。複数の transformer_encoder ブロックをスタックできてそして最後の多層パーセプトロン分類ヘッドを続けて追加することもできます。Dense 層のスタックとは別に、モデルの TransformerEncoder パートの出力テンソルを、現在のバッチの各データポイントに対する特徴ベクトルに reduce down する必要があります。これを成す一般的な方法はプーリング層を使用することです。この例については、GlobalAveragePooling1D 層で十分です。
def build_model(
input_shape,
head_size,
num_heads,
ff_dim,
num_transformer_blocks,
mlp_units,
dropout=0,
mlp_dropout=0,
):
inputs = keras.Input(shape=input_shape)
x = inputs
for _ in range(num_transformer_blocks):
x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)
x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
for dim in mlp_units:
x = layers.Dense(dim, activation="relu")(x)
x = layers.Dropout(mlp_dropout)(x)
outputs = layers.Dense(n_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
訓練と評価
input_shape = x_train.shape[1:]
model = build_model(
input_shape,
head_size=256,
num_heads=4,
ff_dim=4,
num_transformer_blocks=4,
mlp_units=[128],
mlp_dropout=0.4,
dropout=0.25,
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
metrics=["sparse_categorical_accuracy"],
)
model.summary()
callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
model.fit(
x_train,
y_train,
validation_split=0.2,
epochs=200,
batch_size=64,
callbacks=callbacks,
)
model.evaluate(x_test, y_test, verbose=1)
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 500, 1)] 0 __________________________________________________________________________________________________ layer_normalization (LayerNorma (None, 500, 1) 2 input_1[0][0] __________________________________________________________________________________________________ multi_head_attention (MultiHead (None, 500, 1) 7169 layer_normalization[0][0] layer_normalization[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 500, 1) 0 multi_head_attention[0][0] __________________________________________________________________________________________________ tf.__operators__.add (TFOpLambd (None, 500, 1) 0 dropout[0][0] input_1[0][0] __________________________________________________________________________________________________ layer_normalization_1 (LayerNor (None, 500, 1) 2 tf.__operators__.add[0][0] __________________________________________________________________________________________________ conv1d (Conv1D) (None, 500, 4) 8 layer_normalization_1[0][0] __________________________________________________________________________________________________ dropout_1 (Dropout) (None, 500, 4) 0 conv1d[0][0] __________________________________________________________________________________________________ conv1d_1 (Conv1D) (None, 500, 1) 5 dropout_1[0][0] __________________________________________________________________________________________________ tf.__operators__.add_1 (TFOpLam (None, 500, 1) 0 conv1d_1[0][0] tf.__operators__.add[0][0] __________________________________________________________________________________________________ layer_normalization_2 (LayerNor (None, 500, 1) 2 tf.__operators__.add_1[0][0] __________________________________________________________________________________________________ multi_head_attention_1 (MultiHe (None, 500, 1) 7169 layer_normalization_2[0][0] layer_normalization_2[0][0] __________________________________________________________________________________________________ dropout_2 (Dropout) (None, 500, 1) 0 multi_head_attention_1[0][0] __________________________________________________________________________________________________ tf.__operators__.add_2 (TFOpLam (None, 500, 1) 0 dropout_2[0][0] tf.__operators__.add_1[0][0] __________________________________________________________________________________________________ layer_normalization_3 (LayerNor (None, 500, 1) 2 tf.__operators__.add_2[0][0] __________________________________________________________________________________________________ conv1d_2 (Conv1D) (None, 500, 4) 8 layer_normalization_3[0][0] __________________________________________________________________________________________________ dropout_3 (Dropout) (None, 500, 4) 0 conv1d_2[0][0] __________________________________________________________________________________________________ conv1d_3 (Conv1D) (None, 500, 1) 5 dropout_3[0][0] __________________________________________________________________________________________________ tf.__operators__.add_3 (TFOpLam (None, 500, 1) 0 conv1d_3[0][0] tf.__operators__.add_2[0][0] __________________________________________________________________________________________________ layer_normalization_4 (LayerNor (None, 500, 1) 2 tf.__operators__.add_3[0][0] __________________________________________________________________________________________________ multi_head_attention_2 (MultiHe (None, 500, 1) 7169 layer_normalization_4[0][0] layer_normalization_4[0][0] __________________________________________________________________________________________________ dropout_4 (Dropout) (None, 500, 1) 0 multi_head_attention_2[0][0] __________________________________________________________________________________________________ tf.__operators__.add_4 (TFOpLam (None, 500, 1) 0 dropout_4[0][0] tf.__operators__.add_3[0][0] __________________________________________________________________________________________________ layer_normalization_5 (LayerNor (None, 500, 1) 2 tf.__operators__.add_4[0][0] __________________________________________________________________________________________________ conv1d_4 (Conv1D) (None, 500, 4) 8 layer_normalization_5[0][0] __________________________________________________________________________________________________ dropout_5 (Dropout) (None, 500, 4) 0 conv1d_4[0][0] __________________________________________________________________________________________________ conv1d_5 (Conv1D) (None, 500, 1) 5 dropout_5[0][0] __________________________________________________________________________________________________ tf.__operators__.add_5 (TFOpLam (None, 500, 1) 0 conv1d_5[0][0] tf.__operators__.add_4[0][0] __________________________________________________________________________________________________ layer_normalization_6 (LayerNor (None, 500, 1) 2 tf.__operators__.add_5[0][0] __________________________________________________________________________________________________ multi_head_attention_3 (MultiHe (None, 500, 1) 7169 layer_normalization_6[0][0] layer_normalization_6[0][0] __________________________________________________________________________________________________ dropout_6 (Dropout) (None, 500, 1) 0 multi_head_attention_3[0][0] __________________________________________________________________________________________________ tf.__operators__.add_6 (TFOpLam (None, 500, 1) 0 dropout_6[0][0] tf.__operators__.add_5[0][0] __________________________________________________________________________________________________ layer_normalization_7 (LayerNor (None, 500, 1) 2 tf.__operators__.add_6[0][0] __________________________________________________________________________________________________ conv1d_6 (Conv1D) (None, 500, 4) 8 layer_normalization_7[0][0] __________________________________________________________________________________________________ dropout_7 (Dropout) (None, 500, 4) 0 conv1d_6[0][0] __________________________________________________________________________________________________ conv1d_7 (Conv1D) (None, 500, 1) 5 dropout_7[0][0] __________________________________________________________________________________________________ tf.__operators__.add_7 (TFOpLam (None, 500, 1) 0 conv1d_7[0][0] tf.__operators__.add_6[0][0] __________________________________________________________________________________________________ global_average_pooling1d (Globa (None, 500) 0 tf.__operators__.add_7[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 128) 64128 global_average_pooling1d[0][0] __________________________________________________________________________________________________ dropout_8 (Dropout) (None, 128) 0 dense[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 2) 258 dropout_8[0][0] ================================================================================================== Total params: 93,130 Trainable params: 93,130 Non-trainable params: 0 __________________________________________________________________________________________________ Epoch 1/200 45/45 [==============================] - 26s 499ms/step - loss: 1.0233 - sparse_categorical_accuracy: 0.5174 - val_loss: 0.7853 - val_sparse_categorical_accuracy: 0.5368 Epoch 2/200 45/45 [==============================] - 22s 499ms/step - loss: 0.9108 - sparse_categorical_accuracy: 0.5507 - val_loss: 0.7169 - val_sparse_categorical_accuracy: 0.5659 Epoch 3/200 45/45 [==============================] - 23s 509ms/step - loss: 0.8177 - sparse_categorical_accuracy: 0.5851 - val_loss: 0.6851 - val_sparse_categorical_accuracy: 0.5839 Epoch 4/200 45/45 [==============================] - 24s 532ms/step - loss: 0.7494 - sparse_categorical_accuracy: 0.6160 - val_loss: 0.6554 - val_sparse_categorical_accuracy: 0.6214 Epoch 5/200 45/45 [==============================] - 23s 520ms/step - loss: 0.7287 - sparse_categorical_accuracy: 0.6319 - val_loss: 0.6333 - val_sparse_categorical_accuracy: 0.6463 Epoch 6/200 45/45 [==============================] - 23s 509ms/step - loss: 0.7108 - sparse_categorical_accuracy: 0.6424 - val_loss: 0.6185 - val_sparse_categorical_accuracy: 0.6546 Epoch 7/200 45/45 [==============================] - 23s 512ms/step - loss: 0.6624 - sparse_categorical_accuracy: 0.6667 - val_loss: 0.6023 - val_sparse_categorical_accuracy: 0.6657 Epoch 8/200 45/45 [==============================] - 23s 518ms/step - loss: 0.6392 - sparse_categorical_accuracy: 0.6774 - val_loss: 0.5935 - val_sparse_categorical_accuracy: 0.6796 Epoch 9/200 45/45 [==============================] - 23s 513ms/step - loss: 0.5978 - sparse_categorical_accuracy: 0.6955 - val_loss: 0.5778 - val_sparse_categorical_accuracy: 0.6907 Epoch 10/200 45/45 [==============================] - 23s 511ms/step - loss: 0.5909 - sparse_categorical_accuracy: 0.6948 - val_loss: 0.5687 - val_sparse_categorical_accuracy: 0.6935 Epoch 11/200 45/45 [==============================] - 23s 513ms/step - loss: 0.5785 - sparse_categorical_accuracy: 0.7021 - val_loss: 0.5628 - val_sparse_categorical_accuracy: 0.6990 Epoch 12/200 45/45 [==============================] - 23s 514ms/step - loss: 0.5547 - sparse_categorical_accuracy: 0.7247 - val_loss: 0.5545 - val_sparse_categorical_accuracy: 0.7101 Epoch 13/200 45/45 [==============================] - 24s 535ms/step - loss: 0.5705 - sparse_categorical_accuracy: 0.7240 - val_loss: 0.5461 - val_sparse_categorical_accuracy: 0.7240 Epoch 14/200 45/45 [==============================] - 23s 517ms/step - loss: 0.5538 - sparse_categorical_accuracy: 0.7250 - val_loss: 0.5403 - val_sparse_categorical_accuracy: 0.7212 Epoch 15/200 45/45 [==============================] - 23s 515ms/step - loss: 0.5144 - sparse_categorical_accuracy: 0.7500 - val_loss: 0.5318 - val_sparse_categorical_accuracy: 0.7295 Epoch 16/200 45/45 [==============================] - 23s 512ms/step - loss: 0.5200 - sparse_categorical_accuracy: 0.7521 - val_loss: 0.5286 - val_sparse_categorical_accuracy: 0.7379 Epoch 17/200 45/45 [==============================] - 23s 515ms/step - loss: 0.4910 - sparse_categorical_accuracy: 0.7590 - val_loss: 0.5229 - val_sparse_categorical_accuracy: 0.7393 Epoch 18/200 45/45 [==============================] - 23s 514ms/step - loss: 0.5013 - sparse_categorical_accuracy: 0.7427 - val_loss: 0.5157 - val_sparse_categorical_accuracy: 0.7462 Epoch 19/200 45/45 [==============================] - 23s 511ms/step - loss: 0.4883 - sparse_categorical_accuracy: 0.7712 - val_loss: 0.5123 - val_sparse_categorical_accuracy: 0.7490 Epoch 20/200 45/45 [==============================] - 23s 514ms/step - loss: 0.4935 - sparse_categorical_accuracy: 0.7667 - val_loss: 0.5032 - val_sparse_categorical_accuracy: 0.7545 Epoch 21/200 45/45 [==============================] - 23s 514ms/step - loss: 0.4551 - sparse_categorical_accuracy: 0.7799 - val_loss: 0.4978 - val_sparse_categorical_accuracy: 0.7573 Epoch 22/200 45/45 [==============================] - 23s 516ms/step - loss: 0.4477 - sparse_categorical_accuracy: 0.7948 - val_loss: 0.4941 - val_sparse_categorical_accuracy: 0.7531 Epoch 23/200 45/45 [==============================] - 23s 518ms/step - loss: 0.4549 - sparse_categorical_accuracy: 0.7858 - val_loss: 0.4893 - val_sparse_categorical_accuracy: 0.7656 Epoch 24/200 45/45 [==============================] - 23s 516ms/step - loss: 0.4426 - sparse_categorical_accuracy: 0.7948 - val_loss: 0.4842 - val_sparse_categorical_accuracy: 0.7712 Epoch 25/200 45/45 [==============================] - 23s 520ms/step - loss: 0.4360 - sparse_categorical_accuracy: 0.8035 - val_loss: 0.4798 - val_sparse_categorical_accuracy: 0.7809 Epoch 26/200 45/45 [==============================] - 23s 515ms/step - loss: 0.4316 - sparse_categorical_accuracy: 0.8035 - val_loss: 0.4715 - val_sparse_categorical_accuracy: 0.7809 Epoch 27/200 45/45 [==============================] - 23s 518ms/step - loss: 0.4084 - sparse_categorical_accuracy: 0.8146 - val_loss: 0.4676 - val_sparse_categorical_accuracy: 0.7878 Epoch 28/200 45/45 [==============================] - 23s 515ms/step - loss: 0.3998 - sparse_categorical_accuracy: 0.8240 - val_loss: 0.4667 - val_sparse_categorical_accuracy: 0.7933 Epoch 29/200 45/45 [==============================] - 23s 514ms/step - loss: 0.3993 - sparse_categorical_accuracy: 0.8198 - val_loss: 0.4603 - val_sparse_categorical_accuracy: 0.7892 Epoch 30/200 45/45 [==============================] - 23s 515ms/step - loss: 0.4031 - sparse_categorical_accuracy: 0.8243 - val_loss: 0.4562 - val_sparse_categorical_accuracy: 0.7920 Epoch 31/200 45/45 [==============================] - 23s 511ms/step - loss: 0.3891 - sparse_categorical_accuracy: 0.8184 - val_loss: 0.4528 - val_sparse_categorical_accuracy: 0.7920 Epoch 32/200 45/45 [==============================] - 23s 516ms/step - loss: 0.3922 - sparse_categorical_accuracy: 0.8292 - val_loss: 0.4485 - val_sparse_categorical_accuracy: 0.7892 Epoch 33/200 45/45 [==============================] - 23s 516ms/step - loss: 0.3802 - sparse_categorical_accuracy: 0.8309 - val_loss: 0.4463 - val_sparse_categorical_accuracy: 0.8003 Epoch 34/200 45/45 [==============================] - 23s 514ms/step - loss: 0.3711 - sparse_categorical_accuracy: 0.8372 - val_loss: 0.4427 - val_sparse_categorical_accuracy: 0.7975 Epoch 35/200 45/45 [==============================] - 23s 512ms/step - loss: 0.3744 - sparse_categorical_accuracy: 0.8378 - val_loss: 0.4366 - val_sparse_categorical_accuracy: 0.8072 Epoch 36/200 45/45 [==============================] - 23s 511ms/step - loss: 0.3653 - sparse_categorical_accuracy: 0.8372 - val_loss: 0.4338 - val_sparse_categorical_accuracy: 0.8072 Epoch 37/200 45/45 [==============================] - 23s 512ms/step - loss: 0.3681 - sparse_categorical_accuracy: 0.8382 - val_loss: 0.4337 - val_sparse_categorical_accuracy: 0.8058 Epoch 38/200 45/45 [==============================] - 23s 512ms/step - loss: 0.3634 - sparse_categorical_accuracy: 0.8514 - val_loss: 0.4264 - val_sparse_categorical_accuracy: 0.8128 Epoch 39/200 45/45 [==============================] - 23s 512ms/step - loss: 0.3498 - sparse_categorical_accuracy: 0.8535 - val_loss: 0.4211 - val_sparse_categorical_accuracy: 0.8225 Epoch 40/200 45/45 [==============================] - 23s 514ms/step - loss: 0.3358 - sparse_categorical_accuracy: 0.8663 - val_loss: 0.4161 - val_sparse_categorical_accuracy: 0.8197 Epoch 41/200 45/45 [==============================] - 23s 512ms/step - loss: 0.3448 - sparse_categorical_accuracy: 0.8573 - val_loss: 0.4161 - val_sparse_categorical_accuracy: 0.8169 Epoch 42/200 45/45 [==============================] - 23s 512ms/step - loss: 0.3439 - sparse_categorical_accuracy: 0.8552 - val_loss: 0.4119 - val_sparse_categorical_accuracy: 0.8211 Epoch 43/200 45/45 [==============================] - 23s 510ms/step - loss: 0.3335 - sparse_categorical_accuracy: 0.8660 - val_loss: 0.4101 - val_sparse_categorical_accuracy: 0.8266 Epoch 44/200 45/45 [==============================] - 23s 510ms/step - loss: 0.3235 - sparse_categorical_accuracy: 0.8660 - val_loss: 0.4067 - val_sparse_categorical_accuracy: 0.8294 Epoch 45/200 45/45 [==============================] - 23s 510ms/step - loss: 0.3273 - sparse_categorical_accuracy: 0.8656 - val_loss: 0.4033 - val_sparse_categorical_accuracy: 0.8350 Epoch 46/200 45/45 [==============================] - 23s 513ms/step - loss: 0.3277 - sparse_categorical_accuracy: 0.8608 - val_loss: 0.3994 - val_sparse_categorical_accuracy: 0.8336 Epoch 47/200 45/45 [==============================] - 23s 519ms/step - loss: 0.3136 - sparse_categorical_accuracy: 0.8708 - val_loss: 0.3945 - val_sparse_categorical_accuracy: 0.8363 Epoch 48/200 45/45 [==============================] - 23s 518ms/step - loss: 0.3122 - sparse_categorical_accuracy: 0.8764 - val_loss: 0.3925 - val_sparse_categorical_accuracy: 0.8350 Epoch 49/200 45/45 [==============================] - 23s 519ms/step - loss: 0.3035 - sparse_categorical_accuracy: 0.8826 - val_loss: 0.3906 - val_sparse_categorical_accuracy: 0.8308 Epoch 50/200 45/45 [==============================] - 23s 512ms/step - loss: 0.2994 - sparse_categorical_accuracy: 0.8823 - val_loss: 0.3888 - val_sparse_categorical_accuracy: 0.8377 Epoch 51/200 45/45 [==============================] - 23s 514ms/step - loss: 0.3023 - sparse_categorical_accuracy: 0.8781 - val_loss: 0.3862 - val_sparse_categorical_accuracy: 0.8391 Epoch 52/200 45/45 [==============================] - 23s 515ms/step - loss: 0.3012 - sparse_categorical_accuracy: 0.8833 - val_loss: 0.3854 - val_sparse_categorical_accuracy: 0.8350 Epoch 53/200 45/45 [==============================] - 23s 513ms/step - loss: 0.2890 - sparse_categorical_accuracy: 0.8837 - val_loss: 0.3837 - val_sparse_categorical_accuracy: 0.8363 Epoch 54/200 45/45 [==============================] - 23s 513ms/step - loss: 0.2931 - sparse_categorical_accuracy: 0.8858 - val_loss: 0.3809 - val_sparse_categorical_accuracy: 0.8433 Epoch 55/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2867 - sparse_categorical_accuracy: 0.8885 - val_loss: 0.3784 - val_sparse_categorical_accuracy: 0.8447 Epoch 56/200 45/45 [==============================] - 23s 511ms/step - loss: 0.2731 - sparse_categorical_accuracy: 0.8986 - val_loss: 0.3756 - val_sparse_categorical_accuracy: 0.8488 Epoch 57/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2754 - sparse_categorical_accuracy: 0.8955 - val_loss: 0.3759 - val_sparse_categorical_accuracy: 0.8474 Epoch 58/200 45/45 [==============================] - 23s 511ms/step - loss: 0.2775 - sparse_categorical_accuracy: 0.8976 - val_loss: 0.3704 - val_sparse_categorical_accuracy: 0.8474 Epoch 59/200 45/45 [==============================] - 23s 513ms/step - loss: 0.2770 - sparse_categorical_accuracy: 0.9000 - val_loss: 0.3698 - val_sparse_categorical_accuracy: 0.8558 Epoch 60/200 45/45 [==============================] - 23s 516ms/step - loss: 0.2688 - sparse_categorical_accuracy: 0.8965 - val_loss: 0.3697 - val_sparse_categorical_accuracy: 0.8502 Epoch 61/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2716 - sparse_categorical_accuracy: 0.8972 - val_loss: 0.3710 - val_sparse_categorical_accuracy: 0.8405 Epoch 62/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2635 - sparse_categorical_accuracy: 0.9087 - val_loss: 0.3656 - val_sparse_categorical_accuracy: 0.8488 Epoch 63/200 45/45 [==============================] - 23s 520ms/step - loss: 0.2596 - sparse_categorical_accuracy: 0.8979 - val_loss: 0.3654 - val_sparse_categorical_accuracy: 0.8488 Epoch 64/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2586 - sparse_categorical_accuracy: 0.9062 - val_loss: 0.3634 - val_sparse_categorical_accuracy: 0.8530 Epoch 65/200 45/45 [==============================] - 23s 516ms/step - loss: 0.2491 - sparse_categorical_accuracy: 0.9139 - val_loss: 0.3591 - val_sparse_categorical_accuracy: 0.8530 Epoch 66/200 45/45 [==============================] - 23s 519ms/step - loss: 0.2600 - sparse_categorical_accuracy: 0.9017 - val_loss: 0.3621 - val_sparse_categorical_accuracy: 0.8516 Epoch 67/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2465 - sparse_categorical_accuracy: 0.9156 - val_loss: 0.3608 - val_sparse_categorical_accuracy: 0.8488 Epoch 68/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2502 - sparse_categorical_accuracy: 0.9101 - val_loss: 0.3557 - val_sparse_categorical_accuracy: 0.8627 Epoch 69/200 45/45 [==============================] - 23s 517ms/step - loss: 0.2418 - sparse_categorical_accuracy: 0.9104 - val_loss: 0.3561 - val_sparse_categorical_accuracy: 0.8502 Epoch 70/200 45/45 [==============================] - 23s 516ms/step - loss: 0.2463 - sparse_categorical_accuracy: 0.9049 - val_loss: 0.3554 - val_sparse_categorical_accuracy: 0.8613 Epoch 71/200 45/45 [==============================] - 23s 520ms/step - loss: 0.2372 - sparse_categorical_accuracy: 0.9177 - val_loss: 0.3548 - val_sparse_categorical_accuracy: 0.8627 Epoch 72/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2365 - sparse_categorical_accuracy: 0.9118 - val_loss: 0.3528 - val_sparse_categorical_accuracy: 0.8655 Epoch 73/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2420 - sparse_categorical_accuracy: 0.9083 - val_loss: 0.3510 - val_sparse_categorical_accuracy: 0.8655 Epoch 74/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2342 - sparse_categorical_accuracy: 0.9205 - val_loss: 0.3478 - val_sparse_categorical_accuracy: 0.8669 Epoch 75/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2337 - sparse_categorical_accuracy: 0.9062 - val_loss: 0.3484 - val_sparse_categorical_accuracy: 0.8655 Epoch 76/200 45/45 [==============================] - 23s 516ms/step - loss: 0.2298 - sparse_categorical_accuracy: 0.9153 - val_loss: 0.3478 - val_sparse_categorical_accuracy: 0.8585 Epoch 77/200 45/45 [==============================] - 23s 516ms/step - loss: 0.2218 - sparse_categorical_accuracy: 0.9243 - val_loss: 0.3467 - val_sparse_categorical_accuracy: 0.8613 Epoch 78/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2352 - sparse_categorical_accuracy: 0.9083 - val_loss: 0.3431 - val_sparse_categorical_accuracy: 0.8641 Epoch 79/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2218 - sparse_categorical_accuracy: 0.9194 - val_loss: 0.3448 - val_sparse_categorical_accuracy: 0.8613 Epoch 80/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2246 - sparse_categorical_accuracy: 0.9198 - val_loss: 0.3417 - val_sparse_categorical_accuracy: 0.8682 Epoch 81/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2168 - sparse_categorical_accuracy: 0.9201 - val_loss: 0.3397 - val_sparse_categorical_accuracy: 0.8641 Epoch 82/200 45/45 [==============================] - 23s 517ms/step - loss: 0.2254 - sparse_categorical_accuracy: 0.9153 - val_loss: 0.3373 - val_sparse_categorical_accuracy: 0.8682 Epoch 83/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2230 - sparse_categorical_accuracy: 0.9194 - val_loss: 0.3391 - val_sparse_categorical_accuracy: 0.8655 Epoch 84/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2124 - sparse_categorical_accuracy: 0.9240 - val_loss: 0.3370 - val_sparse_categorical_accuracy: 0.8682 Epoch 85/200 45/45 [==============================] - 23s 515ms/step - loss: 0.2123 - sparse_categorical_accuracy: 0.9278 - val_loss: 0.3394 - val_sparse_categorical_accuracy: 0.8571 Epoch 86/200 45/45 [==============================] - 23s 520ms/step - loss: 0.2119 - sparse_categorical_accuracy: 0.9260 - val_loss: 0.3355 - val_sparse_categorical_accuracy: 0.8627 Epoch 87/200 45/45 [==============================] - 23s 517ms/step - loss: 0.2052 - sparse_categorical_accuracy: 0.9247 - val_loss: 0.3353 - val_sparse_categorical_accuracy: 0.8738 Epoch 88/200 45/45 [==============================] - 23s 518ms/step - loss: 0.2089 - sparse_categorical_accuracy: 0.9299 - val_loss: 0.3342 - val_sparse_categorical_accuracy: 0.8779 Epoch 89/200 45/45 [==============================] - 23s 519ms/step - loss: 0.2027 - sparse_categorical_accuracy: 0.9250 - val_loss: 0.3353 - val_sparse_categorical_accuracy: 0.8793 Epoch 90/200 45/45 [==============================] - 23s 517ms/step - loss: 0.2110 - sparse_categorical_accuracy: 0.9264 - val_loss: 0.3320 - val_sparse_categorical_accuracy: 0.8752 Epoch 91/200 45/45 [==============================] - 23s 516ms/step - loss: 0.1965 - sparse_categorical_accuracy: 0.9292 - val_loss: 0.3339 - val_sparse_categorical_accuracy: 0.8710 Epoch 92/200 45/45 [==============================] - 23s 520ms/step - loss: 0.2030 - sparse_categorical_accuracy: 0.9253 - val_loss: 0.3296 - val_sparse_categorical_accuracy: 0.8752 Epoch 93/200 45/45 [==============================] - 23s 519ms/step - loss: 0.1969 - sparse_categorical_accuracy: 0.9347 - val_loss: 0.3298 - val_sparse_categorical_accuracy: 0.8807 Epoch 94/200 45/45 [==============================] - 23s 518ms/step - loss: 0.1939 - sparse_categorical_accuracy: 0.9295 - val_loss: 0.3300 - val_sparse_categorical_accuracy: 0.8779 Epoch 95/200 45/45 [==============================] - 23s 517ms/step - loss: 0.1930 - sparse_categorical_accuracy: 0.9330 - val_loss: 0.3305 - val_sparse_categorical_accuracy: 0.8766 Epoch 96/200 45/45 [==============================] - 23s 518ms/step - loss: 0.1946 - sparse_categorical_accuracy: 0.9288 - val_loss: 0.3288 - val_sparse_categorical_accuracy: 0.8669 Epoch 97/200 45/45 [==============================] - 23s 518ms/step - loss: 0.1951 - sparse_categorical_accuracy: 0.9264 - val_loss: 0.3281 - val_sparse_categorical_accuracy: 0.8682 Epoch 98/200 45/45 [==============================] - 23s 516ms/step - loss: 0.1899 - sparse_categorical_accuracy: 0.9354 - val_loss: 0.3307 - val_sparse_categorical_accuracy: 0.8696 Epoch 99/200 45/45 [==============================] - 23s 519ms/step - loss: 0.1901 - sparse_categorical_accuracy: 0.9250 - val_loss: 0.3307 - val_sparse_categorical_accuracy: 0.8710 Epoch 100/200 45/45 [==============================] - 23s 516ms/step - loss: 0.1902 - sparse_categorical_accuracy: 0.9319 - val_loss: 0.3259 - val_sparse_categorical_accuracy: 0.8696 Epoch 101/200 45/45 [==============================] - 23s 518ms/step - loss: 0.1868 - sparse_categorical_accuracy: 0.9358 - val_loss: 0.3262 - val_sparse_categorical_accuracy: 0.8724 Epoch 102/200 45/45 [==============================] - 23s 518ms/step - loss: 0.1779 - sparse_categorical_accuracy: 0.9431 - val_loss: 0.3250 - val_sparse_categorical_accuracy: 0.8710 Epoch 103/200 45/45 [==============================] - 23s 520ms/step - loss: 0.1870 - sparse_categorical_accuracy: 0.9351 - val_loss: 0.3260 - val_sparse_categorical_accuracy: 0.8724 Epoch 104/200 45/45 [==============================] - 23s 521ms/step - loss: 0.1826 - sparse_categorical_accuracy: 0.9344 - val_loss: 0.3232 - val_sparse_categorical_accuracy: 0.8766 Epoch 105/200 45/45 [==============================] - 23s 519ms/step - loss: 0.1731 - sparse_categorical_accuracy: 0.9399 - val_loss: 0.3245 - val_sparse_categorical_accuracy: 0.8724 Epoch 106/200 45/45 [==============================] - 23s 518ms/step - loss: 0.1766 - sparse_categorical_accuracy: 0.9361 - val_loss: 0.3254 - val_sparse_categorical_accuracy: 0.8682 Epoch 107/200
終わりに
およそ 110-120 エポック (25s each on Colab) で、ハイパーパラメータ調整なしに、モデルは ~0.95 の訓練精度、~84 の検証精度、そして ~85 のテスト精度に到達します。更にそれは 100k パラメータ未満のモデルに対してです。もちろん、パラメータ数と精度はハイパーパラメータ探索とより洗練された学習率スケジュールや別の optimizer により改良できるでしょう。
You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.
以上