Sonnet 2.0 : Tutorials : VQ-VAE 訓練サンプル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020
* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
VQ-VAE 訓練サンプル
TF 2 / Sonnet 2 を使用して、https://arxiv.org/abs/1711.00937 で指定されるモデルをどのように訓練するかの実演です。
Mac と Linux 上、単純に各セルを順番に実行してください。
!pip install dm-sonnet dm-tree
Requirement already satisfied: dm-sonnet in /tmp/sonnet-nb-env/lib/python3.7/site-packages (2.0.0) Requirement already satisfied: dm-tree in /tmp/sonnet-nb-env/lib/python3.7/site-packages (0.1.5) Requirement already satisfied: six>=1.12.0 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.14.0) Requirement already satisfied: tabulate>=0.7.5 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.8.7) Requirement already satisfied: absl-py>=0.7.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.9.0) Requirement already satisfied: numpy>=1.16.3 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.18.3) Requirement already satisfied: wrapt>=1.11.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.12.1)
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree
try:
import sonnet.v2 as snt
tf.enable_v2_behavior()
except ImportError:
import sonnet as snt
print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))
TensorFlow version 2.1.0 Sonnet version 2.0.0
Cifar10 データをダウンロードする
これはインターネットへの接続を必要として ~160MB をダウンロードします。
cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1))
cifar10.pop("id", None)
cifar10.pop("label")
tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)
{'image': 'uint8[60000, 32, 32, 3]'}
データを Numpy にロードする
下で平均二乗誤差を正規化するために訓練セット全体の分散を計算します。
train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10) valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10) test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)
def cast_and_normalise_images(data_dict):
"""Convert images to floating point with the range [-0.5, 0.5]"""
images = data_dict['image']
data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
return data_dict
train_data_variance = np.var(train_data_dict['image'] / 255.0)
print('train data variance: %s' % train_data_variance)
train data variance: 0.06327039811675479
エンコーダ & デコーダ・アーキテクチャ
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(ResidualStack, self).__init__(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self._layers = []
for i in range(num_residual_layers):
conv3 = snt.Conv2D(
output_channels=num_residual_hiddens,
kernel_shape=(3, 3),
stride=(1, 1),
name="res3x3_%d" % i)
conv1 = snt.Conv2D(
output_channels=num_hiddens,
kernel_shape=(1, 1),
stride=(1, 1),
name="res1x1_%d" % i)
self._layers.append((conv3, conv1))
def __call__(self, inputs):
h = inputs
for conv3, conv1 in self._layers:
conv3_out = conv3(tf.nn.relu(h))
conv1_out = conv1(tf.nn.relu(conv3_out))
h += conv1_out
return tf.nn.relu(h) # Resnet V1 style
class Encoder(snt.Module):
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(Encoder, self).__init__(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self._enc_1 = snt.Conv2D(
output_channels=self._num_hiddens // 2,
kernel_shape=(4, 4),
stride=(2, 2),
name="enc_1")
self._enc_2 = snt.Conv2D(
output_channels=self._num_hiddens,
kernel_shape=(4, 4),
stride=(2, 2),
name="enc_2")
self._enc_3 = snt.Conv2D(
output_channels=self._num_hiddens,
kernel_shape=(3, 3),
stride=(1, 1),
name="enc_3")
self._residual_stack = ResidualStack(
self._num_hiddens,
self._num_residual_layers,
self._num_residual_hiddens)
def __call__(self, x):
h = tf.nn.relu(self._enc_1(x))
h = tf.nn.relu(self._enc_2(h))
h = tf.nn.relu(self._enc_3(h))
return self._residual_stack(h)
class Decoder(snt.Module):
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(Decoder, self).__init__(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self._dec_1 = snt.Conv2D(
output_channels=self._num_hiddens,
kernel_shape=(3, 3),
stride=(1, 1),
name="dec_1")
self._residual_stack = ResidualStack(
self._num_hiddens,
self._num_residual_layers,
self._num_residual_hiddens)
self._dec_2 = snt.Conv2DTranspose(
output_channels=self._num_hiddens // 2,
output_shape=None,
kernel_shape=(4, 4),
stride=(2, 2),
name="dec_2")
self._dec_3 = snt.Conv2DTranspose(
output_channels=3,
output_shape=None,
kernel_shape=(4, 4),
stride=(2, 2),
name="dec_3")
def __call__(self, x):
h = self._dec_1(x)
h = self._residual_stack(h)
h = tf.nn.relu(self._dec_2(h))
x_recon = self._dec_3(h)
return x_recon
class VQVAEModel(snt.Module):
def __init__(self, encoder, decoder, vqvae, pre_vq_conv1,
data_variance, name=None):
super(VQVAEModel, self).__init__(name=name)
self._encoder = encoder
self._decoder = decoder
self._vqvae = vqvae
self._pre_vq_conv1 = pre_vq_conv1
self._data_variance = data_variance
def __call__(self, inputs, is_training):
z = self._pre_vq_conv1(self._encoder(inputs))
vq_output = self._vqvae(z, is_training=is_training)
x_recon = self._decoder(vq_output['quantize'])
recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance
loss = recon_error + vq_output['loss']
return {
'z': z,
'x_recon': x_recon,
'loss': loss,
'recon_error': recon_error,
'vq_output': vq_output,
}
モデルを構築して訓練する
%%time
# Set hyper-parameters.
batch_size = 32
image_size = 32
# 100k steps should take < 30 minutes on a modern (>= 2017) GPU.
# 10k steps gives reasonable accuracy with VQVAE on Cifar10.
num_training_updates = 10000
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
# These hyper-parameters define the size of the model (number of parameters and layers).
# The hyper-parameters in the paper were (For ImageNet):
# batch_size = 128
# image_size = 128
# num_hiddens = 128
# num_residual_hiddens = 32
# num_residual_layers = 2
# This value is not that important, usually 64 works.
# This will not change the capacity in the information-bottleneck.
embedding_dim = 64
# The higher this value, the higher the capacity in the information bottleneck.
num_embeddings = 512
# commitment_cost should be set appropriately. It's often useful to try a couple
# of values. It mostly depends on the scale of the reconstruction cost
# (log p(x|z)). So if the reconstruction cost is 100x higher, the
# commitment_cost should also be multiplied with the same amount.
commitment_cost = 0.25
# Use EMA updates for the codebook (instead of the Adam optimizer).
# This typically converges faster, and makes the model less dependent on choice
# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was
# developed afterwards). See Appendix of the paper for more details.
vq_use_ema = True
# This is only used for EMA updates.
decay = 0.99
learning_rate = 3e-4
# # Data Loading.
train_dataset = (
tf.data.Dataset.from_tensor_slices(train_data_dict)
.map(cast_and_normalise_images)
.shuffle(10000)
.repeat(-1) # repeat indefinitely
.batch(batch_size, drop_remainder=True)
.prefetch(-1))
valid_dataset = (
tf.data.Dataset.from_tensor_slices(valid_data_dict)
.map(cast_and_normalise_images)
.repeat(1) # 1 epoch
.batch(batch_size)
.prefetch(-1))
# # Build modules.
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,
kernel_shape=(1, 1),
stride=(1, 1),
name="to_vq")
if vq_use_ema:
vq_vae = snt.nets.VectorQuantizerEMA(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
commitment_cost=commitment_cost,
decay=decay)
else:
vq_vae = snt.nets.VectorQuantizer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
commitment_cost=commitment_cost)
model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
data_variance=train_data_variance)
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
@tf.function
def train_step(data):
with tf.GradientTape() as tape:
model_output = model(data['image'], is_training=True)
trainable_variables = model.trainable_variables
grads = tape.gradient(model_output['loss'], trainable_variables)
optimizer.apply(grads, trainable_variables)
return model_output
train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []
for step_index, data in enumerate(train_dataset):
train_results = train_step(data)
train_losses.append(train_results['loss'])
train_recon_errors.append(train_results['recon_error'])
train_perplexities.append(train_results['vq_output']['perplexity'])
train_vqvae_loss.append(train_results['vq_output']['loss'])
if (step_index + 1) % 100 == 0:
print('%d train loss: %f ' % (step_index + 1,
np.mean(train_losses[-100:])) +
('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +
('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +
('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))
if step_index == num_training_updates:
break
WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code WARNING: AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. 100 train loss: 0.523625 recon_error: 0.483 perplexity: 10.356 vqvae loss: 0.041 200 train loss: 0.248232 recon_error: 0.223 perplexity: 18.294 vqvae loss: 0.026 300 train loss: 0.215068 recon_error: 0.190 perplexity: 23.106 vqvae loss: 0.025 400 train loss: 0.191891 recon_error: 0.164 perplexity: 29.139 vqvae loss: 0.028 500 train loss: 0.180945 recon_error: 0.147 perplexity: 34.253 vqvae loss: 0.033 600 train loss: 0.167115 recon_error: 0.134 perplexity: 39.961 vqvae loss: 0.033 700 train loss: 0.157724 recon_error: 0.124 perplexity: 46.521 vqvae loss: 0.033 800 train loss: 0.153761 recon_error: 0.119 perplexity: 53.559 vqvae loss: 0.035 900 train loss: 0.145033 recon_error: 0.112 perplexity: 62.442 vqvae loss: 0.033 1000 train loss: 0.137589 recon_error: 0.105 perplexity: 71.831 vqvae loss: 0.033 1100 train loss: 0.133044 recon_error: 0.101 perplexity: 79.135 vqvae loss: 0.032 1200 train loss: 0.129990 recon_error: 0.098 perplexity: 87.959 vqvae loss: 0.032 1300 train loss: 0.126507 recon_error: 0.095 perplexity: 96.704 vqvae loss: 0.031 1400 train loss: 0.122403 recon_error: 0.092 perplexity: 104.202 vqvae loss: 0.031 1500 train loss: 0.122003 recon_error: 0.091 perplexity: 112.476 vqvae loss: 0.031 1600 train loss: 0.120192 recon_error: 0.089 perplexity: 122.269 vqvae loss: 0.032 1700 train loss: 0.117041 recon_error: 0.086 perplexity: 129.887 vqvae loss: 0.031 1800 train loss: 0.115004 recon_error: 0.083 perplexity: 138.603 vqvae loss: 0.032 1900 train loss: 0.114134 recon_error: 0.082 perplexity: 147.545 vqvae loss: 0.032 2000 train loss: 0.112840 recon_error: 0.081 perplexity: 153.993 vqvae loss: 0.032 2100 train loss: 0.108815 recon_error: 0.077 perplexity: 161.729 vqvae loss: 0.031 2200 train loss: 0.108596 recon_error: 0.078 perplexity: 171.971 vqvae loss: 0.031 2300 train loss: 0.108132 recon_error: 0.077 perplexity: 181.157 vqvae loss: 0.031 2400 train loss: 0.106273 recon_error: 0.076 perplexity: 186.200 vqvae loss: 0.031 2500 train loss: 0.105936 recon_error: 0.075 perplexity: 194.301 vqvae loss: 0.031 2600 train loss: 0.103880 recon_error: 0.073 perplexity: 201.674 vqvae loss: 0.030 2700 train loss: 0.101655 recon_error: 0.072 perplexity: 207.131 vqvae loss: 0.030 2800 train loss: 0.102564 recon_error: 0.072 perplexity: 216.983 vqvae loss: 0.030 2900 train loss: 0.101613 recon_error: 0.072 perplexity: 219.649 vqvae loss: 0.030 3000 train loss: 0.101227 recon_error: 0.071 perplexity: 226.789 vqvae loss: 0.030 3100 train loss: 0.100786 recon_error: 0.071 perplexity: 235.522 vqvae loss: 0.030 3200 train loss: 0.100130 recon_error: 0.070 perplexity: 243.282 vqvae loss: 0.030 3300 train loss: 0.097764 recon_error: 0.067 perplexity: 249.584 vqvae loss: 0.030 3400 train loss: 0.100630 recon_error: 0.069 perplexity: 260.551 vqvae loss: 0.031 3500 train loss: 0.099929 recon_error: 0.068 perplexity: 266.012 vqvae loss: 0.032 3600 train loss: 0.099245 recon_error: 0.067 perplexity: 272.031 vqvae loss: 0.032 3700 train loss: 0.097812 recon_error: 0.066 perplexity: 279.691 vqvae loss: 0.032 3800 train loss: 0.097137 recon_error: 0.064 perplexity: 284.240 vqvae loss: 0.033 3900 train loss: 0.099217 recon_error: 0.066 perplexity: 293.507 vqvae loss: 0.034 4000 train loss: 0.098570 recon_error: 0.065 perplexity: 300.891 vqvae loss: 0.034 4100 train loss: 0.099238 recon_error: 0.065 perplexity: 306.762 vqvae loss: 0.034 4200 train loss: 0.098172 recon_error: 0.064 perplexity: 311.918 vqvae loss: 0.035 4300 train loss: 0.096449 recon_error: 0.063 perplexity: 316.246 vqvae loss: 0.034 4400 train loss: 0.096487 recon_error: 0.062 perplexity: 319.591 vqvae loss: 0.034 4500 train loss: 0.096092 recon_error: 0.062 perplexity: 322.313 vqvae loss: 0.034 4600 train loss: 0.096474 recon_error: 0.062 perplexity: 324.620 vqvae loss: 0.035 4700 train loss: 0.097075 recon_error: 0.063 perplexity: 324.357 vqvae loss: 0.035 4800 train loss: 0.094709 recon_error: 0.060 perplexity: 326.024 vqvae loss: 0.034 4900 train loss: 0.096557 recon_error: 0.061 perplexity: 327.701 vqvae loss: 0.035 5000 train loss: 0.096185 recon_error: 0.061 perplexity: 326.664 vqvae loss: 0.035 5100 train loss: 0.095646 recon_error: 0.060 perplexity: 327.617 vqvae loss: 0.035 5200 train loss: 0.094689 recon_error: 0.059 perplexity: 328.692 vqvae loss: 0.035 5300 train loss: 0.097047 recon_error: 0.061 perplexity: 327.988 vqvae loss: 0.036 5400 train loss: 0.096259 recon_error: 0.060 perplexity: 327.075 vqvae loss: 0.036 5500 train loss: 0.094588 recon_error: 0.059 perplexity: 327.083 vqvae loss: 0.036 5600 train loss: 0.095947 recon_error: 0.060 perplexity: 328.213 vqvae loss: 0.036 5700 train loss: 0.095466 recon_error: 0.059 perplexity: 329.375 vqvae loss: 0.036 5800 train loss: 0.094849 recon_error: 0.059 perplexity: 326.821 vqvae loss: 0.036 5900 train loss: 0.093799 recon_error: 0.058 perplexity: 328.409 vqvae loss: 0.036 6000 train loss: 0.095373 recon_error: 0.059 perplexity: 326.791 vqvae loss: 0.036 6100 train loss: 0.093989 recon_error: 0.059 perplexity: 325.959 vqvae loss: 0.035 6200 train loss: 0.095549 recon_error: 0.059 perplexity: 330.829 vqvae loss: 0.036 6300 train loss: 0.094730 recon_error: 0.058 perplexity: 330.906 vqvae loss: 0.036 6400 train loss: 0.095038 recon_error: 0.058 perplexity: 329.353 vqvae loss: 0.037 6500 train loss: 0.095891 recon_error: 0.059 perplexity: 330.197 vqvae loss: 0.037 6600 train loss: 0.094342 recon_error: 0.058 perplexity: 331.240 vqvae loss: 0.036 6700 train loss: 0.095096 recon_error: 0.058 perplexity: 330.618 vqvae loss: 0.037 6800 train loss: 0.095581 recon_error: 0.059 perplexity: 324.493 vqvae loss: 0.037 6900 train loss: 0.094467 recon_error: 0.058 perplexity: 328.868 vqvae loss: 0.037 7000 train loss: 0.092967 recon_error: 0.057 perplexity: 328.276 vqvae loss: 0.036 7100 train loss: 0.094339 recon_error: 0.058 perplexity: 327.318 vqvae loss: 0.037 7200 train loss: 0.095227 recon_error: 0.058 perplexity: 326.306 vqvae loss: 0.037 7300 train loss: 0.093832 recon_error: 0.057 perplexity: 328.262 vqvae loss: 0.037 7400 train loss: 0.093331 recon_error: 0.057 perplexity: 327.987 vqvae loss: 0.037 7500 train loss: 0.094718 recon_error: 0.058 perplexity: 328.948 vqvae loss: 0.037 7600 train loss: 0.094199 recon_error: 0.058 perplexity: 328.468 vqvae loss: 0.037 7700 train loss: 0.094603 recon_error: 0.058 perplexity: 327.501 vqvae loss: 0.037 7800 train loss: 0.092299 recon_error: 0.056 perplexity: 327.630 vqvae loss: 0.037 7900 train loss: 0.095228 recon_error: 0.058 perplexity: 329.946 vqvae loss: 0.037 8000 train loss: 0.094291 recon_error: 0.058 perplexity: 326.790 vqvae loss: 0.037 8100 train loss: 0.094481 recon_error: 0.057 perplexity: 328.667 vqvae loss: 0.037 8200 train loss: 0.093992 recon_error: 0.057 perplexity: 329.655 vqvae loss: 0.037 8300 train loss: 0.093976 recon_error: 0.057 perplexity: 323.950 vqvae loss: 0.037 8400 train loss: 0.093422 recon_error: 0.057 perplexity: 324.523 vqvae loss: 0.036 8500 train loss: 0.092898 recon_error: 0.056 perplexity: 325.402 vqvae loss: 0.037 8600 train loss: 0.094298 recon_error: 0.057 perplexity: 329.251 vqvae loss: 0.037 8700 train loss: 0.094489 recon_error: 0.057 perplexity: 331.027 vqvae loss: 0.037 8800 train loss: 0.093022 recon_error: 0.056 perplexity: 327.495 vqvae loss: 0.037 8900 train loss: 0.093427 recon_error: 0.057 perplexity: 328.008 vqvae loss: 0.037 9000 train loss: 0.094884 recon_error: 0.058 perplexity: 327.057 vqvae loss: 0.037 9100 train loss: 0.093559 recon_error: 0.056 perplexity: 331.800 vqvae loss: 0.037 9200 train loss: 0.093282 recon_error: 0.056 perplexity: 328.689 vqvae loss: 0.037 9300 train loss: 0.092217 recon_error: 0.056 perplexity: 323.903 vqvae loss: 0.036 9400 train loss: 0.093902 recon_error: 0.057 perplexity: 326.350 vqvae loss: 0.037 9500 train loss: 0.093772 recon_error: 0.057 perplexity: 325.627 vqvae loss: 0.037 9600 train loss: 0.093123 recon_error: 0.056 perplexity: 327.352 vqvae loss: 0.037 9700 train loss: 0.092934 recon_error: 0.056 perplexity: 328.674 vqvae loss: 0.037 9800 train loss: 0.093284 recon_error: 0.056 perplexity: 329.437 vqvae loss: 0.037 9900 train loss: 0.094147 recon_error: 0.057 perplexity: 330.146 vqvae loss: 0.037 10000 train loss: 0.092876 recon_error: 0.056 perplexity: 326.349 vqvae loss: 0.037 CPU times: user 1h 47min 46s, sys: 14min 12s, total: 2h 1min 59s Wall time: 4min 29s
損失をプロットする
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')
ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')
Text(0.5, 1.0, 'Average codebook usage (perplexity).')

再構築を見る
# Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))
# Put data through the model with is_training=False, so that in the case of
# using EMA the codebook is not updated.
train_reconstructions = model(train_batch['image'],
is_training=False)['x_recon'].numpy()
valid_reconstructions = model(valid_batch['image'],
is_training=False)['x_recon'].numpy()
def convert_batch_to_image_grid(image_batch):
reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
.transpose(0, 2, 1, 3, 4)
.reshape(4 * 32, 8 * 32, 3))
return reshaped + 0.5
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['image'].numpy()),
interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')
ax = f.add_subplot(2,2,2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions),
interpolation='nearest')
ax.set_title('training data reconstructions')
plt.axis('off')
ax = f.add_subplot(2,2,3)
ax.imshow(convert_batch_to_image_grid(valid_batch['image'].numpy()),
interpolation='nearest')
ax.set_title('validation data originals')
plt.axis('off')
ax = f.add_subplot(2,2,4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions),
interpolation='nearest')
ax.set_title('validation data reconstructions')
plt.axis('off')
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 255.5, 127.5, -0.5)

以上