TensorFlow 2.0 Beta : 上級 Tutorials : 画像生成 :- CycleGAN (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/11/2019
* 本ページは、TensorFlow の本家サイトの TF 2.0 Beta – Advanced Tutorials – Image generation の以下のページを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
画像生成 :- CycleGAN
このノートブックは Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks で説明されている、CycleGAN としても知られる、条件付き GAN を使用して不対の (= unpaired) 画像から画像への変換 (= translation) を実演します。このペーパーはメソッドを提案します、それを通して一つの画像ドメインの特質を捉えてどのような対となる訓練サンプルも欠落しながら、これらの特質がどのようにもう一つの画像ドメインに翻訳できるかを見出します。
このノートブックは貴方が Pix2Pix について馴染みがあることを仮定しています、これについては Pix2Pix チュートリアル で学習することができます。CycleGAN のためのコードも類似していますが、主な違いは追加の損失関数と、不対の訓練データの使用です。
CycleGAN はペアデータを必要とせずに訓練を可能にするために cycle consistency 損失を使用します。換言すれば、それはソースとターゲットドメイン間の 1対1 マッピングなしに一つのドメインからもう一つのドメインへ変換することができます。
これは写真拡張、画像彩色、画風変換, etc. のような多くの興味深いタスクを行なう可能性を広げます。貴方が必要なことの総てはソースとターゲット・データセットです (これは単純に画像のディレクトリです)。
入力パイプラインをセットアップする
tensorflow_examples パッケージをインストールします、これは generator と discriminator のインポートを可能にします。
!pip install -q git+https://github.com/tensorflow/examples.git
!pip install -q tensorflow-gpu==2.0.0-beta1 import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literals import tensorflow_datasets as tfds from tensorflow_examples.models.pix2pix import pix2pix import os import time import matplotlib.pyplot as plt from IPython.display import clear_output tfds.disable_progress_bar() AUTOTUNE = tf.data.experimental.AUTOTUNE
入力パイプライン
このチュートリアルは馬の画像からシマウマの画像に変換するためにモデルを訓練します。このデータセットと類似のものは ここ で見つけられます。
ペーパー で言及されているように、訓練データセットにランダムな jittering とミラーリングを適用します。これらは overfitting を回避する画像増強テクニックの幾つかです。
これは pix2pix で行われたことに類似しています * ランダムな jittering では、画像は 286 x 286 にリサイズされてから 256 x 256 にランダムにクロップされます。* ランダムなミラーリングでは、画像はランダムに水平 i.e. 左から右に反転されます。
dataset, metadata = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True) train_horses, train_zebras = dataset['trainA'], dataset['trainB'] test_horses, test_zebras = dataset['testA'], dataset['testB']
Downloading and preparing dataset cycle_gan (111.45 MiB) to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/0.1.0... /home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:851: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning) WARNING: Logging before flag parsing goes to stderr. W0628 01:59:12.853749 140396912944896 deprecation.py:323] From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)` Dataset cycle_gan downloaded and prepared to /home/kbuilder/tensorflow_datasets/cycle_gan/horse2zebra/0.1.0. Subsequent calls will reuse this data.
BUFFER_SIZE = 1000 BATCH_SIZE = 1 IMG_WIDTH = 256 IMG_HEIGHT = 256
def random_crop(image): cropped_image = tf.image.random_crop( image, size=[IMG_HEIGHT, IMG_WIDTH, 3]) return cropped_image
# normalizing the images to [-1, 1] def normalize(image): image = tf.cast(image, tf.float32) image = (image / 127.5) - 1 return image
def random_jitter(image): # resizing to 286 x 286 x 3 image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) # randomly cropping to 256 x 256 x 3 image = random_crop(image) # random mirroring image = tf.image.random_flip_left_right(image) return image
def preprocess_image_train(image, label): image = random_jitter(image) image = normalize(image) return image
def preprocess_image_test(image, label): image = normalize(image) return image
train_horses = train_horses.map( preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1) train_zebras = train_zebras.map( preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1) test_horses = test_horses.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1) test_zebras = test_zebras.map( preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle( BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses)) sample_zebra = next(iter(train_zebras))
plt.subplot(121) plt.title('Horse') plt.imshow(sample_horse[0] * 0.5 + 0.5) plt.subplot(122) plt.title('Horse with random jitter') plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fb0100e4e10>
plt.subplot(121) plt.title('Zebra') plt.imshow(sample_zebra[0] * 0.5 + 0.5) plt.subplot(122) plt.title('Zebra with random jitter') plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7fb029232898>
Pix2Pix モデルをインポートして再利用する
Pix2Pix で使用された generator と discriminator をインストールされた tensorflow_examples パッケージを通してインポートします。
このチュートリアルで使用されるモデル・アーキテクチャは pix2pix で使用されたものに非常に類似しています。違いの幾つかは :
- Cyclegan は バッチ正規化 の代わりに インスタンス正規化 を使用します。
- CycleGAN ペーパー では修正された resnet ベースの generator を使用します。このチュートリアルは単純化のために修正された unet generator を使用しています。
ここで訓練される 2 generator ($G$ と $F$) と 2 discriminators (X と Y) (訳注: 原文ママ、正しくは $D_X$ と $D_Y$) があります。
- Generator $G$ は画像 X から画像 Y への変換を学習します。$(G: X -> Y)$
- Generator $F$ は画像 Y から画像 X への変換を学習します。$(F: Y -> X)$
- Discriminator $D_X$ は画像 X と生成された画像 X (F(Y)) を識別します。
- Discriminator $D_Y$ は画像 Y と生成された画像 Y (G(X)) を識別します。
OUTPUT_CHANNELS = 3 generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm') generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm') discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False) discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse) to_horse = generator_f(sample_zebra) plt.figure(figsize=(8, 8)) contrast = 8 imgs = [sample_horse, to_zebra, sample_zebra, to_horse] title = ['Horse', 'To Zebra', 'Zebra', 'To Horse'] for i in range(len(imgs)): plt.subplot(2, 2, i+1) plt.title(title[i]) if i % 2 == 0: plt.imshow(imgs[i][0] * 0.5 + 0.5) else: plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5) plt.show()
plt.figure(figsize=(8, 8)) plt.subplot(121) plt.title('Is a real zebra?') plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r') plt.subplot(122) plt.title('Is a real horse?') plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r') plt.show()
損失関数
CycleGAN では、訓練するペアとなるデータはありません、そのため訓練の間入力 x とターゲット y ペアに意味がある保証はありません。このためネットワークが正しいマッピングを学習することを強要するために、著者は cycle consistency 損失を提案しています。
discriminator 損失と generator 損失は pix2pix で使用されたものに類似しています。
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated): real_loss = loss_obj(tf.ones_like(real), real) generated_loss = loss_obj(tf.zeros_like(generated), generated) total_disc_loss = real_loss + generated_loss return total_disc_loss * 0.5
def generator_loss(generated): return loss_obj(tf.ones_like(generated), generated)
cycle consistency は結果は元の入力に近くあるべきといういう意味です。例えば、もしセンテンスを英語からフランス語に翻訳して、それからフランス語から英語に翻訳し戻す場合、結果としてのセンテンスは元のセンテンスと同じであるべきということです。
cycle consistency 損失では、
- 画像 $X$ は generator $G$ を通して渡されます、これは生成画像 $\hat{Y}$ を生成します。
- 生成画像 $\hat{Y}$ は generator $F$ を通して渡されます、これは cycled 画像 $\hat{X}$ を生成します。
- $X$ と $\hat{X}$ の間の mean absolute error (平均絶対誤差) が計算されます。
def calc_cycle_loss(real_image, cycled_image): loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) return LAMBDA * loss1
上で示されるように、generator $G$ は画像 $X$ を画像 $Y$ に変換する責任を負います。Identity 損失は、画像 $Y$ が generator $G$ に供給された場合、それは real 画像 $Y$ か画像 $Y$ に近い何かを生成するべきであると言っています。
def identity_loss(real_image, same_image): loss = tf.reduce_mean(tf.abs(real_image - same_image)) return LAMBDA * 0.5 * loss
総ての generator と discriminator のために optimizer を初期化します。
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
チェックポイント
checkpoint_path = "./checkpoints/train" ckpt = tf.train.Checkpoint(generator_g=generator_g, generator_f=generator_f, discriminator_x=discriminator_x, discriminator_y=discriminator_y, generator_g_optimizer=generator_g_optimizer, generator_f_optimizer=generator_f_optimizer, discriminator_x_optimizer=discriminator_x_optimizer, discriminator_y_optimizer=discriminator_y_optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) # if a checkpoint exists, restore the latest checkpoint. if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) print ('Latest checkpoint restored!!')
訓練
Note: このチュートリアルのために訓練時間を合理的に保つためにこのサンプルモデルはペーパー (200) よりも少ないエポック (40) の間訓練されます。予測は少し正確ではないかもしれません。
EPOCHS = 40
def generate_images(model, test_input): prediction = model(test_input) plt.figure(figsize=(12, 12)) display_list = [test_input[0], prediction[0]] title = ['Input Image', 'Predicted Image'] for i in range(2): plt.subplot(1, 2, i+1) plt.title(title[i]) # getting the pixel values between [0, 1] to plot it. plt.imshow(display_list[i] * 0.5 + 0.5) plt.axis('off') plt.show()
訓練ループは複雑に見えますけれども、それは 4 つの基本ステップから成ります : * 予測を得る。* 損失を計算する。* 逆伝播を使用して勾配を計算する。* optimizer に勾配を適用する。
@tf.function def train_step(real_x, real_y): # persistent is set to True because gen_tape and disc_tape is used more than # once to calculate the gradients. with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape( persistent=True) as disc_tape: # Generator G translates X -> Y # Generator F translates Y -> X. fake_y = generator_g(real_x, training=True) cycled_x = generator_f(fake_y, training=True) fake_x = generator_f(real_y, training=True) cycled_y = generator_g(fake_x, training=True) # same_x and same_y are used for identity loss. same_x = generator_f(real_x, training=True) same_y = generator_g(real_y, training=True) disc_real_x = discriminator_x(real_x, training=True) disc_real_y = discriminator_y(real_y, training=True) disc_fake_x = discriminator_x(fake_x, training=True) disc_fake_y = discriminator_y(fake_y, training=True) # calculate the loss gen_g_loss = generator_loss(disc_fake_y) gen_f_loss = generator_loss(disc_fake_x) # Total generator loss = adversarial loss + cycle loss total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x) + identity_loss(real_x, same_x) total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y) + identity_loss(real_y, same_y) disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x) disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y) # Calculate the gradients for generator and discriminator generator_g_gradients = gen_tape.gradient(total_gen_g_loss, generator_g.trainable_variables) generator_f_gradients = gen_tape.gradient(total_gen_f_loss, generator_f.trainable_variables) discriminator_x_gradients = disc_tape.gradient( disc_x_loss, discriminator_x.trainable_variables) discriminator_y_gradients = disc_tape.gradient( disc_y_loss, discriminator_y.trainable_variables) # Apply the gradients to the optimizer generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables)) generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables)) discriminator_x_optimizer.apply_gradients( zip(discriminator_x_gradients, discriminator_x.trainable_variables)) discriminator_y_optimizer.apply_gradients( zip(discriminator_y_gradients, discriminator_y.trainable_variables))
for epoch in range(EPOCHS): start = time.time() n = 0 for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)): train_step(image_x, image_y) if n % 10 == 0: print ('.', end='') n+=1 clear_output(wait=True) # Using a consistent image (sample_horse) so that the progress of the model # is clearly visible. generate_images(generator_g, sample_horse) if (epoch + 1) % 5 == 0: ckpt_save_path = ckpt_manager.save() print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path)) print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))
Time taken for epoch 26 is 109.30256915092468 sec ........................................
テストデータセットを使用して生成する
# Run the trained model on the test dataset for inp in test_horses.take(5): generate_images(generator_g, inp)
Next Steps
このチュートリアルは Pix2Pix チュートリアルで実装された generator と discriminator から始めて CycleGAN をどのように実装するかを示しました。次のステップとして、TensorFlow データセット からの異なるデータセットを使用して試すことができます。
貴方はまた結果を改善するためにより大きなエポック数の間訓練することもできます、あるいはここで使用された U-Net generator の代わりに ペーパー で使用された修正された ResNet generator を実装することもできます。
以上