Keras 2 : examples : 知識蒸留 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/25/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Knowledge Distillation (Author: Kenneth Borup)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : 知識蒸留
Description: 古典的な知識蒸留の実装。
知識蒸留へのイントロダクション
知識蒸留はモデル圧縮のための手続きで、そこでは大きい事前訓練済みの (教師) モデルに一致するように小さい (生徒) モデルが訓練されます。正解ラベルに加えて穏やかにされた (= softened) 教師ロジットに一致することを目的として、損失関数を最小化することにより知識が教師モデルから生徒 (モデル) に転移されます。
ロジットは softmax の “temperature” scaling (温度付きスケーリング) 関数を適用することにより穏やかにされ、確率分布を効果的に滑らかにして教師により学習されたクラス間の関係性を明らかにします。
リファレンス
セットアップ
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
Distiller() クラスの構築
カスタム Distiller() クラスは Model のメソッド train_step, test_step と compile() を override します。distiller を使用するには、以下が必要です :
- 訓練済みの教師モデル
- 訓練する生徒モデル
- 生徒予測と正解の間の差についての生徒損失関数
- soft 生徒予測と soft 教師ラベルの間の差についての (温度と連動する) 蒸留損失関数、
- 生徒と蒸留損失を重み付ける alpha 因子
- 生徒のための optimizer とパフォーマンスを評価するための (オプションの) メトリクス
train_step メソッドでは、教師と生徒の両方の forward パスを実行し、student_loss と distillation_loss をそれぞれ alpha と 1 – alpha で重み付けして損失を計算し、そして backward パスを実行します。Note: 生徒重みだけが更新されますので、生徒重みに対する勾配だけを計算します。
test_step メソッドでは、提供されたデータセットで生徒モデルを評価します。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super(Distiller, self).__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
""" Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_prediction = self.student(x, training=False)
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
生徒と教師モデルの作成
最初に、教師モデルと (それ) より小さい生徒モデルを作成します。両方のモデルは畳み込みニューラルネットワークで Sequential() を使用して作成されますが、どのような Keras モデルでもあり得ます。
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
データセットの準備
教師を訓練して教師を蒸留するために使用されるデータセットは MNIST で、そしてこの手続きは例えば CIFAR-10 のような任意の他のデータセットのためにも、適切なモデルの選択をすれば、等値です。生徒と教師の両方は訓練セットで訓練されてテストセットで評価されます。
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
教師の訓練
知識蒸留では教師が訓練されて固定されていることを仮定しています。そのため、通常の方法で訓練セットで教師モデルを訓練することから始めます。
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5 1875/1875 [==============================] - 248s 132ms/step - loss: 0.2438 - sparse_categorical_accuracy: 0.9220 Epoch 2/5 1875/1875 [==============================] - 263s 140ms/step - loss: 0.0881 - sparse_categorical_accuracy: 0.9738 Epoch 3/5 1875/1875 [==============================] - 245s 131ms/step - loss: 0.0650 - sparse_categorical_accuracy: 0.9811 Epoch 5/5 363/1875 [====>.........................] - ETA: 3:18 - loss: 0.0555 - sparse_categorical_accuracy: 0.9839
(訳者注: 実験結果)
Epoch 1/5 1875/1875 [==============================] - 16s 4ms/step - loss: 0.1440 - sparse_categorical_accuracy: 0.9564 Epoch 2/5 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0898 - sparse_categorical_accuracy: 0.9733 Epoch 3/5 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0801 - sparse_categorical_accuracy: 0.9769 Epoch 4/5 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9785 Epoch 5/5 1875/1875 [==============================] - 7s 4ms/step - loss: 0.0688 - sparse_categorical_accuracy: 0.9807 313/313 [==============================] - 1s 3ms/step - loss: 0.0961 - sparse_categorical_accuracy: 0.9762 [0.09605662524700165, 0.9761999845504761]
教師から生徒への蒸留
既に教師モデルを訓練しましたので、Distiller(student, teacher) インスタンスを初期化し、それを必要な損失, ハイパーパラメータと optimizer で compile() し、そして教師から生徒へ蒸留します。
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3 1875/1875 [==============================] - 242s 129ms/step - sparse_categorical_accuracy: 0.9761 - student_loss: 0.1526 - distillation_loss: 0.0226 Epoch 2/3 1875/1875 [==============================] - 281s 150ms/step - sparse_categorical_accuracy: 0.9863 - student_loss: 0.1384 - distillation_loss: 0.0185 Epoch 3/3 399/1875 [=====>........................] - ETA: 3:27 - sparse_categorical_accuracy: 0.9896 - student_loss: 0.1300 - distillation_loss: 0.0182
Epoch 1/3 1875/1875 [==============================] - 7s 4ms/step - sparse_categorical_accuracy: 0.9169 - student_loss: 0.3630 - distillation_loss: 0.1083 Epoch 2/3 1875/1875 [==============================] - 7s 4ms/step - sparse_categorical_accuracy: 0.9706 - student_loss: 0.1223 - distillation_loss: 0.0315 Epoch 3/3 1875/1875 [==============================] - 7s 4ms/step - sparse_categorical_accuracy: 0.9772 - student_loss: 0.0902 - distillation_loss: 0.0204 313/313 [==============================] - 1s 2ms/step - sparse_categorical_accuracy: 0.9790 - student_loss: 0.0831 [0.9789999723434448, 1.7090986148105003e-05]
比較のために生徒をスクラッチから訓練する
知識蒸留により得られる性能 gain を評価するために、教師なしで同値の生徒モデルをスクラッチから訓練することもできます。
# Train student as doen usually
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4731 - sparse_categorical_accuracy: 0.8550 Epoch 2/3 1875/1875 [==============================] - 4s 2ms/step - loss: 0.0966 - sparse_categorical_accuracy: 0.9710 Epoch 3/3 1875/1875 [==============================] - 4s 2ms/step - loss: 0.0750 - sparse_categorical_accuracy: 0.9773 313/313 [==============================] - 0s 963us/step - loss: 0.0691 - sparse_categorical_accuracy: 0.9778 [0.06905383616685867, 0.9778000116348267]
Epoch 1/3 1875/1875 [==============================] - 5s 3ms/step - loss: 0.2420 - sparse_categorical_accuracy: 0.9278 Epoch 2/3 1875/1875 [==============================] - 5s 3ms/step - loss: 0.0914 - sparse_categorical_accuracy: 0.9721 Epoch 3/3 1875/1875 [==============================] - 5s 3ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9765 313/313 [==============================] - 1s 2ms/step - loss: 0.0706 - sparse_categorical_accuracy: 0.9776 [0.07058624178171158, 0.9775999784469604]
教師を 5 full エポック訓練してこの教師で生徒を 3 full エポック蒸留する場合、このサンプルではスクラッチからの同じ生徒モデルの訓練と比較して、更には教師自身と比較してさえも、性能ブーストを体験するはずです。 教師が約 97.6% の精度、スクラッチから訓練された生徒が約 96.6 を持ち、そして蒸留され生徒が約 98.1% であることを期待できるはずです。異なる重みの初期化を使用するためにシードを削除するか、異なるものを試してください。
以上