Keras 2 : examples : 生成深層学習 – WGAN-GP (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/30/2022 (keras 2.9.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Generative Deep Learning : WGAN-GP overriding Model.train_step (Author: A_K_Nain)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 生成深層学習 – WGAN-GP
Description : 勾配ペナルティを使用した Wasserstein GAN の実装。
オリジナルの Wasserstein GAN は、オリジナルの GAN 論文で使用された値関数よりも良い理論的特質を持つ、値関数を生成するために Wasserstein 距離を利用しています。WGAN は discriminator (aka critic) が 1-Lipschitz 関数の空間内にあることを必要とします。著者らはこの制約を達成するために重みクリッピングのアイデアを提案しました。重みクリッピングは機能しますが、1-Lipschitz を強制する問題のある方法であり、望ましくない動作、例えば非常に深い WGAN discriminator (critic) は収束に失敗することが多い、を引き起こす可能性があります。
WGAN-GP 法は滑らかな訓練を保証する重みクリッピングの別の方法を提案しています。重みをクリップする代わりに、著者らは discriminator 勾配の L2 ノルムを 1 に近く保持する損失項を加えることによる「勾配ペナルティ」を提案しました。
セットアップ
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
Fashion-MNIST データの準備
WGAN-GP を訓練する方法を実演するため、Fashion-MNIST データセットを使用していきます。このデータセットの各サンプルは 10 クラス (e.g. trouser, pullover, sneaker 等) からのラベルと関連付けられた 28×28 グレースケール画像です。
IMG_SHAPE = (28, 28, 1)
BATCH_SIZE = 512
# Size of the noise vector
noise_dim = 128
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(f"Number of examples: {len(train_images)}")
print(f"Shape of the images in the dataset: {train_images.shape[1:]}")
# Reshape each sample to (28, 28, 1) and normalize the pixel values in the [-1, 1] range
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
train_images = (train_images - 127.5) / 127.5
Number of examples: 60000 Shape of the images in the dataset: (28, 28)
discriminator (オリジナル WGAN の critic) の作成
データセットのサンプルは (28, 28, 1) shape を持ちます。ストライドされた畳み込みを使用していきますので、これは奇数の次元を持つ shape という結果になる可能性があります。例えば、(28, 28) -> Conv_s2 -> (14, 14) -> Conv_s2 -> (7, 7) -> Conv_s2 ->(3, 3).
ネットワークの generator パートでアップサンプリングを実行する際、注意していないと、元の画像と同じ入力 shape を取得できません。これを避けるため、遥かに単純なことを行います : – discriminator では : 各サンプルに対して入力を「ゼロパディング」して shape を (32, 32, 1) に変更します ; そして generator では : 最終的な出力を切り抜いて入力 shape に一致させます。
def conv_block(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(1, 1),
padding="same",
use_bias=True,
use_bn=False,
use_dropout=False,
drop_value=0.5,
):
x = layers.Conv2D(
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
)(x)
if use_bn:
x = layers.BatchNormalization()(x)
x = activation(x)
if use_dropout:
x = layers.Dropout(drop_value)(x)
return x
def get_discriminator_model():
img_input = layers.Input(shape=IMG_SHAPE)
# Zero pad the input to make the input images size to (32, 32, 1).
x = layers.ZeroPadding2D((2, 2))(img_input)
x = conv_block(
x,
64,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
use_bias=True,
activation=layers.LeakyReLU(0.2),
use_dropout=False,
drop_value=0.3,
)
x = conv_block(
x,
128,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
activation=layers.LeakyReLU(0.2),
use_bias=True,
use_dropout=True,
drop_value=0.3,
)
x = conv_block(
x,
256,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
activation=layers.LeakyReLU(0.2),
use_bias=True,
use_dropout=True,
drop_value=0.3,
)
x = conv_block(
x,
512,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
activation=layers.LeakyReLU(0.2),
use_bias=True,
use_dropout=False,
drop_value=0.3,
)
x = layers.Flatten()(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1)(x)
d_model = keras.models.Model(img_input, x, name="discriminator")
return d_model
d_model = get_discriminator_model()
d_model.summary()
Model: "discriminator" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ zero_padding2d (ZeroPadding2 (None, 32, 32, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 16, 16, 64) 1664 _________________________________________________________________ leaky_re_lu (LeakyReLU) (None, 16, 16, 64) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 8, 8, 128) 204928 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 8, 8, 128) 0 _________________________________________________________________ dropout (Dropout) (None, 8, 8, 128) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 4, 4, 256) 819456 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 4, 4, 256) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 4, 4, 256) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 2, 2, 512) 3277312 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 2, 2, 512) 0 _________________________________________________________________ flatten (Flatten) (None, 2048) 0 _________________________________________________________________ dropout_2 (Dropout) (None, 2048) 0 _________________________________________________________________ dense (Dense) (None, 1) 2049 ================================================================= Total params: 4,305,409 Trainable params: 4,305,409 Non-trainable params: 0 _________________________________________________________________
generator の作成
def upsample_block(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(1, 1),
up_size=(2, 2),
padding="same",
use_bn=False,
use_bias=True,
use_dropout=False,
drop_value=0.3,
):
x = layers.UpSampling2D(up_size)(x)
x = layers.Conv2D(
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
)(x)
if use_bn:
x = layers.BatchNormalization()(x)
if activation:
x = activation(x)
if use_dropout:
x = layers.Dropout(drop_value)(x)
return x
def get_generator_model():
noise = layers.Input(shape=(noise_dim,))
x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Reshape((4, 4, 256))(x)
x = upsample_block(
x,
128,
layers.LeakyReLU(0.2),
strides=(1, 1),
use_bias=False,
use_bn=True,
padding="same",
use_dropout=False,
)
x = upsample_block(
x,
64,
layers.LeakyReLU(0.2),
strides=(1, 1),
use_bias=False,
use_bn=True,
padding="same",
use_dropout=False,
)
x = upsample_block(
x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
)
# At this point, we have an output which has the same shape as the input, (32, 32, 1).
# We will use a Cropping2D layer to make it (28, 28, 1).
x = layers.Cropping2D((2, 2))(x)
g_model = keras.models.Model(noise, x, name="generator")
return g_model
g_model = get_generator_model()
g_model.summary()
Model: "generator" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 128)] 0 _________________________________________________________________ dense_1 (Dense) (None, 4096) 524288 _________________________________________________________________ batch_normalization (BatchNo (None, 4096) 16384 _________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 4096) 0 _________________________________________________________________ reshape (Reshape) (None, 4, 4, 256) 0 _________________________________________________________________ up_sampling2d (UpSampling2D) (None, 8, 8, 256) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 8, 8, 128) 294912 _________________________________________________________________ batch_normalization_1 (Batch (None, 8, 8, 128) 512 _________________________________________________________________ leaky_re_lu_5 (LeakyReLU) (None, 8, 8, 128) 0 _________________________________________________________________ up_sampling2d_1 (UpSampling2 (None, 16, 16, 128) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 16, 16, 64) 73728 _________________________________________________________________ batch_normalization_2 (Batch (None, 16, 16, 64) 256 _________________________________________________________________ leaky_re_lu_6 (LeakyReLU) (None, 16, 16, 64) 0 _________________________________________________________________ up_sampling2d_2 (UpSampling2 (None, 32, 32, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 32, 32, 1) 576 _________________________________________________________________ batch_normalization_3 (Batch (None, 32, 32, 1) 4 _________________________________________________________________ activation (Activation) (None, 32, 32, 1) 0 _________________________________________________________________ cropping2d (Cropping2D) (None, 28, 28, 1) 0 ================================================================= Total params: 910,660 Trainable params: 902,082 Non-trainable params: 8,578 _________________________________________________________________
WGAN-GP モデルの作成
generator と discriminator を定義した今、WGAN-GP モデルを実装するときです。また、訓練のための train_step をオーバーライドします。
class WGAN(keras.Model):
def __init__(
self,
discriminator,
generator,
latent_dim,
discriminator_extra_steps=3,
gp_weight=10.0,
):
super(WGAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.d_steps = discriminator_extra_steps
self.gp_weight = gp_weight
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super(WGAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_loss_fn = d_loss_fn
self.g_loss_fn = g_loss_fn
def gradient_penalty(self, batch_size, real_images, fake_images):
""" Calculates the gradient penalty.
This loss is calculated on an interpolated image
and added to the discriminator loss.
"""
# Get the interpolated image
alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
diff = fake_images - real_images
interpolated = real_images + alpha * diff
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
# 1. Get the discriminator output for this interpolated image.
pred = self.discriminator(interpolated, training=True)
# 2. Calculate the gradients w.r.t to this interpolated image.
grads = gp_tape.gradient(pred, [interpolated])[0]
# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Get the batch size
batch_size = tf.shape(real_images)[0]
# For each batch, we are going to perform the
# following steps as laid out in the original paper:
# 1. Train the generator and get the generator loss
# 2. Train the discriminator and get the discriminator loss
# 3. Calculate the gradient penalty
# 4. Multiply this gradient penalty with a constant weight factor
# 5. Add the gradient penalty to the discriminator loss
# 6. Return the generator and discriminator losses as a loss dictionary
# Train the discriminator first. The original paper recommends training
# the discriminator for `x` more steps (typically 5) as compared to
# one step of the generator. Here we will train it for 3 extra steps
# as compared to 5 to reduce the training time.
for i in range(self.d_steps):
# Get the latent vector
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
with tf.GradientTape() as tape:
# Generate fake images from the latent vector
fake_images = self.generator(random_latent_vectors, training=True)
# Get the logits for the fake images
fake_logits = self.discriminator(fake_images, training=True)
# Get the logits for the real images
real_logits = self.discriminator(real_images, training=True)
# Calculate the discriminator loss using the fake and real image logits
d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
# Calculate the gradient penalty
gp = self.gradient_penalty(batch_size, real_images, fake_images)
# Add the gradient penalty to the original discriminator loss
d_loss = d_cost + gp * self.gp_weight
# Get the gradients w.r.t the discriminator loss
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
# Update the weights of the discriminator using the discriminator optimizer
self.d_optimizer.apply_gradients(
zip(d_gradient, self.discriminator.trainable_variables)
)
# Train the generator
# Get the latent vector
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
with tf.GradientTape() as tape:
# Generate fake images using the generator
generated_images = self.generator(random_latent_vectors, training=True)
# Get the discriminator logits for fake images
gen_img_logits = self.discriminator(generated_images, training=True)
# Calculate the generator loss
g_loss = self.g_loss_fn(gen_img_logits)
# Get the gradients w.r.t the generator loss
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
# Update the weights of the generator using the generator optimizer
self.g_optimizer.apply_gradients(
zip(gen_gradient, self.generator.trainable_variables)
)
return {"d_loss": d_loss, "g_loss": g_loss}
生成された画像を定期的にセーブする Keras コールバックの作成
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=6, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
generated_images = self.model.generator(random_latent_vectors)
generated_images = (generated_images * 127.5) + 127.5
for i in range(self.num_img):
img = generated_images[i].numpy()
img = keras.preprocessing.image.array_to_img(img)
img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))
end-to-end モデルの訓練
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
real_loss = tf.reduce_mean(real_img)
fake_loss = tf.reduce_mean(fake_img)
return fake_loss - real_loss
# Define the loss functions for the generator.
def generator_loss(fake_img):
return -tf.reduce_mean(fake_img)
# Set the number of epochs for trainining.
epochs = 20
# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)
# Instantiate the WGAN model.
wgan = WGAN(
discriminator=d_model,
generator=g_model,
latent_dim=noise_dim,
discriminator_extra_steps=3,
)
# Compile the WGAN model.
wgan.compile(
d_optimizer=discriminator_optimizer,
g_optimizer=generator_optimizer,
g_loss_fn=generator_loss,
d_loss_fn=discriminator_loss,
)
# Start training the model.
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])
Epoch 1/20 118/118 [==============================] - 39s 334ms/step - d_loss: -7.6571 - g_loss: -16.9272 Epoch 2/20 118/118 [==============================] - 39s 334ms/step - d_loss: -7.2396 - g_loss: -8.5466 Epoch 3/20 118/118 [==============================] - 40s 335ms/step - d_loss: -6.3892 - g_loss: 1.3971 Epoch 4/20 118/118 [==============================] - 40s 335ms/step - d_loss: -5.7705 - g_loss: 6.5997 Epoch 5/20 118/118 [==============================] - 40s 336ms/step - d_loss: -5.2659 - g_loss: 7.4743 Epoch 6/20 118/118 [==============================] - 40s 335ms/step - d_loss: -4.9563 - g_loss: 6.2071 Epoch 7/20 118/118 [==============================] - 40s 335ms/step - d_loss: -4.5759 - g_loss: 6.4767 Epoch 8/20 118/118 [==============================] - 40s 335ms/step - d_loss: -4.3748 - g_loss: 5.4304 Epoch 9/20 118/118 [==============================] - 40s 335ms/step - d_loss: -4.1142 - g_loss: 6.4326 Epoch 10/20 118/118 [==============================] - 40s 335ms/step - d_loss: -3.7956 - g_loss: 7.1200 Epoch 11/20 118/118 [==============================] - 40s 335ms/step - d_loss: -3.5723 - g_loss: 7.1837 Epoch 12/20 118/118 [==============================] - 40s 335ms/step - d_loss: -3.4374 - g_loss: 9.0537 Epoch 13/20 118/118 [==============================] - 40s 335ms/step - d_loss: -3.3402 - g_loss: 8.4949 Epoch 14/20 118/118 [==============================] - 40s 335ms/step - d_loss: -3.1252 - g_loss: 8.6130 Epoch 15/20 118/118 [==============================] - 40s 336ms/step - d_loss: -3.0130 - g_loss: 9.4563 Epoch 16/20 118/118 [==============================] - 40s 335ms/step - d_loss: -2.9330 - g_loss: 8.8075 Epoch 17/20 118/118 [==============================] - 40s 336ms/step - d_loss: -2.7980 - g_loss: 8.0775 Epoch 18/20 118/118 [==============================] - 40s 335ms/step - d_loss: -2.7835 - g_loss: 8.7983 Epoch 19/20 118/118 [==============================] - 40s 335ms/step - d_loss: -2.6409 - g_loss: 7.8309 Epoch 20/20 118/118 [==============================] - 40s 336ms/step - d_loss: -2.5134 - g_loss: 8.6653 <tensorflow.python.keras.callbacks.History at 0x7fc1a410a278>
Display the last generated images:
from IPython.display import Image, display
display(Image("generated_img_0_19.png"))
display(Image("generated_img_1_19.png"))
display(Image("generated_img_2_19.png"))
以上