Keras 2 : examples : FixRes: 訓練/テスト解像度の不一致の修正 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/19/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : FixRes: Fixing train-test resolution discrepancy (Author: Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : FixRes: 訓練/テスト解像度の不一致の修正
Description : 訓練とテストセットの間の解像度の不一致の緩和。
イントロダクション
視覚モデルを訓練してテストする際、同じ入力画像解像度を使用することは一般的な実践です。けれども、Fixing the train-test resolution discrepancy (Touvron et al.) で研究されたように、この実践は次善の性能に繋がります (最適な性能にはなりません)。データ増強は深層ニューラルネットワークの訓練プロセスの必須のパートです。視覚モデルについては、通常は訓練の間にはランダムにリサイズされたクロップを、推論の間にはセンタークロップを使用します。これは訓練と推論の間に見られるオブジェクトサイズの不一致を導入してしまいます。Touvron et al. で示されたように、この不一致を修正できれば、モデル性能を大幅にブーストできます。
このサンプルでは、この不一致を修正するために Touvron et al. により導入された FixRes テクニックを実装します。
インポート
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
import matplotlib.pyplot as plt
tf_flowers データセットのロード
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
Number of training examples: 3303 Number of validation examples: 367
データ前処理ユティリティ
3 つのデータセットを作成します :
- より小さい解像度のデータセット – 128×128
- より大きい解像度の 2 つのデータセット – 224×224
より大きい解像度のデータセットに様々な増強変換を適用します。
FixRes のアイデアは最初に小さい解像度のデータセットでモデルを訓練してから、それをより大きい解像度のデータセットで再調整することです。この単純で効果的なレシピは自明ではないパフォーマンス改良に繋がります。結果については 元の論文 を参照してください。
# Reference: https://github.com/facebookresearch/FixRes/blob/main/transforms_v2.py.
batch_size = 128
auto = tf.data.AUTOTUNE
smaller_size = 128
bigger_size = 224
size_for_resizing = int((bigger_size / smaller_size) * bigger_size)
central_crop_layer = layers.CenterCrop(bigger_size, bigger_size)
def preprocess_initial(train, image_size):
"""Initial preprocessing function for training on smaller resolution.
For training, do random_horizontal_flip -> random_crop.
For validation, just resize.
No color-jittering has been used.
"""
def _pp(image, label, train):
if train:
channels = image.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0,
use_image_if_no_bounding_boxes=True,
)
image = tf.slice(image, begin, size)
image.set_shape([None, None, channels])
image = tf.image.resize(image, [image_size, image_size])
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, [image_size, image_size])
return image, label
return _pp
def preprocess_finetune(image, label, train):
"""Preprocessing function for fine-tuning on a higher resolution.
For training, resize to a bigger resolution to maintain the ratio ->
random_horizontal_flip -> center_crop.
For validation, do the same without any horizontal flipping.
No color-jittering has been used.
"""
image = tf.image.resize(image, [size_for_resizing, size_for_resizing])
if train:
image = tf.image.random_flip_left_right(image)
image = central_crop_layer(image[None, ...])[0]
return image, label
def make_dataset(
dataset: tf.data.Dataset,
train: bool,
image_size: int = smaller_size,
fixres: bool = True,
num_parallel_calls=auto,
):
if image_size not in [smaller_size, bigger_size]:
raise ValueError(f"{image_size} resolution is not supported.")
# Determine which preprocessing function we are using.
if image_size == smaller_size:
preprocess_func = preprocess_initial(train, image_size)
elif not fixres and image_size == bigger_size:
preprocess_func = preprocess_initial(train, image_size)
else:
preprocess_func = preprocess_finetune
if train:
dataset = dataset.shuffle(batch_size * 10)
return (
dataset.map(
lambda x, y: preprocess_func(x, y, train),
num_parallel_calls=num_parallel_calls,
)
.batch(batch_size)
.prefetch(num_parallel_calls)
)
準備しているデータセットに対して増強変換がどのように変化するかに注目してください。
データセットの準備
initial_train_dataset = make_dataset(train_dataset, train=True, image_size=smaller_size)
initial_val_dataset = make_dataset(val_dataset, train=False, image_size=smaller_size)
finetune_train_dataset = make_dataset(train_dataset, train=True, image_size=bigger_size)
finetune_val_dataset = make_dataset(val_dataset, train=False, image_size=bigger_size)
vanilla_train_dataset = make_dataset(
train_dataset, train=True, image_size=bigger_size, fixres=False
)
vanilla_val_dataset = make_dataset(
val_dataset, train=False, image_size=bigger_size, fixres=False
)
データセットの可視化
def visualize_dataset(batch_images):
plt.figure(figsize=(10, 10))
for n in range(25):
ax = plt.subplot(5, 5, n + 1)
plt.imshow(batch_images[n].numpy().astype("int"))
plt.axis("off")
plt.show()
print(f"Batch shape: {batch_images.shape}.")
# Smaller resolution.
initial_sample_images, _ = next(iter(initial_train_dataset))
visualize_dataset(initial_sample_images)
# Bigger resolution, only for fine-tuning.
finetune_sample_images, _ = next(iter(finetune_train_dataset))
visualize_dataset(finetune_sample_images)
# Bigger resolution, with the same augmentation transforms as
# the smaller resolution dataset.
vanilla_sample_images, _ = next(iter(vanilla_train_dataset))
visualize_dataset(vanilla_sample_images)
Batch shape: (128, 128, 128, 3).
Batch shape: (128, 224, 224, 3).
Batch shape: (128, 224, 224, 3).
モデル訓練ユティリティ
ResNet50V2 (He et al.) の複数のバリエーションを訓練します :
- 小さい解像度 (128×128) のデータセット上で。それはスクラッチから訓練されます。
- 次に 1 からのモデルをより大きい解像度 (224×224) のデータセットで再調整します。
- 別の ResNet50V2 を大きい解像度のデータセットでスクラッチから訓練します。
大きい解像度のデータセットは増強変換の観点からは異なることに注意してください。
def get_training_model(num_classes=5):
inputs = layers.Input((None, None, 3))
resnet_base = keras.applications.ResNet50V2(
include_top=False, weights=None, pooling="avg"
)
resnet_base.trainable = True
x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(inputs)
x = resnet_base(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
def train_and_evaluate(
model, train_ds, val_ds, epochs, learning_rate=1e-3, use_early_stopping=False
):
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
if use_early_stopping:
es_callback = keras.callbacks.EarlyStopping(patience=5)
callbacks = [es_callback]
else:
callbacks = None
model.fit(
train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks,
)
_, accuracy = model.evaluate(val_ds)
print(f"Top-1 accuracy on the validation set: {accuracy*100:.2f}%.")
return model
実験 1 : 128×128 上で訓練してから 224×224 上で再調整する
epochs = 30
smaller_res_model = get_training_model()
smaller_res_model = train_and_evaluate(
smaller_res_model, initial_train_dataset, initial_val_dataset, epochs
)
Epoch 1/30 26/26 [==============================] - 14s 226ms/step - loss: 1.6476 - accuracy: 0.4345 - val_loss: 9.8213 - val_accuracy: 0.2044 Epoch 2/30 26/26 [==============================] - 3s 123ms/step - loss: 1.1561 - accuracy: 0.5495 - val_loss: 6.5521 - val_accuracy: 0.2071 Epoch 3/30 26/26 [==============================] - 3s 123ms/step - loss: 1.0989 - accuracy: 0.5722 - val_loss: 2.6216 - val_accuracy: 0.1935 Epoch 4/30 26/26 [==============================] - 3s 122ms/step - loss: 1.0373 - accuracy: 0.5895 - val_loss: 1.9918 - val_accuracy: 0.2125 Epoch 5/30 26/26 [==============================] - 3s 122ms/step - loss: 0.9960 - accuracy: 0.6119 - val_loss: 2.8505 - val_accuracy: 0.2262 Epoch 6/30 26/26 [==============================] - 3s 122ms/step - loss: 0.9458 - accuracy: 0.6331 - val_loss: 1.8974 - val_accuracy: 0.2834 Epoch 7/30 26/26 [==============================] - 3s 122ms/step - loss: 0.8949 - accuracy: 0.6606 - val_loss: 2.1164 - val_accuracy: 0.2834 Epoch 8/30 26/26 [==============================] - 3s 122ms/step - loss: 0.8581 - accuracy: 0.6709 - val_loss: 1.8858 - val_accuracy: 0.3815 Epoch 9/30 26/26 [==============================] - 3s 123ms/step - loss: 0.8436 - accuracy: 0.6776 - val_loss: 1.5671 - val_accuracy: 0.4687 Epoch 10/30 26/26 [==============================] - 3s 123ms/step - loss: 0.8632 - accuracy: 0.6685 - val_loss: 1.5005 - val_accuracy: 0.5504 Epoch 11/30 26/26 [==============================] - 3s 123ms/step - loss: 0.8316 - accuracy: 0.6918 - val_loss: 1.1421 - val_accuracy: 0.6594 Epoch 12/30 26/26 [==============================] - 3s 123ms/step - loss: 0.7981 - accuracy: 0.6951 - val_loss: 1.2036 - val_accuracy: 0.6403 Epoch 13/30 26/26 [==============================] - 3s 122ms/step - loss: 0.8275 - accuracy: 0.6806 - val_loss: 2.2632 - val_accuracy: 0.5177 Epoch 14/30 26/26 [==============================] - 3s 122ms/step - loss: 0.8156 - accuracy: 0.6994 - val_loss: 1.1023 - val_accuracy: 0.6649 Epoch 15/30 26/26 [==============================] - 3s 122ms/step - loss: 0.7572 - accuracy: 0.7091 - val_loss: 1.6248 - val_accuracy: 0.6049 Epoch 16/30 26/26 [==============================] - 3s 123ms/step - loss: 0.7757 - accuracy: 0.7024 - val_loss: 2.0600 - val_accuracy: 0.6294 Epoch 17/30 26/26 [==============================] - 3s 122ms/step - loss: 0.7600 - accuracy: 0.7087 - val_loss: 1.5731 - val_accuracy: 0.6131 Epoch 18/30 26/26 [==============================] - 3s 122ms/step - loss: 0.7385 - accuracy: 0.7215 - val_loss: 1.8312 - val_accuracy: 0.5749 Epoch 19/30 26/26 [==============================] - 3s 122ms/step - loss: 0.7493 - accuracy: 0.7224 - val_loss: 3.0382 - val_accuracy: 0.4986 Epoch 20/30 26/26 [==============================] - 3s 122ms/step - loss: 0.7746 - accuracy: 0.7048 - val_loss: 7.8191 - val_accuracy: 0.5123 Epoch 21/30 26/26 [==============================] - 3s 123ms/step - loss: 0.7367 - accuracy: 0.7405 - val_loss: 1.9607 - val_accuracy: 0.6676 Epoch 22/30 26/26 [==============================] - 3s 122ms/step - loss: 0.6970 - accuracy: 0.7357 - val_loss: 3.1944 - val_accuracy: 0.4496 Epoch 23/30 26/26 [==============================] - 3s 122ms/step - loss: 0.7299 - accuracy: 0.7212 - val_loss: 1.4012 - val_accuracy: 0.6567 Epoch 24/30 26/26 [==============================] - 3s 122ms/step - loss: 0.6965 - accuracy: 0.7315 - val_loss: 1.9781 - val_accuracy: 0.6403 Epoch 25/30 26/26 [==============================] - 3s 124ms/step - loss: 0.6811 - accuracy: 0.7408 - val_loss: 0.9287 - val_accuracy: 0.6839 Epoch 26/30 26/26 [==============================] - 3s 123ms/step - loss: 0.6732 - accuracy: 0.7487 - val_loss: 2.9406 - val_accuracy: 0.5504 Epoch 27/30 26/26 [==============================] - 3s 122ms/step - loss: 0.6571 - accuracy: 0.7560 - val_loss: 1.6268 - val_accuracy: 0.5804 Epoch 28/30 26/26 [==============================] - 3s 122ms/step - loss: 0.6662 - accuracy: 0.7548 - val_loss: 0.9067 - val_accuracy: 0.7357 Epoch 29/30 26/26 [==============================] - 3s 122ms/step - loss: 0.6443 - accuracy: 0.7520 - val_loss: 0.7760 - val_accuracy: 0.7520 Epoch 30/30 26/26 [==============================] - 3s 122ms/step - loss: 0.6617 - accuracy: 0.7539 - val_loss: 0.6026 - val_accuracy: 0.7766 3/3 [==============================] - 0s 37ms/step - loss: 0.6026 - accuracy: 0.7766 Top-1 accuracy on the validation set: 77.66%.
(訳注 : 実験結果)
Epoch 1/30 26/26 [==============================] - 35s 353ms/step - loss: 1.4942 - accuracy: 0.4329 - val_loss: 8.4472 - val_accuracy: 0.1907 Epoch 2/30 26/26 [==============================] - 6s 213ms/step - loss: 1.1686 - accuracy: 0.5289 - val_loss: 1.7882 - val_accuracy: 0.2670 Epoch 3/30 26/26 [==============================] - 6s 212ms/step - loss: 1.1276 - accuracy: 0.5637 - val_loss: 4.7523 - val_accuracy: 0.2207 Epoch 4/30 26/26 [==============================] - 6s 211ms/step - loss: 1.1139 - accuracy: 0.5837 - val_loss: 1.5831 - val_accuracy: 0.3025 Epoch 5/30 26/26 [==============================] - 6s 213ms/step - loss: 1.0299 - accuracy: 0.6040 - val_loss: 1.9543 - val_accuracy: 0.3651 Epoch 6/30 26/26 [==============================] - 6s 212ms/step - loss: 0.9603 - accuracy: 0.6300 - val_loss: 2.1586 - val_accuracy: 0.3324 Epoch 7/30 26/26 [==============================] - 6s 213ms/step - loss: 0.9598 - accuracy: 0.6391 - val_loss: 1.3156 - val_accuracy: 0.5068 Epoch 8/30 26/26 [==============================] - 6s 216ms/step - loss: 0.9725 - accuracy: 0.6264 - val_loss: 1.9583 - val_accuracy: 0.4251 Epoch 9/30 26/26 [==============================] - 6s 213ms/step - loss: 0.9030 - accuracy: 0.6533 - val_loss: 1.6112 - val_accuracy: 0.4360 Epoch 10/30 26/26 [==============================] - 6s 213ms/step - loss: 0.8695 - accuracy: 0.6721 - val_loss: 1.2253 - val_accuracy: 0.5831 Epoch 11/30 26/26 [==============================] - 6s 211ms/step - loss: 0.8614 - accuracy: 0.6776 - val_loss: 1.0915 - val_accuracy: 0.6403 Epoch 12/30 26/26 [==============================] - 6s 212ms/step - loss: 0.8376 - accuracy: 0.6806 - val_loss: 0.8847 - val_accuracy: 0.6948 Epoch 13/30 26/26 [==============================] - 6s 211ms/step - loss: 0.8417 - accuracy: 0.6866 - val_loss: 1.6050 - val_accuracy: 0.5858 Epoch 14/30 26/26 [==============================] - 6s 213ms/step - loss: 0.8175 - accuracy: 0.6942 - val_loss: 1.4623 - val_accuracy: 0.6076 Epoch 15/30 26/26 [==============================] - 6s 213ms/step - loss: 0.8092 - accuracy: 0.6915 - val_loss: 4.5990 - val_accuracy: 0.5668 Epoch 16/30 26/26 [==============================] - 6s 213ms/step - loss: 0.7692 - accuracy: 0.7060 - val_loss: 1.5812 - val_accuracy: 0.5450 Epoch 17/30 26/26 [==============================] - 6s 212ms/step - loss: 0.7705 - accuracy: 0.7048 - val_loss: 1.1033 - val_accuracy: 0.6485 Epoch 18/30 26/26 [==============================] - 6s 211ms/step - loss: 0.7548 - accuracy: 0.7103 - val_loss: 1.2926 - val_accuracy: 0.6049 Epoch 19/30 26/26 [==============================] - 6s 211ms/step - loss: 0.7411 - accuracy: 0.7154 - val_loss: 0.9187 - val_accuracy: 0.7166 Epoch 20/30 26/26 [==============================] - 6s 213ms/step - loss: 0.7481 - accuracy: 0.7151 - val_loss: 1.5360 - val_accuracy: 0.5886 Epoch 21/30 26/26 [==============================] - 6s 211ms/step - loss: 0.7353 - accuracy: 0.7257 - val_loss: 0.8701 - val_accuracy: 0.7057 Epoch 22/30 26/26 [==============================] - 6s 212ms/step - loss: 0.7363 - accuracy: 0.7296 - val_loss: 1.3043 - val_accuracy: 0.6621 Epoch 23/30 26/26 [==============================] - 6s 212ms/step - loss: 0.7109 - accuracy: 0.7321 - val_loss: 2.6653 - val_accuracy: 0.6785 Epoch 24/30 26/26 [==============================] - 6s 211ms/step - loss: 0.7008 - accuracy: 0.7411 - val_loss: 0.6036 - val_accuracy: 0.7520 Epoch 25/30 26/26 [==============================] - 6s 211ms/step - loss: 0.6908 - accuracy: 0.7442 - val_loss: 0.9055 - val_accuracy: 0.7330 Epoch 26/30 26/26 [==============================] - 6s 211ms/step - loss: 0.6726 - accuracy: 0.7484 - val_loss: 0.9800 - val_accuracy: 0.7411 Epoch 27/30 26/26 [==============================] - 6s 215ms/step - loss: 0.6654 - accuracy: 0.7502 - val_loss: 1.0846 - val_accuracy: 0.6921 Epoch 28/30 26/26 [==============================] - 6s 212ms/step - loss: 0.6531 - accuracy: 0.7566 - val_loss: 1.0444 - val_accuracy: 0.7275 Epoch 29/30 26/26 [==============================] - 6s 216ms/step - loss: 0.6620 - accuracy: 0.7502 - val_loss: 1.2927 - val_accuracy: 0.6649 Epoch 30/30 26/26 [==============================] - 6s 214ms/step - loss: 0.6618 - accuracy: 0.7587 - val_loss: 1.0450 - val_accuracy: 0.6894 3/3 [==============================] - 0s 67ms/step - loss: 1.0450 - accuracy: 0.6894 Top-1 accuracy on the validation set: 68.94%. CPU times: user 3min 36s, sys: 9.5 s, total: 3min 45s Wall time: 5min 5s
最後のバッチ正規化層を除いて総ての層を凍結する
再調整のためには、2 層だけを訓練します :
- 最後のバッチ正規化 (Ioffe et al.) 層。
- 分類層。
グローバル平均プーリング層の前の最後のバッチ正規化層は活性統計情報の変化を補うために凍結解除しています。論文 で示されているように、最後のバッチ正規化層の凍結解除で十分です。
Keras におけるモデルの再調整の包括的なガイドについては、このチュートリアル を参照してください。
for layer in smaller_res_model.layers[2].layers:
layer.trainable = False
smaller_res_model.layers[2].get_layer("post_bn").trainable = True
epochs = 10
# Use a lower learning rate during fine-tuning.
bigger_res_model = train_and_evaluate(
smaller_res_model,
finetune_train_dataset,
finetune_val_dataset,
epochs,
learning_rate=1e-4,
)
Epoch 1/10 26/26 [==============================] - 9s 201ms/step - loss: 0.7912 - accuracy: 0.7856 - val_loss: 0.6808 - val_accuracy: 0.7575 Epoch 2/10 26/26 [==============================] - 3s 115ms/step - loss: 0.7732 - accuracy: 0.7938 - val_loss: 0.7028 - val_accuracy: 0.7684 Epoch 3/10 26/26 [==============================] - 3s 115ms/step - loss: 0.7658 - accuracy: 0.7923 - val_loss: 0.7136 - val_accuracy: 0.7629 Epoch 4/10 26/26 [==============================] - 3s 115ms/step - loss: 0.7536 - accuracy: 0.7872 - val_loss: 0.7161 - val_accuracy: 0.7684 Epoch 5/10 26/26 [==============================] - 3s 115ms/step - loss: 0.7346 - accuracy: 0.7947 - val_loss: 0.7154 - val_accuracy: 0.7711 Epoch 6/10 26/26 [==============================] - 3s 115ms/step - loss: 0.7183 - accuracy: 0.7990 - val_loss: 0.7139 - val_accuracy: 0.7684 Epoch 7/10 26/26 [==============================] - 3s 116ms/step - loss: 0.7059 - accuracy: 0.7962 - val_loss: 0.7071 - val_accuracy: 0.7738 Epoch 8/10 26/26 [==============================] - 3s 115ms/step - loss: 0.6959 - accuracy: 0.7923 - val_loss: 0.7002 - val_accuracy: 0.7738 Epoch 9/10 26/26 [==============================] - 3s 116ms/step - loss: 0.6871 - accuracy: 0.8011 - val_loss: 0.6967 - val_accuracy: 0.7711 Epoch 10/10 26/26 [==============================] - 3s 116ms/step - loss: 0.6761 - accuracy: 0.8044 - val_loss: 0.6887 - val_accuracy: 0.7738 3/3 [==============================] - 0s 95ms/step - loss: 0.6887 - accuracy: 0.7738 Top-1 accuracy on the validation set: 77.38%.
Epoch 1/10 26/26 [==============================] - 15s 405ms/step - loss: 0.7293 - accuracy: 0.7663 - val_loss: 0.7613 - val_accuracy: 0.7466 Epoch 2/10 26/26 [==============================] - 9s 341ms/step - loss: 0.6958 - accuracy: 0.7654 - val_loss: 0.7010 - val_accuracy: 0.7657 Epoch 3/10 26/26 [==============================] - 9s 341ms/step - loss: 0.6887 - accuracy: 0.7729 - val_loss: 0.6589 - val_accuracy: 0.7847 Epoch 4/10 26/26 [==============================] - 9s 340ms/step - loss: 0.6240 - accuracy: 0.7781 - val_loss: 0.6232 - val_accuracy: 0.7847 Epoch 5/10 26/26 [==============================] - 9s 340ms/step - loss: 0.6314 - accuracy: 0.7829 - val_loss: 0.6033 - val_accuracy: 0.7929 Epoch 6/10 26/26 [==============================] - 9s 344ms/step - loss: 0.6280 - accuracy: 0.7944 - val_loss: 0.5856 - val_accuracy: 0.7956 Epoch 7/10 26/26 [==============================] - 9s 339ms/step - loss: 0.6094 - accuracy: 0.7911 - val_loss: 0.5780 - val_accuracy: 0.7929 Epoch 8/10 26/26 [==============================] - 9s 338ms/step - loss: 0.6000 - accuracy: 0.7884 - val_loss: 0.5683 - val_accuracy: 0.7956 Epoch 9/10 26/26 [==============================] - 9s 339ms/step - loss: 0.6183 - accuracy: 0.7953 - val_loss: 0.5645 - val_accuracy: 0.7847 Epoch 10/10 26/26 [==============================] - 9s 337ms/step - loss: 0.5987 - accuracy: 0.7978 - val_loss: 0.5567 - val_accuracy: 0.7875 3/3 [==============================] - 1s 266ms/step - loss: 0.5567 - accuracy: 0.7875 Top-1 accuracy on the validation set: 78.75%.
実験 2 : 224×224 解像度でモデルをスクラッチから訓練する
次に、大きい解像度のデータセットでスクラッチからもう一つのモデルを訓練します。このデータセットで使用される増強変換は前とは異なることを思い出してください。
epochs = 30
vanilla_bigger_res_model = get_training_model()
vanilla_bigger_res_model = train_and_evaluate(
vanilla_bigger_res_model, vanilla_train_dataset, vanilla_val_dataset, epochs
)
Epoch 1/30 26/26 [==============================] - 15s 389ms/step - loss: 1.5339 - accuracy: 0.4569 - val_loss: 177.5233 - val_accuracy: 0.1907 Epoch 2/30 26/26 [==============================] - 8s 314ms/step - loss: 1.1472 - accuracy: 0.5483 - val_loss: 17.5804 - val_accuracy: 0.1907 Epoch 3/30 26/26 [==============================] - 8s 315ms/step - loss: 1.0708 - accuracy: 0.5792 - val_loss: 2.2719 - val_accuracy: 0.2480 Epoch 4/30 26/26 [==============================] - 8s 315ms/step - loss: 1.0225 - accuracy: 0.6170 - val_loss: 2.1274 - val_accuracy: 0.2398 Epoch 5/30 26/26 [==============================] - 8s 316ms/step - loss: 1.0001 - accuracy: 0.6206 - val_loss: 2.0375 - val_accuracy: 0.2834 Epoch 6/30 26/26 [==============================] - 8s 315ms/step - loss: 0.9602 - accuracy: 0.6355 - val_loss: 1.4412 - val_accuracy: 0.3978 Epoch 7/30 26/26 [==============================] - 8s 316ms/step - loss: 0.9418 - accuracy: 0.6461 - val_loss: 1.5257 - val_accuracy: 0.4305 Epoch 8/30 26/26 [==============================] - 8s 316ms/step - loss: 0.8911 - accuracy: 0.6649 - val_loss: 1.1530 - val_accuracy: 0.5858 Epoch 9/30 26/26 [==============================] - 8s 316ms/step - loss: 0.8834 - accuracy: 0.6694 - val_loss: 1.2026 - val_accuracy: 0.5531 Epoch 10/30 26/26 [==============================] - 8s 316ms/step - loss: 0.8752 - accuracy: 0.6724 - val_loss: 1.4917 - val_accuracy: 0.5695 Epoch 11/30 26/26 [==============================] - 8s 316ms/step - loss: 0.8690 - accuracy: 0.6594 - val_loss: 1.4115 - val_accuracy: 0.6022 Epoch 12/30 26/26 [==============================] - 8s 314ms/step - loss: 0.8586 - accuracy: 0.6761 - val_loss: 1.0692 - val_accuracy: 0.6349 Epoch 13/30 26/26 [==============================] - 8s 315ms/step - loss: 0.8120 - accuracy: 0.6894 - val_loss: 1.5233 - val_accuracy: 0.6567 Epoch 14/30 26/26 [==============================] - 8s 316ms/step - loss: 0.8275 - accuracy: 0.6857 - val_loss: 1.9079 - val_accuracy: 0.5804 Epoch 15/30 26/26 [==============================] - 8s 316ms/step - loss: 0.7624 - accuracy: 0.7127 - val_loss: 0.9543 - val_accuracy: 0.6540 Epoch 16/30 26/26 [==============================] - 8s 315ms/step - loss: 0.7595 - accuracy: 0.7266 - val_loss: 4.5757 - val_accuracy: 0.4877 Epoch 17/30 26/26 [==============================] - 8s 315ms/step - loss: 0.7577 - accuracy: 0.7154 - val_loss: 1.8411 - val_accuracy: 0.5749 Epoch 18/30 26/26 [==============================] - 8s 316ms/step - loss: 0.7596 - accuracy: 0.7163 - val_loss: 1.0660 - val_accuracy: 0.6703 Epoch 19/30 26/26 [==============================] - 8s 315ms/step - loss: 0.7492 - accuracy: 0.7160 - val_loss: 1.2462 - val_accuracy: 0.6485 Epoch 20/30 26/26 [==============================] - 8s 315ms/step - loss: 0.7269 - accuracy: 0.7330 - val_loss: 5.8287 - val_accuracy: 0.3379 Epoch 21/30 26/26 [==============================] - 8s 315ms/step - loss: 0.7193 - accuracy: 0.7275 - val_loss: 4.7058 - val_accuracy: 0.6049 Epoch 22/30 26/26 [==============================] - 8s 316ms/step - loss: 0.7251 - accuracy: 0.7318 - val_loss: 1.5608 - val_accuracy: 0.6485 Epoch 23/30 26/26 [==============================] - 8s 314ms/step - loss: 0.6888 - accuracy: 0.7466 - val_loss: 1.7914 - val_accuracy: 0.6240 Epoch 24/30 26/26 [==============================] - 8s 314ms/step - loss: 0.7051 - accuracy: 0.7339 - val_loss: 2.0918 - val_accuracy: 0.6158 Epoch 25/30 26/26 [==============================] - 8s 315ms/step - loss: 0.6920 - accuracy: 0.7454 - val_loss: 0.7284 - val_accuracy: 0.7575 Epoch 26/30 26/26 [==============================] - 8s 316ms/step - loss: 0.6502 - accuracy: 0.7523 - val_loss: 2.5474 - val_accuracy: 0.5313 Epoch 27/30 26/26 [==============================] - 8s 315ms/step - loss: 0.7101 - accuracy: 0.7330 - val_loss: 26.8117 - val_accuracy: 0.3297 Epoch 28/30 26/26 [==============================] - 8s 315ms/step - loss: 0.6632 - accuracy: 0.7548 - val_loss: 20.1011 - val_accuracy: 0.3243 Epoch 29/30 26/26 [==============================] - 8s 315ms/step - loss: 0.6682 - accuracy: 0.7505 - val_loss: 11.5872 - val_accuracy: 0.3297 Epoch 30/30 26/26 [==============================] - 8s 315ms/step - loss: 0.6758 - accuracy: 0.7514 - val_loss: 5.7229 - val_accuracy: 0.4305 3/3 [==============================] - 0s 95ms/step - loss: 5.7229 - accuracy: 0.4305 Top-1 accuracy on the validation set: 43.05%.
Epoch 1/30 26/26 [==============================] - 24s 615ms/step - loss: 1.4819 - accuracy: 0.4541 - val_loss: 204.4962 - val_accuracy: 0.2371 Epoch 2/30 26/26 [==============================] - 14s 534ms/step - loss: 1.1513 - accuracy: 0.5537 - val_loss: 8.4281 - val_accuracy: 0.2289 Epoch 3/30 26/26 [==============================] - 14s 534ms/step - loss: 1.0496 - accuracy: 0.5982 - val_loss: 3.8557 - val_accuracy: 0.1853 Epoch 4/30 26/26 [==============================] - 14s 534ms/step - loss: 1.0238 - accuracy: 0.6019 - val_loss: 2.6034 - val_accuracy: 0.2589 Epoch 5/30 26/26 [==============================] - 14s 535ms/step - loss: 0.9981 - accuracy: 0.6243 - val_loss: 2.1512 - val_accuracy: 0.3270 Epoch 6/30 26/26 [==============================] - 14s 535ms/step - loss: 0.9837 - accuracy: 0.6231 - val_loss: 2.5076 - val_accuracy: 0.2098 Epoch 7/30 26/26 [==============================] - 14s 536ms/step - loss: 0.9453 - accuracy: 0.6461 - val_loss: 2.0778 - val_accuracy: 0.3106 Epoch 8/30 26/26 [==============================] - 14s 534ms/step - loss: 0.9254 - accuracy: 0.6421 - val_loss: 1.4707 - val_accuracy: 0.5586 Epoch 9/30 26/26 [==============================] - 14s 535ms/step - loss: 0.8721 - accuracy: 0.6639 - val_loss: 1.3752 - val_accuracy: 0.5777 Epoch 10/30 26/26 [==============================] - 14s 534ms/step - loss: 0.8773 - accuracy: 0.6667 - val_loss: 1.5882 - val_accuracy: 0.5150 Epoch 11/30 26/26 [==============================] - 14s 533ms/step - loss: 0.8495 - accuracy: 0.6794 - val_loss: 1.4919 - val_accuracy: 0.5695 Epoch 12/30 26/26 [==============================] - 14s 534ms/step - loss: 0.8733 - accuracy: 0.6776 - val_loss: 1.3502 - val_accuracy: 0.4905 Epoch 13/30 26/26 [==============================] - 14s 534ms/step - loss: 0.8456 - accuracy: 0.6812 - val_loss: 0.9505 - val_accuracy: 0.6512 Epoch 14/30 26/26 [==============================] - 14s 535ms/step - loss: 0.8344 - accuracy: 0.6906 - val_loss: 3.3944 - val_accuracy: 0.4278 Epoch 15/30 26/26 [==============================] - 14s 535ms/step - loss: 0.7916 - accuracy: 0.7000 - val_loss: 1.5446 - val_accuracy: 0.6621 Epoch 16/30 26/26 [==============================] - 14s 534ms/step - loss: 0.7739 - accuracy: 0.7072 - val_loss: 1.4009 - val_accuracy: 0.6104 Epoch 17/30 26/26 [==============================] - 14s 536ms/step - loss: 0.7393 - accuracy: 0.7218 - val_loss: 1.0081 - val_accuracy: 0.6730 Epoch 18/30 26/26 [==============================] - 14s 536ms/step - loss: 0.7761 - accuracy: 0.6969 - val_loss: 1.0424 - val_accuracy: 0.6703 Epoch 19/30 26/26 [==============================] - 14s 534ms/step - loss: 0.7429 - accuracy: 0.7151 - val_loss: 0.6274 - val_accuracy: 0.7602 Epoch 20/30 26/26 [==============================] - 14s 535ms/step - loss: 0.6917 - accuracy: 0.7348 - val_loss: 1.1674 - val_accuracy: 0.6703 Epoch 21/30 26/26 [==============================] - 14s 535ms/step - loss: 0.7013 - accuracy: 0.7342 - val_loss: 0.7206 - val_accuracy: 0.7629 Epoch 22/30 26/26 [==============================] - 14s 535ms/step - loss: 0.7036 - accuracy: 0.7348 - val_loss: 0.7936 - val_accuracy: 0.7629 Epoch 23/30 26/26 [==============================] - 14s 536ms/step - loss: 0.6914 - accuracy: 0.7357 - val_loss: 1.7459 - val_accuracy: 0.6785 Epoch 24/30 26/26 [==============================] - 14s 535ms/step - loss: 0.6884 - accuracy: 0.7414 - val_loss: 1.3030 - val_accuracy: 0.7003 Epoch 25/30 26/26 [==============================] - 14s 535ms/step - loss: 0.6921 - accuracy: 0.7333 - val_loss: 1.4142 - val_accuracy: 0.6730 Epoch 26/30 26/26 [==============================] - 14s 534ms/step - loss: 0.6665 - accuracy: 0.7514 - val_loss: 0.8042 - val_accuracy: 0.7466 Epoch 27/30 26/26 [==============================] - 14s 535ms/step - loss: 0.6959 - accuracy: 0.7408 - val_loss: 2.1995 - val_accuracy: 0.6376 Epoch 28/30 26/26 [==============================] - 14s 533ms/step - loss: 0.6429 - accuracy: 0.7608 - val_loss: 1.1186 - val_accuracy: 0.6757 Epoch 29/30 26/26 [==============================] - 14s 534ms/step - loss: 0.6530 - accuracy: 0.7523 - val_loss: 1.9714 - val_accuracy: 0.6294 Epoch 30/30 26/26 [==============================] - 14s 535ms/step - loss: 0.6319 - accuracy: 0.7684 - val_loss: 1.1736 - val_accuracy: 0.7030 3/3 [==============================] - 1s 173ms/step - loss: 1.1736 - accuracy: 0.7030 Top-1 accuracy on the validation set: 70.30%. CPU times: user 7min 48s, sys: 15.6 s, total: 8min 3s Wall time: 10min 13s
上のセルから気付けるように、FixRes はより良いパフォーマンスをもたらします。FixRes のもう一つの利点は改善された合計訓練時間と GPU メモリ使用量の削減です。FixRes はモデル不可知ですから、パフォーマンスを潜在的にブーストするために任意の画像分類モデル上でそれを利用できます。
異なるランダムシードで同じコードを実行して集められたより多くの結果を ここ で見つけられます。
以上