Keras 2 : examples : 生成深層学習 – 条件付き画像生成のための GauGAN (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/09/2022 (keras 2.9.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Generative Deep Learning : GauGAN for conditional image generation (Author: Soumik Rakshit, Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 生成深層学習 – 条件付き画像生成のための GauGAN
Description : 条件付き画像生成のための GauGAN の実装。
イントロダクション
このサンプルでは、Semantic Image Synthesis with Spatially-Adaptive Normalization で提案された GauGAN アーキテクチャの実装を提示します。簡潔には、GauGAN は、以下に示されるように、キュー (= cue) 画像とセグメンテーション・マップにより条件付けられたリアルな画像を生成するために敵対的生成ネットワーク (GAN) を使用します (画像ソース) :
GauGAN の主要コンポーネントは :
- SPADE (aka spatially-adaptive 正規化) : GauGAN の著者らは、(バッチ正規化 のような) 従来の正規化層は、入力として提供されたセグメンテーション・マップから得られた意味的情報を破壊することを主張しています。この問題に対処するため、著者らは SPADE, 空間的に適応可能なアフィン・パラメータ (スケールとバイアス) を学習するために特に適している正規化層を導入しました。これは、各意味的ラベルに対してスケーリングとバイアス・パラメータの異なるセットを学習することにより成されます。
- 変分エンコーダ : 変分オートエンコーダ にインスパイアされ、GauGAN は、cue 画像から正規 (ガウス) 分布の平均と分散を学習する、変分定式化を利用しています。そこから GauGAN は名前を得ました。GauGAN の generator は入力として、ガウス分布からサンプリングされた潜在値 (= latent) と one-hot エンコードされたセマンティックセグメンテーションのラベルマップを受けとります。cue 画像は generator をスタイルの生成へガイドするスタイル画像として機能します。この変分定式化は GauGAN が画像の多様性と忠実性を獲得するのに役立ちます。
- マルチスケール・パッチ discriminator : PatchGAN にインスパイアされ、GauGAN は与えられた画像をパッチベースで評価して平均スコアを生成する discriminator を使用します。
サンプルを進めながら、様々なコンポーネントの各々を詳細に議論します。
GauGAN の徹底的なレビューについては、この記事 を参照してください。また 公式 GauGAN web サイト をチェックすることも勧めます、これは GauGAN の多くの創造的なアプリケーションを持っています。このサンプルは読者が GAN の基礎的な概念に既に馴染みがあることを仮定しています。refresher が必要であれば、以下のリソースが有用であるかもしれません :
- Chapter on GANs from the Deep Learning with Python book by François Chollet.
- GAN implementations on keras.io:
- [Data efficient GANs](https://keras.io/examples/generative/gan_ada)
- [CycleGAN](https://keras.io/examples/generative/cyclegan)
- [Conditional GAN](https://keras.io/examples/generative/conditional_gan)
データ・コレクション
GauGAN モデルを訓練するために Facades データセット を使用していきます。まずはそれをダウンロードしましょう。また TensorFlow Addons もインストールします。
!gdown https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj
!unzip -q facades_data.zip
!pip install -qqq tensorflow_addons
Downloading... From: https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj To: /content/keras-io/scripts/tmp_2820468/facades_data.zip 100% 26.0M/26.0M [00:00<00:00, 261MB/s] [K |████████████████████████████████| 1.1 MB 8.5 MB/s [?25h
インポート
import os
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
from glob import glob
from PIL import Image
データ分割
PATH = "./facades_data/"
SPLIT = 0.2
files = glob(PATH + "*.jpg")
np.random.shuffle(files)
split_index = int(len(files) * (1 - SPLIT))
train_files = files[:split_index]
val_files = files[split_index:]
print(f"Total samples: {len(files)}.")
print(f"Total training samples: {len(train_files)}.")
print(f"Total validation samples: {len(val_files)}.")
Total samples: 378. Total training samples: 302. Total validation samples: 76.
データローダ
BATCH_SIZE = 4
IMG_HEIGHT = IMG_WIDTH = 256
NUM_CLASSES = 12
AUTOTUNE = tf.data.AUTOTUNE
def load(image_files, batch_size, is_train=True):
def _random_crop(
segmentation_map, image, labels, crop_size=(IMG_HEIGHT, IMG_WIDTH),
):
crop_size = tf.convert_to_tensor(crop_size)
image_shape = tf.shape(image)[:2]
margins = image_shape - crop_size
y1 = tf.random.uniform(shape=(), maxval=margins[0], dtype=tf.int32)
x1 = tf.random.uniform(shape=(), maxval=margins[1], dtype=tf.int32)
y2 = y1 + crop_size[0]
x2 = x1 + crop_size[1]
cropped_images = []
images = [segmentation_map, image, labels]
for img in images:
cropped_images.append(img[y1:y2, x1:x2])
return cropped_images
def _load_data_tf(image_file, segmentation_map_file, label_file):
image = tf.image.decode_png(tf.io.read_file(image_file), channels=3)
segmentation_map = tf.image.decode_png(
tf.io.read_file(segmentation_map_file), channels=3
)
labels = tf.image.decode_bmp(tf.io.read_file(label_file), channels=0)
labels = tf.squeeze(labels)
image = tf.cast(image, tf.float32) / 127.5 - 1
segmentation_map = tf.cast(segmentation_map, tf.float32) / 127.5 - 1
return segmentation_map, image, labels
segmentation_map_files = [
image_file.replace("images", "segmentation_map").replace("jpg", "png")
for image_file in image_files
]
label_files = [
image_file.replace("images", "segmentation_labels").replace("jpg", "bmp")
for image_file in image_files
]
dataset = tf.data.Dataset.from_tensor_slices(
(image_files, segmentation_map_files, label_files)
)
dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
dataset = dataset.map(_load_data_tf, num_parallel_calls=AUTOTUNE)
dataset = dataset.map(_random_crop, num_parallel_calls=AUTOTUNE)
dataset = dataset.map(
lambda x, y, z: (x, y, tf.one_hot(z, NUM_CLASSES)), num_parallel_calls=AUTOTUNE
)
return dataset.batch(batch_size, drop_remainder=True)
train_dataset = load(train_files, batch_size=BATCH_SIZE, is_train=True)
val_dataset = load(val_files, batch_size=BATCH_SIZE, is_train=False)
次に、訓練セットから幾つかサンプルを可視化しましょう。
sample_train_batch = next(iter(train_dataset))
print(f"Segmentation map batch shape: {sample_train_batch[0].shape}.")
print(f"Image batch shape: {sample_train_batch[1].shape}.")
print(f"One-hot encoded label map shape: {sample_train_batch[2].shape}.")
# Plot a view samples from the training set.
for segmentation_map, real_image in zip(sample_train_batch[0], sample_train_batch[1]):
fig = plt.figure(figsize=(10, 10))
fig.add_subplot(1, 2, 1).set_title("Segmentation Map")
plt.imshow((segmentation_map + 1) / 2)
fig.add_subplot(1, 2, 2).set_title("Real Image")
plt.imshow((real_image + 1) / 2)
plt.show()
Segmentation map batch shape: (4, 256, 256, 3). Image batch shape: (4, 256, 256, 3). One-hot encoded label map shape: (4, 256, 256, 12).
この例の残りでは、便宜上、オリジナルの GauGAN 論文 からの幾つかの図を使用します。
カスタム層
以下のセクションで、以下の層を実装します :
- SPADE
- SPADE を含む残差ブロック
- Gaussian サンプラー
Some more notes on SPADE
SPatially-Adaptive (DE) 正規化 or SPADE は、入力セマンティック・レイアウトが与えられたときに写真のようにリアルな画像を合成するための単純ですが効果的な層です。Pix2Pix (Isola et al.) や Pix2PixHD (Wang et al.) のようなセマンティック入力から条件付き画像生成するための以前の方法は深層ネットワークへの入力としてセマンティック・レイアウトを直接供給し、それから畳み込み、正規化、非線形層のスタックを通して処理します。これは、正規化層が意味的情報を洗い流す傾向があるために最適ではない場合が多いです。
SPADE では、セグメンテーション・マスクが最初に埋め込み空間上に射影されてから、変調 (= modulation) パラメータ γ と β を生成するために畳み込まれます。前の条件付き正規化法と違い、γ と β はベクトルではなく、空間次元を持つテンソルです。生成された γ と β は正規化された活性に要素毎に乗算されて加算されます。modulation パラメータは入力セグメンテーション・マスクに適応可能ですので、SPADE は意味的画像合成に対してより適切です。
class SPADE(layers.Layer):
def __init__(self, filters, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
self.conv = layers.Conv2D(128, 3, padding="same", activation="relu")
self.conv_gamma = layers.Conv2D(filters, 3, padding="same")
self.conv_beta = layers.Conv2D(filters, 3, padding="same")
def build(self, input_shape):
self.resize_shape = input_shape[1:3]
def call(self, input_tensor, raw_mask):
mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
x = self.conv(mask)
gamma = self.conv_gamma(x)
beta = self.conv_beta(x)
mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
std = tf.sqrt(var + self.epsilon)
normalized = (input_tensor - mean) / std
output = gamma * normalized + beta
return output
class ResBlock(layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.filters = filters
def build(self, input_shape):
input_filter = input_shape[-1]
self.spade_1 = SPADE(input_filter)
self.spade_2 = SPADE(self.filters)
self.conv_1 = layers.Conv2D(self.filters, 3, padding="same")
self.conv_2 = layers.Conv2D(self.filters, 3, padding="same")
self.learned_skip = False
if self.filters != input_filter:
self.learned_skip = True
self.spade_3 = SPADE(input_filter)
self.conv_3 = layers.Conv2D(self.filters, 3, padding="same")
def call(self, input_tensor, mask):
x = self.spade_1(input_tensor, mask)
x = self.conv_1(tf.nn.leaky_relu(x, 0.2))
x = self.spade_2(x, mask)
x = self.conv_2(tf.nn.leaky_relu(x, 0.2))
skip = (
self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2))
if self.learned_skip
else input_tensor
)
output = skip + x
return output
class GaussianSampler(layers.Layer):
def __init__(self, batch_size, latent_dim, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.latent_dim = latent_dim
def call(self, inputs):
means, variance = inputs
epsilon = tf.random.normal(
shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0
)
samples = means + tf.exp(0.5 * variance) * epsilon
return samples
次に、エンコーダのためのダウンサンプリング・ブロックを実装します。
def downsample(
channels,
kernels,
strides=2,
apply_norm=True,
apply_activation=True,
apply_dropout=False,
):
block = keras.Sequential()
block.add(
layers.Conv2D(
channels,
kernels,
strides=strides,
padding="same",
use_bias=False,
kernel_initializer=keras.initializers.GlorotNormal(),
)
)
if apply_norm:
block.add(tfa.layers.InstanceNormalization())
if apply_activation:
block.add(layers.LeakyReLU(0.2))
if apply_dropout:
block.add(layers.Dropout(0.5))
return block
GauGAN エンコーダは幾つかのダウンサンプリング・ブロックから構成されます。それは分布の平均と分散を出力します。
def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):
input_image = keras.Input(shape=image_shape)
x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image)
x = downsample(2 * encoder_downsample_factor, 3)(x)
x = downsample(4 * encoder_downsample_factor, 3)(x)
x = downsample(8 * encoder_downsample_factor, 3)(x)
x = downsample(8 * encoder_downsample_factor, 3)(x)
x = layers.Flatten()(x)
mean = layers.Dense(latent_dim, name="mean")(x)
variance = layers.Dense(latent_dim, name="variance")(x)
return keras.Model(input_image, [mean, variance], name="encoder")
次に、generator を実装します、これは変更された残差ブロックとアップサンプリング・ブロックから構成されます。それは潜在ベクトルと one-hot エンコードされたセグメンテーション・ラベルと取り、新しい画像を生成します。
SPADE では、generator の最初の層にセグメンテーション・マップを供給する必要はありません、何故ならば潜在入力は generator にエミュレートすることを望むスタイルについて十分な構造的情報を持つからです。以前のアーキテクチャで一般に使用される、generator のエンコーダ部も捨てます。これはより軽量な generator ネットワークという結果になります、これはまた入力としてランダム・ベクトルを取り、多様な合成への単純で自然なパスを有効にします。
def build_generator(mask_shape, latent_dim=256):
latent = keras.Input(shape=(latent_dim))
mask = keras.Input(shape=mask_shape)
x = layers.Dense(16384)(latent)
x = layers.Reshape((4, 4, 1024))(x)
x = ResBlock(filters=1024)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=1024)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=1024)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=512)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=256)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=128)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = tf.nn.leaky_relu(x, 0.2)
output_image = tf.nn.tanh(layers.Conv2D(3, 4, padding="same")(x))
return keras.Model([latent, mask], output_image, name="generator")
discriminator はセグメンテーション・マップと画像を取り、それらを連結します。そして連結された画像のパッチが本物か偽物かを予測します。
def build_discriminator(image_shape, downsample_factor=64):
input_image_A = keras.Input(shape=image_shape, name="discriminator_image_A")
input_image_B = keras.Input(shape=image_shape, name="discriminator_image_B")
x = layers.Concatenate()([input_image_A, input_image_B])
x1 = downsample(downsample_factor, 4, apply_norm=False)(x)
x2 = downsample(2 * downsample_factor, 4)(x1)
x3 = downsample(4 * downsample_factor, 4)(x2)
x4 = downsample(8 * downsample_factor, 4, strides=1)(x3)
x5 = layers.Conv2D(1, 4)(x4)
outputs = [x1, x2, x3, x4, x5]
return keras.Model([input_image_A, input_image_B], outputs)
損失関数
GauGAN は以下の損失関数を使用します :
- Generator:
- discriminator 予測に渡る期待値。
- KL divergence エンコーダにより予測された平均と分散を学習するため
- generator の特徴空間をアラインするためのオリジナルと生成画像の discriminator 予測間の最小化。
- Perceptual loss for encouraging the generated images to have perceptual quality.
- Discriminator:
-
def generator_loss(y): return -tf.reduce_mean(y) def kl_divergence_loss(mean, variance): return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance)) class FeatureMatchingLoss(keras.losses.Loss): def __init__(self, **kwargs): super().__init__(**kwargs) self.mae = keras.losses.MeanAbsoluteError() def call(self, y_true, y_pred): loss = 0 for i in range(len(y_true) - 1): loss += self.mae(y_true[i], y_pred[i]) return loss class VGGFeatureMatchingLoss(keras.losses.Loss): def __init__(self, **kwargs): super().__init__(**kwargs) self.encoder_layers = [ "block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1", "block5_conv1", ] self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] vgg = keras.applications.VGG19(include_top=False, weights="imagenet") layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers] self.vgg_model = keras.Model(vgg.input, layer_outputs, name="VGG") self.mae = keras.losses.MeanAbsoluteError() def call(self, y_true, y_pred): y_true = keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1)) y_pred = keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1)) real_features = self.vgg_model(y_true) fake_features = self.vgg_model(y_pred) loss = 0 for i in range(len(real_features)): loss += self.weights[i] * self.mae(real_features[i], fake_features[i]) return loss class DiscriminatorLoss(keras.losses.Loss): def __init__(self, **kwargs): super().__init__(**kwargs) self.hinge_loss = keras.losses.Hinge() def call(self, y, is_real): label = 1.0 if is_real else -1.0 return self.hinge_loss(label, y)
- Hinge loss
-
GAN モニタ・コールバック
次に、訓練の間に GauGAN の結果をモニタするコールバックを実装します。
class GanMonitor(keras.callbacks.Callback):
def __init__(self, val_dataset, n_samples, epoch_interval=5):
self.val_images = next(iter(val_dataset))
self.n_samples = n_samples
self.epoch_interval = epoch_interval
def infer(self):
latent_vector = tf.random.normal(
shape=(self.model.batch_size, self.model.latent_dim), mean=0.0, stddev=2.0
)
return self.model.predict([latent_vector, self.val_images[2]])
def on_epoch_end(self, epoch, logs=None):
if epoch % self.epoch_interval == 0:
generated_images = self.infer()
for _ in range(self.n_samples):
grid_row = min(generated_images.shape[0], 3)
f, axarr = plt.subplots(grid_row, 3, figsize=(18, grid_row * 6))
for row in range(grid_row):
ax = axarr if grid_row == 1 else axarr[row]
ax[0].imshow((self.val_images[0][row] + 1) / 2)
ax[0].axis("off")
ax[0].set_title("Mask", fontsize=20)
ax[1].imshow((self.val_images[1][row] + 1) / 2)
ax[1].axis("off")
ax[1].set_title("Ground Truth", fontsize=20)
ax[2].imshow((generated_images[row] + 1) / 2)
ax[2].axis("off")
ax[2].set_title("Generated", fontsize=20)
plt.show()
サブクラス化された GauGAN モデル
最後に、train_step() メソッドをオーバライドして総てを (from tf.keras.Model から) サブクラス化されたモデルにまとめます。
class GauGAN(keras.Model):
def __init__(
self,
image_size,
num_classes,
batch_size,
latent_dim,
feature_loss_coeff=10,
vgg_feature_loss_coeff=0.1,
kl_divergence_loss_coeff=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.latent_dim = latent_dim
self.batch_size = batch_size
self.num_classes = num_classes
self.image_shape = (image_size, image_size, 3)
self.mask_shape = (image_size, image_size, num_classes)
self.feature_loss_coeff = feature_loss_coeff
self.vgg_feature_loss_coeff = vgg_feature_loss_coeff
self.kl_divergence_loss_coeff = kl_divergence_loss_coeff
self.discriminator = build_discriminator(self.image_shape)
self.generator = build_generator(self.mask_shape)
self.encoder = build_encoder(self.image_shape)
self.sampler = GaussianSampler(batch_size, latent_dim)
self.patch_size, self.combined_model = self.build_combined_generator()
self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss")
self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss")
self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss")
self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss")
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.disc_loss_tracker,
self.gen_loss_tracker,
self.feat_loss_tracker,
self.vgg_loss_tracker,
self.kl_loss_tracker,
]
def build_combined_generator(self):
# This method builds a model that takes as inputs the following:
# latent vector, one-hot encoded segmentation label map, and
# a segmentation map. It then (i) generates an image with the generator,
# (ii) passes the generated images and segmentation map to the discriminator.
# Finally, the model produces the following outputs: (a) discriminator outputs,
# (b) generated image.
# We will be using this model to simplify the implementation.
self.discriminator.trainable = False
mask_input = keras.Input(shape=self.mask_shape, name="mask")
image_input = keras.Input(shape=self.image_shape, name="image")
latent_input = keras.Input(shape=(self.latent_dim), name="latent")
generated_image = self.generator([latent_input, mask_input])
discriminator_output = self.discriminator([image_input, generated_image])
patch_size = discriminator_output[-1].shape[1]
combined_model = keras.Model(
[latent_input, mask_input, image_input],
[discriminator_output, generated_image],
)
return patch_size, combined_model
def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs):
super().compile(**kwargs)
self.generator_optimizer = keras.optimizers.Adam(
gen_lr, beta_1=0.0, beta_2=0.999
)
self.discriminator_optimizer = keras.optimizers.Adam(
disc_lr, beta_1=0.0, beta_2=0.999
)
self.discriminator_loss = DiscriminatorLoss()
self.feature_matching_loss = FeatureMatchingLoss()
self.vgg_loss = VGGFeatureMatchingLoss()
def train_discriminator(self, latent_vector, segmentation_map, real_image, labels):
fake_images = self.generator([latent_vector, labels])
with tf.GradientTape() as gradient_tape:
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
pred_real = self.discriminator([segmentation_map, real_image])[-1]
loss_fake = self.discriminator_loss(pred_fake, False)
loss_real = self.discriminator_loss(pred_real, True)
total_loss = 0.5 * (loss_fake + loss_real)
self.discriminator.trainable = True
gradients = gradient_tape.gradient(
total_loss, self.discriminator.trainable_variables
)
self.discriminator_optimizer.apply_gradients(
zip(gradients, self.discriminator.trainable_variables)
)
return total_loss
def train_generator(
self, latent_vector, segmentation_map, labels, image, mean, variance
):
# Generator learns through the signal provided by the discriminator. During
# backpropagation, we only update the generator parameters.
self.discriminator.trainable = False
with tf.GradientTape() as tape:
real_d_output = self.discriminator([segmentation_map, image])
fake_d_output, fake_image = self.combined_model(
[latent_vector, labels, segmentation_map]
)
pred = fake_d_output[-1]
# Compute generator losses.
g_loss = generator_loss(pred)
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
real_d_output, fake_d_output
)
total_loss = g_loss + kl_loss + vgg_loss + feature_loss
all_trainable_variables = (
self.combined_model.trainable_variables + self.encoder.trainable_variables
)
gradients = tape.gradient(total_loss, all_trainable_variables)
self.generator_optimizer.apply_gradients(
zip(gradients, all_trainable_variables)
)
return total_loss, feature_loss, vgg_loss, kl_loss
def train_step(self, data):
segmentation_map, image, labels = data
mean, variance = self.encoder(image)
latent_vector = self.sampler([mean, variance])
discriminator_loss = self.train_discriminator(
latent_vector, segmentation_map, image, labels
)
(generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator(
latent_vector, segmentation_map, labels, image, mean, variance
)
# Report progress.
self.disc_loss_tracker.update_state(discriminator_loss)
self.gen_loss_tracker.update_state(generator_loss)
self.feat_loss_tracker.update_state(feature_loss)
self.vgg_loss_tracker.update_state(vgg_loss)
self.kl_loss_tracker.update_state(kl_loss)
results = {m.name: m.result() for m in self.metrics}
return results
def test_step(self, data):
segmentation_map, image, labels = data
# Obtain the learned moments of the real image distribution.
mean, variance = self.encoder(image)
# Sample a latent from the distribution defined by the learned moments.
latent_vector = self.sampler([mean, variance])
# Generate the fake images.
fake_images = self.generator([latent_vector, labels])
# Calculate the losses.
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
pred_real = self.discriminator([segmentation_map, image])[-1]
loss_fake = self.discriminator_loss(pred_fake, False)
loss_real = self.discriminator_loss(pred_real, True)
total_discriminator_loss = 0.5 * (loss_fake + loss_real)
real_d_output = self.discriminator([segmentation_map, image])
fake_d_output, fake_image = self.combined_model(
[latent_vector, labels, segmentation_map]
)
pred = fake_d_output[-1]
g_loss = generator_loss(pred)
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
real_d_output, fake_d_output
)
total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss
# Report progress.
self.disc_loss_tracker.update_state(total_discriminator_loss)
self.gen_loss_tracker.update_state(total_generator_loss)
self.feat_loss_tracker.update_state(feature_loss)
self.vgg_loss_tracker.update_state(vgg_loss)
self.kl_loss_tracker.update_state(kl_loss)
results = {m.name: m.result() for m in self.metrics}
return results
def call(self, inputs):
latent_vectors, labels = inputs
return self.generator([latent_vectors, labels])
GauGAN 訓練
gaugan = GauGAN(IMG_HEIGHT, NUM_CLASSES, BATCH_SIZE, latent_dim=256)
gaugan.compile()
history = gaugan.fit(
train_dataset,
validation_data=val_dataset,
epochs=15,
callbacks=[GanMonitor(val_dataset, BATCH_SIZE)],
)
def plot_history(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_history("disc_loss")
plot_history("gen_loss")
plot_history("feat_loss")
plot_history("vgg_loss")
plot_history("kl_loss")
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5 80142336/80134624 [==============================] - 0s 0us/step 80150528/80134624 [==============================] - 0s 0us/step Epoch 1/15 WARNING:tensorflow:Gradients do not exist for variables ['conv2d_6/kernel:0', 'conv2d_7/kernel:0', 'instance_normalization_3/gamma:0', 'instance_normalization_3/beta:0', 'conv2d_8/kernel:0', 'instance_normalization_4/gamma:0', 'instance_normalization_4/beta:0', 'conv2d_9/kernel:0', 'instance_normalization_5/gamma:0', 'instance_normalization_5/beta:0', 'conv2d_10/kernel:0', 'instance_normalization_6/gamma:0', 'instance_normalization_6/beta:0', 'mean/kernel:0', 'mean/bias:0', 'variance/kernel:0', 'variance/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d_6/kernel:0', 'conv2d_7/kernel:0', 'instance_normalization_3/gamma:0', 'instance_normalization_3/beta:0', 'conv2d_8/kernel:0', 'instance_normalization_4/gamma:0', 'instance_normalization_4/beta:0', 'conv2d_9/kernel:0', 'instance_normalization_5/gamma:0', 'instance_normalization_5/beta:0', 'conv2d_10/kernel:0', 'instance_normalization_6/gamma:0', 'instance_normalization_6/beta:0', 'mean/kernel:0', 'mean/bias:0', 'variance/kernel:0', 'variance/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument? 75/75 [==============================] - ETA: 0s - disc_loss: 1.1359 - gen_loss: 114.6762 - feat_loss: 9.6107 - vgg_loss: 17.5540 - kl_loss: 87.3495
75/75 [==============================] - 69s 620ms/step - disc_loss: 1.1359 - gen_loss: 114.6762 - feat_loss: 9.6107 - vgg_loss: 17.5540 - kl_loss: 87.3495 - val_disc_loss: 0.9339 - val_gen_loss: 115.9361 - val_feat_loss: 12.0053 - val_vgg_loss: 17.6576 - val_kl_loss: 86.4241 Epoch 2/15 75/75 [==============================] - 39s 522ms/step - disc_loss: 0.9191 - gen_loss: 115.4342 - feat_loss: 11.1384 - vgg_loss: 16.6970 - kl_loss: 87.1299 - val_disc_loss: 0.9073 - val_gen_loss: 117.0431 - val_feat_loss: 10.9293 - val_vgg_loss: 17.3108 - val_kl_loss: 87.1910 Epoch 3/15 75/75 [==============================] - 40s 530ms/step - disc_loss: 0.7783 - gen_loss: 116.0476 - feat_loss: 11.2018 - vgg_loss: 16.4767 - kl_loss: 87.7456 - val_disc_loss: 0.7877 - val_gen_loss: 115.6750 - val_feat_loss: 11.3406 - val_vgg_loss: 16.9246 - val_kl_loss: 87.8907 Epoch 4/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.6915 - gen_loss: 115.8905 - feat_loss: 10.7578 - vgg_loss: 16.3213 - kl_loss: 88.0270 - val_disc_loss: 0.7651 - val_gen_loss: 115.5427 - val_feat_loss: 11.5930 - val_vgg_loss: 17.0086 - val_kl_loss: 87.3675 Epoch 5/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.6652 - gen_loss: 115.3557 - feat_loss: 10.7736 - vgg_loss: 16.3333 - kl_loss: 87.4493 - val_disc_loss: 0.9139 - val_gen_loss: 115.3157 - val_feat_loss: 11.3612 - val_vgg_loss: 17.0591 - val_kl_loss: 87.6537 Epoch 6/15 75/75 [==============================] - ETA: 0s - disc_loss: 0.6541 - gen_loss: 115.2529 - feat_loss: 10.6386 - vgg_loss: 16.2342 - kl_loss: 87.5053
75/75 [==============================] - 43s 573ms/step - disc_loss: 0.6541 - gen_loss: 115.2529 - feat_loss: 10.6386 - vgg_loss: 16.2342 - kl_loss: 87.5053 - val_disc_loss: 0.3999 - val_gen_loss: 116.1638 - val_feat_loss: 11.1031 - val_vgg_loss: 16.9759 - val_kl_loss: 87.5684 Epoch 7/15 75/75 [==============================] - 40s 530ms/step - disc_loss: 0.6029 - gen_loss: 115.3866 - feat_loss: 10.6807 - vgg_loss: 16.2025 - kl_loss: 87.5448 - val_disc_loss: 0.4571 - val_gen_loss: 117.5491 - val_feat_loss: 10.7403 - val_vgg_loss: 16.8212 - val_kl_loss: 88.4036 Epoch 8/15 75/75 [==============================] - 39s 522ms/step - disc_loss: 0.5798 - gen_loss: 115.1903 - feat_loss: 10.5906 - vgg_loss: 16.1720 - kl_loss: 87.4315 - val_disc_loss: 0.4470 - val_gen_loss: 114.3039 - val_feat_loss: 10.8104 - val_vgg_loss: 16.9426 - val_kl_loss: 86.1938 Epoch 9/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5412 - gen_loss: 115.6245 - feat_loss: 10.5598 - vgg_loss: 16.1985 - kl_loss: 87.8032 - val_disc_loss: 0.3365 - val_gen_loss: 116.6229 - val_feat_loss: 10.9437 - val_vgg_loss: 16.9026 - val_kl_loss: 87.6305 Epoch 10/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5822 - gen_loss: 115.2743 - feat_loss: 10.5127 - vgg_loss: 16.1609 - kl_loss: 87.5487 - val_disc_loss: 0.4711 - val_gen_loss: 116.2798 - val_feat_loss: 11.4271 - val_vgg_loss: 16.7643 - val_kl_loss: 87.5068 Epoch 11/15 75/75 [==============================] - ETA: 0s - disc_loss: 0.5588 - gen_loss: 115.1784 - feat_loss: 10.5053 - vgg_loss: 16.0580 - kl_loss: 87.5601
75/75 [==============================] - 43s 571ms/step - disc_loss: 0.5588 - gen_loss: 115.1784 - feat_loss: 10.5053 - vgg_loss: 16.0580 - kl_loss: 87.5601 - val_disc_loss: 0.7086 - val_gen_loss: 116.6372 - val_feat_loss: 11.6103 - val_vgg_loss: 16.9735 - val_kl_loss: 88.3342 Epoch 12/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5461 - gen_loss: 115.2417 - feat_loss: 10.5498 - vgg_loss: 16.0916 - kl_loss: 87.5209 - val_disc_loss: 0.5056 - val_gen_loss: 116.8908 - val_feat_loss: 10.3318 - val_vgg_loss: 16.8074 - val_kl_loss: 87.9361 Epoch 13/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5482 - gen_loss: 114.9858 - feat_loss: 10.3671 - vgg_loss: 16.0029 - kl_loss: 87.5513 - val_disc_loss: 0.5819 - val_gen_loss: 116.6370 - val_feat_loss: 11.3866 - val_vgg_loss: 17.0710 - val_kl_loss: 88.0975 Epoch 14/15 75/75 [==============================] - 39s 522ms/step - disc_loss: 0.5596 - gen_loss: 114.5251 - feat_loss: 10.3841 - vgg_loss: 16.0361 - kl_loss: 86.9755 - val_disc_loss: 0.4472 - val_gen_loss: 115.9854 - val_feat_loss: 10.2750 - val_vgg_loss: 16.9934 - val_kl_loss: 87.2017 Epoch 15/15 75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5417 - gen_loss: 114.3627 - feat_loss: 10.1977 - vgg_loss: 15.9871 - kl_loss: 87.1026 - val_disc_loss: 0.4202 - val_gen_loss: 115.1249 - val_feat_loss: 10.2537 - val_vgg_loss: 16.8011 - val_kl_loss: 86.5379
推論
val_iterator = iter(val_dataset)
for _ in range(5):
val_images = next(val_iterator)
# Sample latent from a normal distribution.
latent_vector = tf.random.normal(
shape=(gaugan.batch_size, gaugan.latent_dim), mean=0.0, stddev=2.0
)
# Generate fake images.
fake_images = gaugan.predict([latent_vector, val_images[2]])
real_images = val_images
grid_row = min(fake_images.shape[0], 3)
grid_col = 3
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col * 6, grid_row * 6))
for row in range(grid_row):
ax = axarr if grid_row == 1 else axarr[row]
ax[0].imshow((real_images[0][row] + 1) / 2)
ax[0].axis("off")
ax[0].set_title("Mask", fontsize=20)
ax[1].imshow((real_images[1][row] + 1) / 2)
ax[1].axis("off")
ax[1].set_title("Ground Truth", fontsize=20)
ax[2].imshow((fake_images[row] + 1) / 2)
ax[2].axis("off")
ax[2].set_title("Generated", fontsize=20)
plt.show()
終わりに
- この例で使用したデータセットは小さいものです。より良い結果を得るためにはより大きなデータセットを使用することを勧めます。GauGAN の結果は COCO-Stuff と CityScapes データセットで実演されました。
- このサンプルは Soon-Yau Cheong による Hands-On Image Generation with TensorFlow の 6 章と Divyansh Jha による Implementing SPADE using fastai にインスパイアされました。
- If you found this example interesting and exciting, you might want to check out our repository which we are currently building. It will include reimplementations of popular GANs and pretrained models. Our focus will be on readibility and making the code as accessible as possible. Our plan is to first train our implementation of GauGAN (following the code of this example) on a bigger dataset and then make the repository public. We welcome contributions!
- Recently GauGAN2 was also released. You can check it out here.
以上