TensorFlow : Tutorials : 生成モデル : DCGAN: tf.keras と eager のサンプル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/29/2018
* TensorFlow 1.10 で更に改訂されています。
* TensorFlow 1.9 でドキュメント構成が変更され、数篇が新規に追加されましたので再翻訳しました。
* 本ページは、TensorFlow の本家サイトの Tutorials – Generative Models の以下のページを翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
DCGAN : tf.keras と eager のサンプル
このノートブックは tf.keras と eager execution を使用して手書き数字の画像をどのように生成するかを示します。それを行なうために、DCGAN (Deep Convolutional Generative Adversarial Networks, 深層畳み込み敵対的生成ネットワーク) を使用します。
このモデルは 2018年7月 時点で、Colab 上単一の Tesla K80 上で訓練するためにエポック毎におよそ ~30 秒かかります (グラフ関数を作成するために tf.contrib.eager.defun を使用)。
下は、generator と discriminator モデルを 150 エポック訓練した後に生成された出力です。

# to generate gifs !pip install imageio
TensorFlow をインポートして eager execution を有効にする
from __future__ import absolute_import, division, print_function # Import TensorFlow >= 1.10 and enable eager execution import tensorflow as tf tf.enable_eager_execution() import os import time import numpy as np import glob import matplotlib.pyplot as plt import PIL import imageio from IPython import display
データセットをロードする
generator と discriminator を訓練するために MNIST データセットを使用していきます。それから generator は手書き数字を生成します。
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# We are normalizing the images to the range of [-1, 1]
train_images = (train_images - 127.5) / 127.5
BUFFER_SIZE = 60000 BATCH_SIZE = 256
バッチを作成してデータセットをシャッフルするために tf.data を使用する
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
generator と discriminator モデルを書く
- Generator
- それは、discriminator を欺くために十分に良い、説得力のある画像を生成する 責任を負います。
- それは Conv2DTranspose (Upsampling) 層で構成されます。完全結合層で開始して、望まれる画像サイズ (mnist 画像サイズ)、(28, 28, 1) に到達するために画像を 2 回アップサンプリングします。
- tanh 活性を仕様する最後の層を除いて leaky relu 活性を使用します。
- Discriminator
- discriminator は real 画像から fake 画像を分類する 責任を負います。
- 換言すれば、discriminator は (generator から) 生成された画像と real MNIST 画像を与えられます。discriminator のジョブはこれらの画像を (生成された) fake と real (MNIST 画像) に分類することです。
- 基本的には generator は discriminator を生成された画像が本物であると騙すために十分に良くあるべきです。
class Generator(tf.keras.Model):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)
self.batchnorm1 = tf.keras.layers.BatchNormalization()
self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)
self.batchnorm2 = tf.keras.layers.BatchNormalization()
self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.batchnorm3 = tf.keras.layers.BatchNormalization()
self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)
def call(self, x, training=True):
x = self.fc1(x)
x = self.batchnorm1(x, training=training)
x = tf.nn.relu(x)
x = tf.reshape(x, shape=(-1, 7, 7, 64))
x = self.conv1(x)
x = self.batchnorm2(x, training=training)
x = tf.nn.relu(x)
x = self.conv2(x)
x = self.batchnorm3(x, training=training)
x = tf.nn.relu(x)
x = tf.nn.tanh(self.conv3(x))
return x
class Discriminator(tf.keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')
self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')
self.dropout = tf.keras.layers.Dropout(0.3)
self.flatten = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(1)
def call(self, x, training=True):
x = tf.nn.leaky_relu(self.conv1(x))
x = self.dropout(x, training=training)
x = tf.nn.leaky_relu(self.conv2(x))
x = self.dropout(x, training=training)
x = self.flatten(x)
x = self.fc1(x)
return x
generator = Generator() discriminator = Discriminator()
# Defun gives 10 secs/epoch performance boost generator.call = tf.contrib.eager.defun(generator.call) discriminator.call = tf.contrib.eager.defun(discriminator.call)
損失関数と optimizer を定義する
- Discriminator 損失
- discriminator 損失関数は 2 つの入力を取ります ; real 画像、生成された画像です。
- real_loss は real 画像と 1 の配列の sigmoid 交差エントロピー損失です (何故ならばこれらは real 画像だからです)。
- generated_loss は生成された画像とゼロ配列の sigmod 交差エントロピー損失です (何故ならばこれらは fake 画像だからです)。
- そして total_loss は real_loss と generated_loss の合計です。
- Generator 損失
- それは生成された画像と 1 の配列の sigmoid 交差エントロピー損失です。
- discriminator と generator optimizer は異なります、何故ならばそれらを別々に訓練するからです。
def discriminator_loss(real_output, generated_output):
# [1,1,...,1] with real output since it is true and we want
# our generated examples to look like it
real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)
# [0,0,...,0] with generated images since they are fake
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)
total_loss = real_loss + generated_loss
return total_loss
def generator_loss(generated_output):
return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)
discriminator_optimizer = tf.train.AdamOptimizer(1e-4) generator_optimizer = tf.train.AdamOptimizer(1e-4)
チェックポイント (オブジェクト・ベースのセーブ)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
訓練
- データセットに渡り反復することから始めます。
- generator は入力としてノイズが与えられます、これは generator モデルを通して渡されたとき手書き数字のような画像を出力します。
- discriminator は real MNIST 画像 と (generator から) 生成された画像が与えられます。
- 次に、generator と discriminator 損失を計算します。
- それから、generator と discriminator 変数 (入力) の両者に関する損失の勾配を計算してそれらを optimizer に適用します。
画像を生成する
- 訓練の後、何某かの画像を生成する時です!
- generator への入力としてノイズ配列を作成することから始めます。
- それから generator はノイズを手書き数字に変換します。
- 最後のステップは予測をプロットします。Voila !
EPOCHS = 150
noise_dim = 100
num_examples_to_generate = 16
# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
noise_dim])
def generate_and_save_images(model, epoch, test_input):
# make sure the training parameter is set to False because we
# don't want to train the batchnorm layer when doing inference.
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
def train(dataset, epochs, noise_dim):
for epoch in range(epochs):
start = time.time()
for images in dataset:
# generating noise from a uniform distribution
noise = tf.random_normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
generated_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(generated_output)
disc_loss = discriminator_loss(real_output, generated_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))
if epoch % 1 == 0:
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
random_vector_for_generation)
# saving (checkpoint) the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
time.time()-start))
# generating after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
random_vector_for_generation)
train(train_dataset, EPOCHS, noise_dim)
最新のチェックポイントをリストアする
# restoring the latest checkpoint in checkpoint_dir checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
エポック数を使用して画像を表示する
def display_image(epoch_no):
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)
総ての保存された画像の GIF を生成する
with imageio.get_writer('dcgan.gif', mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
last = -1
for i,filename in enumerate(filenames):
frame = 2*(i**0.5)
if round(frame) > round(last):
last = frame
else:
continue
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
# this is a hack to display the gif inside the notebook
os.system('cp dcgan.gif dcgan.gif.png')
display.Image(filename="dcgan.gif.png")
Colab からアニメーションをダウンロードするためには下のコードをアンコメントします :
#from google.colab import files
#files.download('dcgan.gif')
以上