Keras 2 : examples : AdaMatch による半教師あり学習とドメイン適応 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/07/2021 (keras 2.6.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Semi-supervision and domain adaptation with AdaMatch (Author: Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
Keras 2 : examples : AdaMatch による半教師あり学習とドメイン適応
イントロダクション
このサンプルでは、AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation (AdaMatch: 半教師あり学習とドメイン適応への統一的なアプローチ) by Berthelot et al で提案された、AdaMatch アルゴリズムを実装します。それは教師なしドメイン適応の新しい最先端技術です (2021/6 の時点で)。AdaMatch は特に興味深いです、何故ならばそれは一つのフレームワークの下で半教師あり学習 (SSL) と教師なしドメイン適応 (UDA) を統合するからです。従って半教師ありドメイン適応 (SSDA) を実行する方法を提供します。
このサンプルは TensorFlow 2.5 またはそれ以上、及び TensorFlow Models を必要とします、これは次のコマンドでインストールできます :
!pip install -q tf-models-official
先に進む前に、このサンプルの基礎となる幾つかの予備的なコンセプトをレビューしましょう。
準備
半教師あり学習 (SSL) では、より大きなラベル付けられていないデータセット上のモデルを訓練するために少量のラベル付きデータを使用します。コンピュータビジョンのためのポピュラーな半教師あり学習法は FixMatch, MixMatch, Noisy Student Training 等を含みます。標準的な SSL ワークフローがどのようなものか考えを得るために このサンプル を参考にできます。
教師なしドメイン適応 では、ソースのラベル付けられたデータセットとターゲットのラベル付けされていないデータセットへのアクセスを持ちます。そしてタスクはターゲット・データセットに上手く一般化できるモデルを学習することです。ソースとターゲット・データセットは分布の観点から変化します。次の図はこの考えの図示しています。現在のサンプルでは、ソース・データセットとして MNIST を、ターゲット・データセットとして SVHN を使用しています、これは家のナンバーの画像から成ります。両者のデータセットはテクスチャ, 視点, 外観 等の観点から様々な変化する要因を持ちます : それらのドメイン、分布は互いに異なります。
深層学習のポピュラーなドメイン適応アルゴリズムは Deep CORAL, Moment Matching 等を含みます。
セットアップ
import tensorflow as tf
tf.random.set_seed(42)
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from official.vision.image_classification.augment import RandAugment
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
データの準備
# MNIST
(
(mnist_x_train, mnist_y_train),
(mnist_x_test, mnist_y_test),
) = keras.datasets.mnist.load_data()
# Add a channel dimension
mnist_x_train = tf.expand_dims(mnist_x_train, -1)
mnist_x_test = tf.expand_dims(mnist_x_test, -1)
# Convert the labels to one-hot encoded vectors
mnist_y_train = tf.one_hot(mnist_y_train, 10).numpy()
# SVHN
svhn_train, svhn_test = tfds.load(
"svhn_cropped", split=["train", "test"], as_supervised=True
)
定数とハイパーパラメータを定義する
RESIZE_TO = 32
SOURCE_BATCH_SIZE = 64
TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE # Reference: Section 3.2
EPOCHS = 10
STEPS_PER_EPOCH = len(mnist_x_train) // SOURCE_BATCH_SIZE
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH
AUTO = tf.data.AUTOTUNE
LEARNING_RATE = 0.03
WEIGHT_DECAY = 0.0005
INIT = "he_normal"
DEPTH = 28
WIDTH_MULT = 2
データ増強ユティリティ
SSL アルゴリズムの標準的な要素は、 学習モデルに予測に一貫性を持たせるために同じ画像の弱くそして強く増強されたバージョンを供給することです。 強い増強については、RandAugment が標準的な選択です。弱い増強については、水平反転とランダム・クロッピングを使用します。
# Initialize `RandAugment` object with 2 layers of
# augmentation transforms and strength of 5.
augmenter = RandAugment(num_layers=2, magnitude=5)
def weak_augment(image, source=True):
if image.dtype != tf.float32:
image = tf.cast(image, tf.float32)
# MNIST images are grayscale, this is why we first convert them to
# RGB images.
if source:
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3))
return image
def strong_augment(image, source=True):
if image.dtype != tf.float32:
image = tf.cast(image, tf.float32)
if source:
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = augmenter.distort(image)
return image
データ・ローディング・ユティリティ
def create_individual_ds(ds, aug_func, source=True):
if source:
batch_size = SOURCE_BATCH_SIZE
else:
# During training 3x more target unlabeled samples are shown
# to the model in AdaMatch (Section 3.2 of the paper).
batch_size = TARGET_BATCH_SIZE
ds = ds.shuffle(batch_size * 10, seed=42)
if source:
ds = ds.map(lambda x, y: (aug_func(x), y), num_parallel_calls=AUTO)
else:
ds = ds.map(lambda x, y: (aug_func(x, False), y), num_parallel_calls=AUTO)
ds = ds.batch(batch_size).prefetch(AUTO)
return ds
_w と _s のサフィックスはそれぞれ弱いと強いを表します。
source_ds = tf.data.Dataset.from_tensor_slices((mnist_x_train, mnist_y_train))
source_ds_w = create_individual_ds(source_ds, weak_augment)
source_ds_s = create_individual_ds(source_ds, strong_augment)
final_source_ds = tf.data.Dataset.zip((source_ds_w, source_ds_s))
target_ds_w = create_individual_ds(svhn_train, weak_augment, source=False)
target_ds_s = create_individual_ds(svhn_train, strong_augment, source=False)
final_target_ds = tf.data.Dataset.zip((target_ds_w, target_ds_s))
ここにシングル画像バッチがどのように見えるかがあります :
損失計算ユティリティ
def compute_loss_source(source_labels, logits_source_w, logits_source_s):
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
# First compute the losses between original source labels and
# predictions made on the weakly and strongly augmented versions
# of the same images.
w_loss = loss_func(source_labels, logits_source_w)
s_loss = loss_func(source_labels, logits_source_s)
return w_loss + s_loss
def compute_loss_target(target_pseudo_labels_w, logits_target_s, mask):
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True, reduction="none")
target_pseudo_labels_w = tf.stop_gradient(target_pseudo_labels_w)
# For calculating loss for the target samples, we treat the pseudo labels
# as the ground-truth. These are not considered during backpropagation
# which is a standard SSL practice.
target_loss = loss_func(target_pseudo_labels_w, logits_target_s)
# More on `mask` later.
mask = tf.cast(mask, target_loss.dtype)
target_loss *= mask
return tf.reduce_mean(target_loss, 0)
AdaMatch 訓練のためのサブクラス化モデル
下図は AdaMatch の全体的なワークフローを表しています (元の論文 からの引用) :
ここにワークフローの簡潔なステップ毎の分解があります :
- 最初にソースとターゲットのデータセットから画像の弱いそして強い増強のペアを取得します。
- 2 つの連結されたコピーを準備します : i. 両者のペアが連結されたもの。ii. ソースデータ画像ペアだけが連結されたもの。
- モデルを通して 2 つの forward パスを実行します : i. 最初の forwrad パスは 2.i から得られた連結されたコピーを使用します。この forward パスでは、Batch Normalization 統計が更新されます。ii. 2 番目の forward パスでは、2.ii で得られた連結されたコピーだけを使用します。Batch Normalization 層は推論モードで実行されます。
- 両方の forward パスに対してそれぞれのロジットが計算されます。
- ロジットは論文で紹介されている、変換のシリーズを通り抜けます (これを短く説明します)。
- 損失を計算して基礎となるモデルの勾配を更新します。
class AdaMatch(keras.Model):
def __init__(self, model, total_steps, tau=0.9):
super(AdaMatch, self).__init__()
self.model = model
self.tau = tau # Denotes the confidence threshold
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.total_steps = total_steps
self.current_step = tf.Variable(0, dtype="int64")
@property
def metrics(self):
return [self.loss_tracker]
# This is a warmup schedule to update the weight of the
# loss contributed by the target unlabeled samples. More
# on this in the text.
def compute_mu(self):
pi = tf.constant(np.pi, dtype="float32")
step = tf.cast(self.current_step, dtype="float32")
return 0.5 - tf.cos(tf.math.minimum(pi, (2 * pi * step) / self.total_steps)) / 2
def train_step(self, data):
## Unpack and organize the data ##
source_ds, target_ds = data
(source_w, source_labels), (source_s, _) = source_ds
(
(target_w, _),
(target_s, _),
) = target_ds # Notice that we are NOT using any labels here.
combined_images = tf.concat([source_w, source_s, target_w, target_s], 0)
combined_source = tf.concat([source_w, source_s], 0)
total_source = tf.shape(combined_source)[0]
total_target = tf.shape(tf.concat([target_w, target_s], 0))[0]
with tf.GradientTape() as tape:
## Forward passes ##
combined_logits = self.model(combined_images, training=True)
z_d_prime_source = self.model(
combined_source, training=False
) # No BatchNorm update.
z_prime_source = combined_logits[:total_source]
## 1. Random logit interpolation for the source images ##
lambd = tf.random.uniform((total_source, 10), 0, 1)
final_source_logits = (lambd * z_prime_source) + (
(1 - lambd) * z_d_prime_source
)
## 2. Distribution alignment (only consider weakly augmented images) ##
# Compute softmax for logits of the WEAKLY augmented SOURCE images.
y_hat_source_w = tf.nn.softmax(final_source_logits[: tf.shape(source_w)[0]])
# Extract logits for the WEAKLY augmented TARGET images and compute softmax.
logits_target = combined_logits[total_source:]
logits_target_w = logits_target[: tf.shape(target_w)[0]]
y_hat_target_w = tf.nn.softmax(logits_target_w)
# Align the target label distribution to that of the source.
expectation_ratio = tf.reduce_mean(y_hat_source_w) / tf.reduce_mean(
y_hat_target_w
)
y_tilde_target_w = tf.math.l2_normalize(
y_hat_target_w * expectation_ratio, 1
)
## 3. Relative confidence thresholding ##
row_wise_max = tf.reduce_max(y_hat_source_w, axis=-1)
final_sum = tf.reduce_mean(row_wise_max, 0)
c_tau = self.tau * final_sum
mask = tf.reduce_max(y_tilde_target_w, axis=-1) >= c_tau
## Compute losses (pay attention to the indexing) ##
source_loss = compute_loss_source(
source_labels,
final_source_logits[: tf.shape(source_w)[0]],
final_source_logits[tf.shape(source_w)[0] :],
)
target_loss = compute_loss_target(
y_tilde_target_w, logits_target[tf.shape(target_w)[0] :], mask
)
t = self.compute_mu() # Compute weight for the target loss
total_loss = source_loss + (t * target_loss)
self.current_step.assign_add(
1
) # Update current training step for the scheduler
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
self.loss_tracker.update_state(total_loss)
return {"loss": self.loss_tracker.result()}
著者は論文で 3 つの改良を紹介しています :
- AdaMatch では、2 つの forward パスを実行してそれらの 1 つだけが Batch Normalization 統計の更新について責任を負います。これはターゲットデータセットにおける分布シフトを説明するために行なわれます。他方の forward パスでは、ソースサンプルだけを使用して、Batch Normalization 層は推論モードで実行されます。これら 2 つのパスからのソースサンプル (弱いそして強い増強バージョン) のためのロジットは Batch Normalization がどのように実行されるかにより互いに少し異なります。ソースサンプルのための最終的なロジットはこれら 2 つの異なるロジットペアの間の線形補間により計算されます。これは一貫性正則化の形式を誘導します。このステップは ランダム・ロジット補間 と呼ばれています。
- 分布アラインメント (= Distribution alignment) がソースとターゲット・ラベル分布を揃える (= align) ために使用されます。これは基礎的なモデルがドメイン不変な表現を学習するのに更に役立ちます。教師なしドメイン適応の場合、ターゲットデータセットの任意のラベルへのアクセスを持ちません。これが疑似ラベルが基礎的なモデルから生成される理由です。
- 基礎となるモデルはターゲットサンプルのために疑似ラベルを生成します。モデルが不完全な予測を作成することは可能性は高いです。それらは訓練が進むにつれて逆伝播し、全体のパフォーマンスに悪影響を与える可能性があります。それを補うため、閾値に基づいて高い信頼度の予測をフィルタリングします (そのため compute_loss_target() 内でマスクを使用しています)。AdaMatch では、この閾値は相対的に調整されますので、それが relative confidence thresholding (相対的信頼度閾値) と呼ばれる理由です。
これらの方法の詳細とそれらの各々がどのように寄与するかを知るには、論文 を参照してください。
About compute_mu():
AdaMatch では固定スカラー量を使用するのではなく、変換するスカラーが使用されます。それはターゲットサンプルにより寄与される損失の重みを表します。視覚的には、重みスケジューラは次のようなものです :
Wide-ResNet-28-2 のインスタンス化
著者はこのサンプルで使用するデータセット・ペアのために WideResNet-28-2 を使用しています。以下のコードの殆どは このスクリプト から参照されています。次のモデルはその内部にピクセル値を [0, 1] にスケールするスケーリング層を持つことに注意してください。
def wide_basic(x, n_input_plane, n_output_plane, stride):
conv_params = [[3, 3, stride, "same"], [3, 3, (1, 1), "same"]]
n_bottleneck_plane = n_output_plane
# Residual block
for i, v in enumerate(conv_params):
if i == 0:
if n_input_plane != n_output_plane:
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
convs = x
else:
convs = layers.BatchNormalization()(x)
convs = layers.Activation("relu")(convs)
convs = layers.Conv2D(
n_bottleneck_plane,
(v[0], v[1]),
strides=v[2],
padding=v[3],
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(convs)
else:
convs = layers.BatchNormalization()(convs)
convs = layers.Activation("relu")(convs)
convs = layers.Conv2D(
n_bottleneck_plane,
(v[0], v[1]),
strides=v[2],
padding=v[3],
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(convs)
# Shortcut connection: identity function or 1x1
# convolutional
# (depends on difference between input & output shape - this
# corresponds to whether we are using the first block in
# each
# group; see `block_series()`).
if n_input_plane != n_output_plane:
shortcut = layers.Conv2D(
n_output_plane,
(1, 1),
strides=stride,
padding="same",
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(x)
else:
shortcut = x
return layers.Add()([convs, shortcut])
# Stacking residual units on the same stage
def block_series(x, n_input_plane, n_output_plane, count, stride):
x = wide_basic(x, n_input_plane, n_output_plane, stride)
for i in range(2, int(count + 1)):
x = wide_basic(x, n_output_plane, n_output_plane, stride=1)
return x
def get_network(image_size=32, num_classes=10):
n = (DEPTH - 4) / 6
n_stages = [16, 16 * WIDTH_MULT, 32 * WIDTH_MULT, 64 * WIDTH_MULT]
inputs = keras.Input(shape=(image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)
conv1 = layers.Conv2D(
n_stages[0],
(3, 3),
strides=1,
padding="same",
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(x)
## Add wide residual blocks ##
conv2 = block_series(
conv1,
n_input_plane=n_stages[0],
n_output_plane=n_stages[1],
count=n,
stride=(1, 1),
) # Stage 1
conv3 = block_series(
conv2,
n_input_plane=n_stages[1],
n_output_plane=n_stages[2],
count=n,
stride=(2, 2),
) # Stage 2
conv4 = block_series(
conv3,
n_input_plane=n_stages[2],
n_output_plane=n_stages[3],
count=n,
stride=(2, 2),
) # Stage 3
batch_norm = layers.BatchNormalization()(conv4)
relu = layers.Activation("relu")(batch_norm)
# Classifier
trunk_outputs = layers.GlobalAveragePooling2D()(relu)
outputs = layers.Dense(
num_classes, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
)(trunk_outputs)
return keras.Model(inputs, outputs)
今ではそのようにして Wide ResNet をインスタンス化できます。ここで Wide ResNet を使用する目的は実装を元のものに出来る限り近付けるためであることに注意してください。
wrn_model = get_network()
print(f"Model has {wrn_model.count_params()/1e6} Million parameters.")
Model has 1.471226 Million parameters.
AdaMatch モデルをインスタンス化してそれをコンパイルする
reduce_lr = keras.optimizers.schedules.CosineDecay(LEARNING_RATE, TOTAL_STEPS, 0.25)
optimizer = keras.optimizers.Adam(reduce_lr)
adamatch_trainer = AdaMatch(model=wrn_model, total_steps=TOTAL_STEPS)
adamatch_trainer.compile(optimizer=optimizer)
モデル訓練
total_ds = tf.data.Dataset.zip((final_source_ds, final_target_ds))
adamatch_trainer.fit(total_ds, epochs=EPOCHS)
Epoch 1/10 382/382 [==============================] - 53s 96ms/step - loss: 117866954752.0000 Epoch 2/10 382/382 [==============================] - 36s 95ms/step - loss: 2.6231 Epoch 3/10 382/382 [==============================] - 36s 94ms/step - loss: 4.1699 Epoch 4/10 382/382 [==============================] - 36s 95ms/step - loss: 8.2748 Epoch 5/10 382/382 [==============================] - 36s 95ms/step - loss: 28.8679 Epoch 6/10 382/382 [==============================] - 36s 94ms/step - loss: 14.7112 Epoch 7/10 382/382 [==============================] - 36s 94ms/step - loss: 7.8206 Epoch 8/10 382/382 [==============================] - 36s 94ms/step - loss: 18.1182 Epoch 9/10 382/382 [==============================] - 36s 94ms/step - loss: 22.4258 Epoch 10/10 382/382 [==============================] - 36s 95ms/step - loss: 22.1107 <tensorflow.python.keras.callbacks.History at 0x7f9bc4990b50>
ターゲットとソース・テストセット上で評価
# Compile the AdaMatch model to yield accuracy.
adamatch_trained_model = adamatch_trainer.model
adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())
# Score on the target test set.
svhn_test = svhn_test.batch(TARGET_BATCH_SIZE).prefetch(AUTO)
_, accuracy = adamatch_trained_model.evaluate(svhn_test)
print(f"Accuracy on target test set: {accuracy * 100:.2f}%")
136/136 [==============================] - 2s 10ms/step - loss: 572.9810 - sparse_categorical_accuracy: 0.1960 Accuracy on target test set: 19.11%
より訓練すれば、このスコアは向上します。この同じネットワークが標準的な分類目的 (関数) で訓練される場合、それは 7.20% の精度を生成し、これは AdaMatch で得たものよりも遥かに低いです。ハイパーパラメータと他の実験の詳細について学習するために このノートブック を確認できます。
# Utility function for preprocessing the source test set.
def prepare_test_ds_source(image, label):
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
return image, label
source_test_ds = tf.data.Dataset.from_tensor_slices((mnist_x_test, mnist_y_test))
source_test_ds = (
source_test_ds.map(prepare_test_ds_source, num_parallel_calls=AUTO)
.batch(TARGET_BATCH_SIZE)
.prefetch(AUTO)
)
# Evaluation on the source test set.
_, accuracy = adamatch_trained_model.evaluate(source_test_ds)
print(f"Accuracy on source test set: {accuracy * 100:.2f}%")
53/53 [==============================] - 1s 10ms/step - loss: 572.9810 - sparse_categorical_accuracy: 0.6532 Accuracy on source test set: 65.32%
これらの モデル重み を使用することにより結果を再現できます。
以上