Keras 2 : examples : 生成深層学習 – Real NVPによる密度推定 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/18/2022 (keras 2.9.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Generative Deep Learning : Density estimation using Real NVP (Author: Mandolini Giorgio Maria, Sanna Daniele, Zannini Quirini Giorgio)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
クラスキャット 人工知能 研究開発支援サービス
◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 生成深層学習 – Real NVPによる密度推定
Description : “double moon” データセットの密度分布の推定。
イントロダクション
このワークの目的は単純な分布 – これはサンプリングが容易でその密度は単純に推定できます – をデータから学習された、より複雑なものにマップすることです。この種類の生成モデルは “normalizing flow” としても知られています。
これを行なうため、「変数変換 (= change of variable)」公式を使用して、最尤原理によりモデルは訓練されます。
アフィン・カップリング関数を使用します。その逆もヤコビ行列の行列式とともに簡単に取得できるようにそれを作成します (詳細は参照論文で)。
要件 :
- Tensorflow 2.9.1
- Tensorflow probability 0.17.0
Reference :
セットアップ
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from sklearn.datasets import make_moons
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
データのロード
data = make_moons(3000, noise=0.05)[0].astype("float32")
norm = layers.Normalization()
norm.adapt(data)
normalized_data = norm(data)
アフィン・カップリング層
# Creating a custom layer with keras API.
output_dim = 256
reg = 0.01
def Coupling(input_shape):
input = keras.layers.Input(shape=input_shape)
t_layer_1 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(input)
t_layer_2 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(t_layer_1)
t_layer_3 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(t_layer_2)
t_layer_4 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(t_layer_3)
t_layer_5 = keras.layers.Dense(
input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg)
)(t_layer_4)
s_layer_1 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(input)
s_layer_2 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(s_layer_1)
s_layer_3 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(s_layer_2)
s_layer_4 = keras.layers.Dense(
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
)(s_layer_3)
s_layer_5 = keras.layers.Dense(
input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg)
)(s_layer_4)
return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5])
Real NVP
class RealNVP(keras.Model):
def __init__(self, num_coupling_layers):
super(RealNVP, self).__init__()
self.num_coupling_layers = num_coupling_layers
# Distribution of the latent space.
self.distribution = tfp.distributions.MultivariateNormalDiag(
loc=[0.0, 0.0], scale_diag=[1.0, 1.0]
)
self.masks = np.array(
[[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32"
)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.layers_list = [Coupling(2) for i in range(num_coupling_layers)]
@property
def metrics(self):
"""List of the model's metrics.
We make sure the loss tracker is listed as part of `model.metrics`
so that `fit()` and `evaluate()` are able to `reset()` the loss tracker
at the start of each epoch and at the start of an `evaluate()` call.
"""
return [self.loss_tracker]
def call(self, x, training=True):
log_det_inv = 0
direction = 1
if training:
direction = -1
for i in range(self.num_coupling_layers)[::direction]:
x_masked = x * self.masks[i]
reversed_mask = 1 - self.masks[i]
s, t = self.layers_list[i](x_masked)
s *= reversed_mask
t *= reversed_mask
gate = (direction - 1) / 2
x = (
reversed_mask
* (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))
+ x_masked
)
log_det_inv += gate * tf.reduce_sum(s, [1])
return x, log_det_inv
# Log likelihood of the normal distribution plus the log determinant of the jacobian.
def log_loss(self, x):
y, logdet = self(x)
log_likelihood = self.distribution.log_prob(y) + logdet
return -tf.reduce_mean(log_likelihood)
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self.log_loss(data)
g = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(g, self.trainable_variables))
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def test_step(self, data):
loss = self.log_loss(data)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
モデル訓練
model = RealNVP(num_coupling_layers=6)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))
history = model.fit(
normalized_data, batch_size=256, epochs=300, verbose=2, validation_split=0.2
)
Epoch 1/300 10/10 - 2s - loss: 2.7104 - val_loss: 2.6385 - 2s/epoch - 248ms/step Epoch 2/300 10/10 - 0s - loss: 2.5951 - val_loss: 2.5818 - 162ms/epoch - 16ms/step Epoch 3/300 10/10 - 0s - loss: 2.5487 - val_loss: 2.5299 - 165ms/epoch - 17ms/step Epoch 4/300 10/10 - 0s - loss: 2.5081 - val_loss: 2.4989 - 150ms/epoch - 15ms/step Epoch 5/300 10/10 - 0s - loss: 2.4729 - val_loss: 2.4641 - 168ms/epoch - 17ms/step Epoch 6/300 10/10 - 0s - loss: 2.4457 - val_loss: 2.4443 - 155ms/epoch - 16ms/step Epoch 7/300 10/10 - 0s - loss: 2.4183 - val_loss: 2.4078 - 155ms/epoch - 16ms/step Epoch 8/300 10/10 - 0s - loss: 2.3840 - val_loss: 2.3852 - 160ms/epoch - 16ms/step Epoch 9/300 10/10 - 0s - loss: 2.3604 - val_loss: 2.3700 - 172ms/epoch - 17ms/step Epoch 10/300 10/10 - 0s - loss: 2.3392 - val_loss: 2.3354 - 156ms/epoch - 16ms/step Epoch 11/300 10/10 - 0s - loss: 2.3042 - val_loss: 2.3099 - 170ms/epoch - 17ms/step Epoch 12/300 10/10 - 0s - loss: 2.2769 - val_loss: 2.3126 - 171ms/epoch - 17ms/step Epoch 13/300 10/10 - 0s - loss: 2.2541 - val_loss: 2.2784 - 174ms/epoch - 17ms/step Epoch 14/300 10/10 - 0s - loss: 2.2175 - val_loss: 2.2354 - 174ms/epoch - 17ms/step Epoch 15/300 10/10 - 0s - loss: 2.1957 - val_loss: 2.1990 - 173ms/epoch - 17ms/step Epoch 16/300 10/10 - 0s - loss: 2.1533 - val_loss: 2.1686 - 167ms/epoch - 17ms/step Epoch 17/300 10/10 - 0s - loss: 2.1232 - val_loss: 2.1276 - 178ms/epoch - 18ms/step Epoch 18/300 10/10 - 0s - loss: 2.0932 - val_loss: 2.1220 - 173ms/epoch - 17ms/step Epoch 19/300 10/10 - 0s - loss: 2.1068 - val_loss: 2.1515 - 152ms/epoch - 15ms/step Epoch 20/300 10/10 - 0s - loss: 2.0793 - val_loss: 2.1661 - 161ms/epoch - 16ms/step Epoch 21/300 10/10 - 0s - loss: 2.0784 - val_loss: 2.0634 - 180ms/epoch - 18ms/step Epoch 22/300 10/10 - 0s - loss: 2.0060 - val_loss: 2.0076 - 160ms/epoch - 16ms/step Epoch 23/300 10/10 - 0s - loss: 1.9845 - val_loss: 1.9773 - 174ms/epoch - 17ms/step Epoch 24/300 10/10 - 0s - loss: 1.9462 - val_loss: 2.0097 - 174ms/epoch - 17ms/step Epoch 25/300 10/10 - 0s - loss: 1.8892 - val_loss: 1.9023 - 173ms/epoch - 17ms/step Epoch 26/300 10/10 - 0s - loss: 1.8011 - val_loss: 1.8128 - 182ms/epoch - 18ms/step Epoch 27/300 10/10 - 0s - loss: 1.7604 - val_loss: 1.8415 - 167ms/epoch - 17ms/step Epoch 28/300 10/10 - 0s - loss: 1.7474 - val_loss: 1.7635 - 172ms/epoch - 17ms/step Epoch 29/300 10/10 - 0s - loss: 1.7313 - val_loss: 1.7154 - 175ms/epoch - 18ms/step Epoch 30/300 10/10 - 0s - loss: 1.6801 - val_loss: 1.7269 - 183ms/epoch - 18ms/step Epoch 31/300 10/10 - 0s - loss: 1.6892 - val_loss: 1.6588 - 170ms/epoch - 17ms/step Epoch 32/300 10/10 - 0s - loss: 1.6384 - val_loss: 1.6467 - 159ms/epoch - 16ms/step Epoch 33/300 10/10 - 0s - loss: 1.6263 - val_loss: 1.6214 - 164ms/epoch - 16ms/step Epoch 34/300 10/10 - 0s - loss: 1.6035 - val_loss: 1.6022 - 154ms/epoch - 15ms/step Epoch 35/300 10/10 - 0s - loss: 1.5872 - val_loss: 1.6203 - 159ms/epoch - 16ms/step Epoch 36/300 10/10 - 0s - loss: 1.5928 - val_loss: 1.6312 - 159ms/epoch - 16ms/step Epoch 37/300 10/10 - 0s - loss: 1.5895 - val_loss: 1.6337 - 148ms/epoch - 15ms/step Epoch 38/300 10/10 - 0s - loss: 1.5726 - val_loss: 1.6192 - 153ms/epoch - 15ms/step Epoch 39/300 10/10 - 0s - loss: 1.5537 - val_loss: 1.5919 - 168ms/epoch - 17ms/step Epoch 40/300 10/10 - 0s - loss: 1.5741 - val_loss: 1.6646 - 173ms/epoch - 17ms/step Epoch 41/300 10/10 - 0s - loss: 1.5737 - val_loss: 1.5718 - 181ms/epoch - 18ms/step Epoch 42/300 10/10 - 0s - loss: 1.5573 - val_loss: 1.6395 - 173ms/epoch - 17ms/step Epoch 43/300 10/10 - 0s - loss: 1.5574 - val_loss: 1.5779 - 183ms/epoch - 18ms/step Epoch 44/300 10/10 - 0s - loss: 1.5345 - val_loss: 1.5549 - 173ms/epoch - 17ms/step Epoch 45/300 10/10 - 0s - loss: 1.5256 - val_loss: 1.5944 - 161ms/epoch - 16ms/step Epoch 46/300 10/10 - 0s - loss: 1.5291 - val_loss: 1.5325 - 169ms/epoch - 17ms/step Epoch 47/300 10/10 - 0s - loss: 1.5341 - val_loss: 1.5929 - 177ms/epoch - 18ms/step Epoch 48/300 10/10 - 0s - loss: 1.5190 - val_loss: 1.5563 - 174ms/epoch - 17ms/step Epoch 49/300 10/10 - 0s - loss: 1.5059 - val_loss: 1.5079 - 187ms/epoch - 19ms/step Epoch 50/300 10/10 - 0s - loss: 1.4971 - val_loss: 1.5163 - 177ms/epoch - 18ms/step Epoch 51/300 10/10 - 0s - loss: 1.4923 - val_loss: 1.5549 - 168ms/epoch - 17ms/step Epoch 52/300 10/10 - 0s - loss: 1.5345 - val_loss: 1.7131 - 171ms/epoch - 17ms/step Epoch 53/300 10/10 - 0s - loss: 1.5381 - val_loss: 1.5102 - 174ms/epoch - 17ms/step Epoch 54/300 10/10 - 0s - loss: 1.4955 - val_loss: 1.5432 - 167ms/epoch - 17ms/step Epoch 55/300 10/10 - 0s - loss: 1.4829 - val_loss: 1.5166 - 172ms/epoch - 17ms/step Epoch 56/300 10/10 - 0s - loss: 1.4672 - val_loss: 1.5297 - 180ms/epoch - 18ms/step Epoch 57/300 10/10 - 0s - loss: 1.4814 - val_loss: 1.5115 - 166ms/epoch - 17ms/step Epoch 58/300 10/10 - 0s - loss: 1.4738 - val_loss: 1.5143 - 165ms/epoch - 17ms/step Epoch 59/300 10/10 - 0s - loss: 1.4639 - val_loss: 1.5326 - 175ms/epoch - 17ms/step Epoch 60/300 10/10 - 0s - loss: 1.4727 - val_loss: 1.5419 - 175ms/epoch - 18ms/step Epoch 61/300 10/10 - 0s - loss: 1.4610 - val_loss: 1.4663 - 177ms/epoch - 18ms/step Epoch 62/300 10/10 - 0s - loss: 1.4512 - val_loss: 1.5624 - 179ms/epoch - 18ms/step Epoch 63/300 10/10 - 0s - loss: 1.4816 - val_loss: 1.5711 - 176ms/epoch - 18ms/step Epoch 64/300 10/10 - 0s - loss: 1.4735 - val_loss: 1.4988 - 181ms/epoch - 18ms/step Epoch 65/300 10/10 - 0s - loss: 1.4443 - val_loss: 1.4850 - 185ms/epoch - 19ms/step Epoch 66/300 10/10 - 0s - loss: 1.4441 - val_loss: 1.5275 - 179ms/epoch - 18ms/step Epoch 67/300 10/10 - 0s - loss: 1.4751 - val_loss: 1.5191 - 177ms/epoch - 18ms/step Epoch 68/300 10/10 - 0s - loss: 1.4874 - val_loss: 1.4888 - 175ms/epoch - 18ms/step Epoch 69/300 10/10 - 0s - loss: 1.4442 - val_loss: 1.5044 - 167ms/epoch - 17ms/step Epoch 70/300 10/10 - 0s - loss: 1.4645 - val_loss: 1.4801 - 174ms/epoch - 17ms/step Epoch 71/300 10/10 - 0s - loss: 1.4648 - val_loss: 1.5016 - 174ms/epoch - 17ms/step Epoch 72/300 10/10 - 0s - loss: 1.4336 - val_loss: 1.4970 - 171ms/epoch - 17ms/step Epoch 73/300 10/10 - 0s - loss: 1.4852 - val_loss: 1.4561 - 176ms/epoch - 18ms/step Epoch 74/300 10/10 - 0s - loss: 1.4656 - val_loss: 1.5156 - 167ms/epoch - 17ms/step Epoch 75/300 10/10 - 0s - loss: 1.4359 - val_loss: 1.4154 - 175ms/epoch - 17ms/step Epoch 76/300 10/10 - 0s - loss: 1.5187 - val_loss: 1.5395 - 182ms/epoch - 18ms/step Epoch 77/300 10/10 - 0s - loss: 1.5554 - val_loss: 1.5524 - 179ms/epoch - 18ms/step Epoch 78/300 10/10 - 0s - loss: 1.4679 - val_loss: 1.4742 - 175ms/epoch - 18ms/step Epoch 79/300 10/10 - 0s - loss: 1.4433 - val_loss: 1.5862 - 177ms/epoch - 18ms/step Epoch 80/300 10/10 - 0s - loss: 1.4775 - val_loss: 1.5030 - 189ms/epoch - 19ms/step Epoch 81/300 10/10 - 0s - loss: 1.4020 - val_loss: 1.5264 - 169ms/epoch - 17ms/step Epoch 82/300 10/10 - 0s - loss: 1.4298 - val_loss: 1.4841 - 170ms/epoch - 17ms/step Epoch 83/300 10/10 - 0s - loss: 1.4329 - val_loss: 1.3966 - 177ms/epoch - 18ms/step Epoch 84/300 10/10 - 0s - loss: 1.4106 - val_loss: 1.4472 - 183ms/epoch - 18ms/step Epoch 85/300 10/10 - 0s - loss: 1.3902 - val_loss: 1.5917 - 174ms/epoch - 17ms/step Epoch 86/300 10/10 - 0s - loss: 1.6401 - val_loss: 1.6188 - 181ms/epoch - 18ms/step Epoch 87/300 10/10 - 0s - loss: 1.5748 - val_loss: 1.5913 - 177ms/epoch - 18ms/step Epoch 88/300 10/10 - 0s - loss: 1.5449 - val_loss: 1.5437 - 185ms/epoch - 19ms/step Epoch 89/300 10/10 - 0s - loss: 1.4769 - val_loss: 1.5344 - 185ms/epoch - 19ms/step Epoch 90/300 10/10 - 0s - loss: 1.4805 - val_loss: 1.4814 - 173ms/epoch - 17ms/step Epoch 91/300 10/10 - 0s - loss: 1.4540 - val_loss: 1.5087 - 170ms/epoch - 17ms/step Epoch 92/300 10/10 - 0s - loss: 1.4266 - val_loss: 1.4554 - 169ms/epoch - 17ms/step Epoch 93/300 10/10 - 0s - loss: 1.4014 - val_loss: 1.4492 - 185ms/epoch - 19ms/step Epoch 94/300 10/10 - 0s - loss: 1.3701 - val_loss: 1.3875 - 182ms/epoch - 18ms/step Epoch 95/300 10/10 - 0s - loss: 1.3792 - val_loss: 1.4288 - 193ms/epoch - 19ms/step Epoch 96/300 10/10 - 0s - loss: 1.3813 - val_loss: 1.4452 - 180ms/epoch - 18ms/step Epoch 97/300 10/10 - 0s - loss: 1.3505 - val_loss: 1.3954 - 173ms/epoch - 17ms/step Epoch 98/300 10/10 - 0s - loss: 1.3870 - val_loss: 1.6328 - 178ms/epoch - 18ms/step Epoch 99/300 10/10 - 0s - loss: 1.5100 - val_loss: 1.5139 - 174ms/epoch - 17ms/step Epoch 100/300 10/10 - 0s - loss: 1.4355 - val_loss: 1.4654 - 176ms/epoch - 18ms/step Epoch 101/300 10/10 - 0s - loss: 1.3967 - val_loss: 1.4168 - 156ms/epoch - 16ms/step Epoch 102/300 10/10 - 0s - loss: 1.3466 - val_loss: 1.3765 - 164ms/epoch - 16ms/step Epoch 103/300 10/10 - 0s - loss: 1.3188 - val_loss: 1.3783 - 182ms/epoch - 18ms/step Epoch 104/300 10/10 - 0s - loss: 1.3659 - val_loss: 1.4572 - 190ms/epoch - 19ms/step Epoch 105/300 10/10 - 0s - loss: 1.6089 - val_loss: 1.6353 - 184ms/epoch - 18ms/step Epoch 106/300 10/10 - 0s - loss: 1.6317 - val_loss: 1.6007 - 171ms/epoch - 17ms/step Epoch 107/300 10/10 - 0s - loss: 1.5652 - val_loss: 1.5409 - 184ms/epoch - 18ms/step Epoch 108/300 10/10 - 0s - loss: 1.5273 - val_loss: 1.5030 - 165ms/epoch - 17ms/step Epoch 109/300 10/10 - 0s - loss: 1.4750 - val_loss: 1.4796 - 179ms/epoch - 18ms/step Epoch 110/300 10/10 - 0s - loss: 1.4710 - val_loss: 1.4996 - 175ms/epoch - 18ms/step Epoch 111/300 10/10 - 0s - loss: 1.4805 - val_loss: 1.5006 - 179ms/epoch - 18ms/step Epoch 112/300 10/10 - 0s - loss: 1.4518 - val_loss: 1.5023 - 184ms/epoch - 18ms/step Epoch 113/300 10/10 - 0s - loss: 1.4313 - val_loss: 1.4234 - 162ms/epoch - 16ms/step Epoch 114/300 10/10 - 0s - loss: 1.4113 - val_loss: 1.4629 - 178ms/epoch - 18ms/step Epoch 115/300 10/10 - 0s - loss: 1.3999 - val_loss: 1.4300 - 170ms/epoch - 17ms/step Epoch 116/300 10/10 - 0s - loss: 1.3886 - val_loss: 1.4042 - 179ms/epoch - 18ms/step Epoch 117/300 10/10 - 0s - loss: 1.3659 - val_loss: 1.4245 - 182ms/epoch - 18ms/step Epoch 118/300 10/10 - 0s - loss: 1.3605 - val_loss: 1.4482 - 162ms/epoch - 16ms/step Epoch 119/300 10/10 - 0s - loss: 1.4003 - val_loss: 1.4756 - 163ms/epoch - 16ms/step Epoch 120/300 10/10 - 0s - loss: 1.3749 - val_loss: 1.4237 - 189ms/epoch - 19ms/step Epoch 121/300 10/10 - 0s - loss: 1.3323 - val_loss: 1.3580 - 189ms/epoch - 19ms/step Epoch 122/300 10/10 - 0s - loss: 1.3464 - val_loss: 1.3684 - 187ms/epoch - 19ms/step Epoch 123/300 10/10 - 0s - loss: 1.3430 - val_loss: 1.3345 - 183ms/epoch - 18ms/step Epoch 124/300 10/10 - 0s - loss: 1.3402 - val_loss: 1.4077 - 183ms/epoch - 18ms/step Epoch 125/300 10/10 - 0s - loss: 1.3861 - val_loss: 1.4208 - 165ms/epoch - 17ms/step Epoch 126/300 10/10 - 0s - loss: 1.3665 - val_loss: 1.4796 - 163ms/epoch - 16ms/step Epoch 127/300 10/10 - 0s - loss: 1.3912 - val_loss: 1.4770 - 169ms/epoch - 17ms/step Epoch 128/300 10/10 - 0s - loss: 1.4114 - val_loss: 1.4261 - 166ms/epoch - 17ms/step Epoch 129/300 10/10 - 0s - loss: 1.3687 - val_loss: 1.4488 - 165ms/epoch - 17ms/step Epoch 130/300 10/10 - 0s - loss: 1.3576 - val_loss: 1.4333 - 173ms/epoch - 17ms/step Epoch 131/300 10/10 - 0s - loss: 1.3413 - val_loss: 1.4298 - 180ms/epoch - 18ms/step Epoch 132/300 10/10 - 0s - loss: 1.3331 - val_loss: 1.4388 - 190ms/epoch - 19ms/step Epoch 133/300 10/10 - 0s - loss: 1.5913 - val_loss: 1.5962 - 188ms/epoch - 19ms/step Epoch 134/300 10/10 - 0s - loss: 1.6076 - val_loss: 1.5921 - 179ms/epoch - 18ms/step Epoch 135/300 10/10 - 0s - loss: 1.5387 - val_loss: 1.5856 - 183ms/epoch - 18ms/step Epoch 136/300 10/10 - 0s - loss: 1.5088 - val_loss: 1.5209 - 159ms/epoch - 16ms/step Epoch 137/300 10/10 - 0s - loss: 1.4640 - val_loss: 1.4599 - 175ms/epoch - 18ms/step Epoch 138/300 10/10 - 0s - loss: 1.4140 - val_loss: 1.4659 - 177ms/epoch - 18ms/step Epoch 139/300 10/10 - 0s - loss: 1.4138 - val_loss: 1.4327 - 179ms/epoch - 18ms/step Epoch 140/300 10/10 - 0s - loss: 1.3911 - val_loss: 1.4366 - 178ms/epoch - 18ms/step Epoch 141/300 10/10 - 0s - loss: 1.3870 - val_loss: 1.3962 - 182ms/epoch - 18ms/step Epoch 142/300 10/10 - 0s - loss: 1.3699 - val_loss: 1.4742 - 154ms/epoch - 15ms/step Epoch 143/300 10/10 - 0s - loss: 1.3630 - val_loss: 1.3963 - 158ms/epoch - 16ms/step Epoch 144/300 10/10 - 0s - loss: 1.3818 - val_loss: 1.4538 - 184ms/epoch - 18ms/step Epoch 145/300 10/10 - 0s - loss: 1.3564 - val_loss: 1.4057 - 182ms/epoch - 18ms/step Epoch 146/300 10/10 - 0s - loss: 1.3353 - val_loss: 1.4064 - 186ms/epoch - 19ms/step Epoch 147/300 10/10 - 0s - loss: 1.3285 - val_loss: 1.3835 - 172ms/epoch - 17ms/step Epoch 148/300 10/10 - 0s - loss: 1.3100 - val_loss: 1.3817 - 188ms/epoch - 19ms/step Epoch 149/300 10/10 - 0s - loss: 1.3761 - val_loss: 1.4566 - 189ms/epoch - 19ms/step Epoch 150/300 10/10 - 0s - loss: 1.3473 - val_loss: 1.4378 - 188ms/epoch - 19ms/step Epoch 151/300 10/10 - 0s - loss: 1.3106 - val_loss: 1.3616 - 182ms/epoch - 18ms/step Epoch 152/300 10/10 - 0s - loss: 1.3239 - val_loss: 1.3468 - 177ms/epoch - 18ms/step Epoch 153/300 10/10 - 0s - loss: 1.2947 - val_loss: 1.3523 - 172ms/epoch - 17ms/step Epoch 154/300 10/10 - 0s - loss: 1.2850 - val_loss: 1.3530 - 170ms/epoch - 17ms/step Epoch 155/300 10/10 - 0s - loss: 1.2834 - val_loss: 1.3878 - 171ms/epoch - 17ms/step Epoch 156/300 10/10 - 0s - loss: 1.3192 - val_loss: 1.3747 - 179ms/epoch - 18ms/step Epoch 157/300 10/10 - 0s - loss: 1.3567 - val_loss: 1.4031 - 174ms/epoch - 17ms/step Epoch 158/300 10/10 - 0s - loss: 1.3240 - val_loss: 1.3735 - 167ms/epoch - 17ms/step Epoch 159/300 10/10 - 0s - loss: 1.3272 - val_loss: 1.4563 - 183ms/epoch - 18ms/step Epoch 160/300 10/10 - 0s - loss: 1.3329 - val_loss: 1.3321 - 179ms/epoch - 18ms/step Epoch 161/300 10/10 - 0s - loss: 1.3120 - val_loss: 1.3779 - 185ms/epoch - 19ms/step Epoch 162/300 10/10 - 0s - loss: 1.3093 - val_loss: 1.3739 - 191ms/epoch - 19ms/step Epoch 163/300 10/10 - 0s - loss: 1.3251 - val_loss: 1.4781 - 182ms/epoch - 18ms/step Epoch 164/300 10/10 - 0s - loss: 1.3762 - val_loss: 1.4035 - 165ms/epoch - 17ms/step Epoch 165/300 10/10 - 0s - loss: 1.3655 - val_loss: 1.3693 - 189ms/epoch - 19ms/step Epoch 166/300 10/10 - 0s - loss: 1.3453 - val_loss: 1.3694 - 170ms/epoch - 17ms/step Epoch 167/300 10/10 - 0s - loss: 1.3019 - val_loss: 1.3496 - 180ms/epoch - 18ms/step Epoch 168/300 10/10 - 0s - loss: 1.2801 - val_loss: 1.3375 - 190ms/epoch - 19ms/step Epoch 169/300 10/10 - 0s - loss: 1.2966 - val_loss: 1.3712 - 178ms/epoch - 18ms/step Epoch 170/300 10/10 - 0s - loss: 1.3060 - val_loss: 1.3237 - 177ms/epoch - 18ms/step Epoch 171/300 10/10 - 0s - loss: 1.3299 - val_loss: 1.5022 - 177ms/epoch - 18ms/step Epoch 172/300 10/10 - 0s - loss: 1.3665 - val_loss: 1.4224 - 186ms/epoch - 19ms/step Epoch 173/300 10/10 - 0s - loss: 1.3432 - val_loss: 1.5198 - 172ms/epoch - 17ms/step Epoch 174/300 10/10 - 0s - loss: 1.3434 - val_loss: 1.4113 - 188ms/epoch - 19ms/step Epoch 175/300 10/10 - 0s - loss: 1.3016 - val_loss: 1.3920 - 175ms/epoch - 18ms/step Epoch 176/300 10/10 - 0s - loss: 1.2833 - val_loss: 1.4342 - 166ms/epoch - 17ms/step Epoch 177/300 10/10 - 0s - loss: 1.3334 - val_loss: 1.4225 - 178ms/epoch - 18ms/step Epoch 178/300 10/10 - 0s - loss: 1.4085 - val_loss: 1.4848 - 170ms/epoch - 17ms/step Epoch 179/300 10/10 - 0s - loss: 1.4262 - val_loss: 1.5149 - 176ms/epoch - 18ms/step Epoch 180/300 10/10 - 0s - loss: 1.4076 - val_loss: 1.5736 - 175ms/epoch - 18ms/step Epoch 181/300 10/10 - 0s - loss: 1.5085 - val_loss: 1.6339 - 179ms/epoch - 18ms/step Epoch 182/300 10/10 - 0s - loss: 1.5028 - val_loss: 1.5327 - 179ms/epoch - 18ms/step Epoch 183/300 10/10 - 0s - loss: 1.4710 - val_loss: 1.4611 - 196ms/epoch - 20ms/step Epoch 184/300 10/10 - 0s - loss: 1.3950 - val_loss: 1.4205 - 190ms/epoch - 19ms/step Epoch 185/300 10/10 - 0s - loss: 1.3815 - val_loss: 1.4100 - 187ms/epoch - 19ms/step Epoch 186/300 10/10 - 0s - loss: 1.3722 - val_loss: 1.3939 - 163ms/epoch - 16ms/step Epoch 187/300 10/10 - 0s - loss: 1.3379 - val_loss: 1.3922 - 194ms/epoch - 19ms/step Epoch 188/300 10/10 - 0s - loss: 1.3406 - val_loss: 1.3874 - 189ms/epoch - 19ms/step Epoch 189/300 10/10 - 0s - loss: 1.4787 - val_loss: 1.5603 - 190ms/epoch - 19ms/step Epoch 190/300 10/10 - 0s - loss: 1.4652 - val_loss: 1.4490 - 163ms/epoch - 16ms/step Epoch 191/300 10/10 - 0s - loss: 1.3868 - val_loss: 1.4725 - 179ms/epoch - 18ms/step Epoch 192/300 10/10 - 0s - loss: 1.3470 - val_loss: 1.4088 - 186ms/epoch - 19ms/step Epoch 193/300 10/10 - 0s - loss: 1.3576 - val_loss: 1.3549 - 193ms/epoch - 19ms/step Epoch 194/300 10/10 - 0s - loss: 1.3574 - val_loss: 1.4884 - 188ms/epoch - 19ms/step Epoch 195/300 10/10 - 0s - loss: 1.4376 - val_loss: 1.4794 - 172ms/epoch - 17ms/step Epoch 196/300 10/10 - 0s - loss: 1.4110 - val_loss: 1.5064 - 175ms/epoch - 18ms/step Epoch 197/300 10/10 - 0s - loss: 1.3597 - val_loss: 1.3742 - 159ms/epoch - 16ms/step Epoch 198/300 10/10 - 0s - loss: 1.3897 - val_loss: 1.4465 - 188ms/epoch - 19ms/step Epoch 199/300 10/10 - 0s - loss: 1.3710 - val_loss: 1.3469 - 175ms/epoch - 18ms/step Epoch 200/300 10/10 - 0s - loss: 1.3613 - val_loss: 1.4129 - 183ms/epoch - 18ms/step Epoch 201/300 10/10 - 0s - loss: 1.3581 - val_loss: 1.4100 - 189ms/epoch - 19ms/step Epoch 202/300 10/10 - 0s - loss: 1.3047 - val_loss: 1.3460 - 164ms/epoch - 16ms/step Epoch 203/300 10/10 - 0s - loss: 1.3133 - val_loss: 1.3942 - 185ms/epoch - 18ms/step Epoch 204/300 10/10 - 0s - loss: 1.3880 - val_loss: 1.4730 - 179ms/epoch - 18ms/step Epoch 205/300 10/10 - 0s - loss: 1.4233 - val_loss: 1.5020 - 196ms/epoch - 20ms/step Epoch 206/300 10/10 - 0s - loss: 1.3696 - val_loss: 1.4541 - 188ms/epoch - 19ms/step Epoch 207/300 10/10 - 0s - loss: 1.3189 - val_loss: 1.4825 - 181ms/epoch - 18ms/step Epoch 208/300 10/10 - 0s - loss: 1.7335 - val_loss: 1.7628 - 170ms/epoch - 17ms/step Epoch 209/300 10/10 - 0s - loss: 1.6927 - val_loss: 1.6906 - 180ms/epoch - 18ms/step Epoch 210/300 10/10 - 0s - loss: 1.6293 - val_loss: 1.6065 - 191ms/epoch - 19ms/step Epoch 211/300 10/10 - 0s - loss: 1.5564 - val_loss: 1.5873 - 179ms/epoch - 18ms/step Epoch 212/300 10/10 - 0s - loss: 1.5258 - val_loss: 1.5561 - 188ms/epoch - 19ms/step Epoch 213/300 10/10 - 0s - loss: 1.4918 - val_loss: 1.5715 - 175ms/epoch - 17ms/step Epoch 214/300 10/10 - 0s - loss: 1.4800 - val_loss: 1.5373 - 166ms/epoch - 17ms/step Epoch 215/300 10/10 - 0s - loss: 1.4669 - val_loss: 1.5262 - 171ms/epoch - 17ms/step Epoch 216/300 10/10 - 0s - loss: 1.4492 - val_loss: 1.4965 - 168ms/epoch - 17ms/step Epoch 217/300 10/10 - 0s - loss: 1.4169 - val_loss: 1.4874 - 160ms/epoch - 16ms/step Epoch 218/300 10/10 - 0s - loss: 1.4192 - val_loss: 1.4848 - 175ms/epoch - 18ms/step Epoch 219/300 10/10 - 0s - loss: 1.4072 - val_loss: 1.4776 - 167ms/epoch - 17ms/step Epoch 220/300 10/10 - 0s - loss: 1.3936 - val_loss: 1.4824 - 163ms/epoch - 16ms/step Epoch 221/300 10/10 - 0s - loss: 1.3813 - val_loss: 1.4814 - 190ms/epoch - 19ms/step Epoch 222/300 10/10 - 0s - loss: 1.3821 - val_loss: 1.4344 - 192ms/epoch - 19ms/step Epoch 223/300 10/10 - 0s - loss: 1.3724 - val_loss: 1.4691 - 197ms/epoch - 20ms/step Epoch 224/300 10/10 - 0s - loss: 1.3818 - val_loss: 1.4371 - 186ms/epoch - 19ms/step Epoch 225/300 10/10 - 0s - loss: 1.3986 - val_loss: 1.4602 - 174ms/epoch - 17ms/step Epoch 226/300 10/10 - 0s - loss: 1.3620 - val_loss: 1.4268 - 162ms/epoch - 16ms/step Epoch 227/300 10/10 - 0s - loss: 1.3658 - val_loss: 1.5127 - 162ms/epoch - 16ms/step Epoch 228/300 10/10 - 0s - loss: 1.3994 - val_loss: 1.4251 - 182ms/epoch - 18ms/step Epoch 229/300 10/10 - 0s - loss: 1.3674 - val_loss: 1.4542 - 181ms/epoch - 18ms/step Epoch 230/300 10/10 - 0s - loss: 1.3453 - val_loss: 1.4165 - 178ms/epoch - 18ms/step Epoch 231/300 10/10 - 0s - loss: 1.3473 - val_loss: 1.4112 - 185ms/epoch - 19ms/step Epoch 232/300 10/10 - 0s - loss: 1.3373 - val_loss: 1.3559 - 193ms/epoch - 19ms/step Epoch 233/300 10/10 - 0s - loss: 1.3267 - val_loss: 1.4230 - 185ms/epoch - 19ms/step Epoch 234/300 10/10 - 0s - loss: 1.4402 - val_loss: 1.5016 - 194ms/epoch - 19ms/step Epoch 235/300 10/10 - 0s - loss: 1.4497 - val_loss: 1.5198 - 182ms/epoch - 18ms/step Epoch 236/300 10/10 - 0s - loss: 1.3724 - val_loss: 1.4116 - 174ms/epoch - 17ms/step Epoch 237/300 10/10 - 0s - loss: 1.3275 - val_loss: 1.4120 - 190ms/epoch - 19ms/step Epoch 238/300 10/10 - 0s - loss: 1.4089 - val_loss: 1.4978 - 180ms/epoch - 18ms/step Epoch 239/300 10/10 - 0s - loss: 1.4203 - val_loss: 1.4340 - 197ms/epoch - 20ms/step Epoch 240/300 10/10 - 0s - loss: 1.4002 - val_loss: 1.4535 - 181ms/epoch - 18ms/step Epoch 241/300 10/10 - 0s - loss: 1.3915 - val_loss: 1.4112 - 179ms/epoch - 18ms/step Epoch 242/300 10/10 - 0s - loss: 1.4050 - val_loss: 1.4437 - 173ms/epoch - 17ms/step Epoch 243/300 10/10 - 0s - loss: 1.3834 - val_loss: 1.3841 - 183ms/epoch - 18ms/step Epoch 244/300 10/10 - 0s - loss: 1.3550 - val_loss: 1.4028 - 185ms/epoch - 19ms/step Epoch 245/300 10/10 - 0s - loss: 1.3415 - val_loss: 1.4119 - 200ms/epoch - 20ms/step Epoch 246/300 10/10 - 0s - loss: 1.3579 - val_loss: 1.4416 - 188ms/epoch - 19ms/step Epoch 247/300 10/10 - 0s - loss: 1.3397 - val_loss: 1.4257 - 173ms/epoch - 17ms/step Epoch 248/300 10/10 - 0s - loss: 1.3353 - val_loss: 1.3809 - 188ms/epoch - 19ms/step Epoch 249/300 10/10 - 0s - loss: 1.3211 - val_loss: 1.3619 - 169ms/epoch - 17ms/step Epoch 250/300 10/10 - 0s - loss: 1.3052 - val_loss: 1.3735 - 168ms/epoch - 17ms/step Epoch 251/300 10/10 - 0s - loss: 1.3121 - val_loss: 1.3636 - 183ms/epoch - 18ms/step Epoch 252/300 10/10 - 0s - loss: 1.3121 - val_loss: 1.3741 - 177ms/epoch - 18ms/step Epoch 253/300 10/10 - 0s - loss: 1.3108 - val_loss: 1.3680 - 168ms/epoch - 17ms/step Epoch 254/300 10/10 - 0s - loss: 1.3188 - val_loss: 1.4326 - 184ms/epoch - 18ms/step Epoch 255/300 10/10 - 0s - loss: 1.3111 - val_loss: 1.3853 - 183ms/epoch - 18ms/step Epoch 256/300 10/10 - 0s - loss: 1.3036 - val_loss: 1.4108 - 195ms/epoch - 20ms/step Epoch 257/300 10/10 - 0s - loss: 1.2867 - val_loss: 1.3785 - 183ms/epoch - 18ms/step Epoch 258/300 10/10 - 0s - loss: 1.2768 - val_loss: 1.3614 - 165ms/epoch - 17ms/step Epoch 259/300 10/10 - 0s - loss: 1.3092 - val_loss: 1.3846 - 176ms/epoch - 18ms/step Epoch 260/300 10/10 - 0s - loss: 1.2845 - val_loss: 1.3970 - 169ms/epoch - 17ms/step Epoch 261/300 10/10 - 0s - loss: 1.3381 - val_loss: 1.3931 - 175ms/epoch - 18ms/step Epoch 262/300 10/10 - 0s - loss: 1.3067 - val_loss: 1.3953 - 176ms/epoch - 18ms/step Epoch 263/300 10/10 - 0s - loss: 1.2947 - val_loss: 1.3783 - 170ms/epoch - 17ms/step Epoch 264/300 10/10 - 0s - loss: 1.2947 - val_loss: 1.3805 - 187ms/epoch - 19ms/step Epoch 265/300 10/10 - 0s - loss: 1.3187 - val_loss: 1.3418 - 187ms/epoch - 19ms/step Epoch 266/300 10/10 - 0s - loss: 1.2830 - val_loss: 1.4077 - 197ms/epoch - 20ms/step Epoch 267/300 10/10 - 0s - loss: 1.3008 - val_loss: 1.3461 - 198ms/epoch - 20ms/step Epoch 268/300 10/10 - 0s - loss: 1.3230 - val_loss: 1.3495 - 183ms/epoch - 18ms/step Epoch 269/300 10/10 - 0s - loss: 1.3171 - val_loss: 1.3547 - 182ms/epoch - 18ms/step Epoch 270/300 10/10 - 0s - loss: 1.3216 - val_loss: 1.4041 - 191ms/epoch - 19ms/step Epoch 271/300 10/10 - 0s - loss: 1.3147 - val_loss: 1.4394 - 182ms/epoch - 18ms/step Epoch 272/300 10/10 - 0s - loss: 1.3062 - val_loss: 1.4410 - 196ms/epoch - 20ms/step Epoch 273/300 10/10 - 0s - loss: 1.3154 - val_loss: 1.4076 - 166ms/epoch - 17ms/step Epoch 274/300 10/10 - 0s - loss: 1.2999 - val_loss: 1.3703 - 161ms/epoch - 16ms/step Epoch 275/300 10/10 - 0s - loss: 1.2730 - val_loss: 1.3523 - 179ms/epoch - 18ms/step Epoch 276/300 10/10 - 0s - loss: 1.2773 - val_loss: 1.3488 - 188ms/epoch - 19ms/step Epoch 277/300 10/10 - 0s - loss: 1.3017 - val_loss: 1.3812 - 184ms/epoch - 18ms/step Epoch 278/300 10/10 - 0s - loss: 1.2857 - val_loss: 1.4040 - 184ms/epoch - 18ms/step Epoch 279/300 10/10 - 0s - loss: 1.3243 - val_loss: 1.3774 - 181ms/epoch - 18ms/step Epoch 280/300 10/10 - 0s - loss: 1.3258 - val_loss: 1.4166 - 161ms/epoch - 16ms/step Epoch 281/300 10/10 - 0s - loss: 1.3004 - val_loss: 1.3956 - 179ms/epoch - 18ms/step Epoch 282/300 10/10 - 0s - loss: 1.3407 - val_loss: 1.3529 - 182ms/epoch - 18ms/step Epoch 283/300 10/10 - 0s - loss: 1.3269 - val_loss: 1.3986 - 183ms/epoch - 18ms/step Epoch 284/300 10/10 - 0s - loss: 1.3138 - val_loss: 1.4302 - 187ms/epoch - 19ms/step Epoch 285/300 10/10 - 0s - loss: 1.2999 - val_loss: 1.3942 - 167ms/epoch - 17ms/step Epoch 286/300 10/10 - 0s - loss: 1.2871 - val_loss: 1.4190 - 161ms/epoch - 16ms/step Epoch 287/300 10/10 - 0s - loss: 1.3094 - val_loss: 1.3905 - 176ms/epoch - 18ms/step Epoch 288/300 10/10 - 0s - loss: 1.3072 - val_loss: 1.3681 - 168ms/epoch - 17ms/step Epoch 289/300 10/10 - 0s - loss: 1.2890 - val_loss: 1.3863 - 190ms/epoch - 19ms/step Epoch 290/300 10/10 - 0s - loss: 1.2861 - val_loss: 1.4039 - 183ms/epoch - 18ms/step Epoch 291/300 10/10 - 0s - loss: 1.2845 - val_loss: 1.4018 - 162ms/epoch - 16ms/step Epoch 292/300 10/10 - 0s - loss: 1.2747 - val_loss: 1.4085 - 184ms/epoch - 18ms/step Epoch 293/300 10/10 - 0s - loss: 1.2728 - val_loss: 1.3846 - 185ms/epoch - 19ms/step Epoch 294/300 10/10 - 0s - loss: 1.2567 - val_loss: 1.3465 - 180ms/epoch - 18ms/step Epoch 295/300 10/10 - 0s - loss: 1.2643 - val_loss: 1.3914 - 195ms/epoch - 20ms/step Epoch 296/300 10/10 - 0s - loss: 1.2747 - val_loss: 1.4068 - 182ms/epoch - 18ms/step Epoch 297/300 10/10 - 0s - loss: 1.3311 - val_loss: 1.5587 - 169ms/epoch - 17ms/step Epoch 298/300 10/10 - 0s - loss: 1.3347 - val_loss: 1.4132 - 181ms/epoch - 18ms/step Epoch 299/300 10/10 - 0s - loss: 1.3485 - val_loss: 1.4953 - 200ms/epoch - 20ms/step Epoch 300/300 10/10 - 0s - loss: 1.3156 - val_loss: 1.4378 - 203ms/epoch - 20ms/step
性能評価
plt.figure(figsize=(15, 10))
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("model loss")
plt.legend(["train", "validation"], loc="upper right")
plt.ylabel("loss")
plt.xlabel("epoch")
# From data to latent space.
z, _ = model(normalized_data)
# From latent space to data.
samples = model.distribution.sample(3000)
x, _ = model.predict(samples)
f, axes = plt.subplots(2, 2)
f.set_size_inches(20, 15)
axes[0, 0].scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
axes[0, 0].set(title="Inference data space X", xlabel="x", ylabel="y")
axes[0, 1].scatter(z[:, 0], z[:, 1], color="r")
axes[0, 1].set(title="Inference latent space Z", xlabel="x", ylabel="y")
axes[0, 1].set_xlim([-3.5, 4])
axes[0, 1].set_ylim([-4, 4])
axes[1, 0].scatter(samples[:, 0], samples[:, 1], color="g")
axes[1, 0].set(title="Generated latent space Z", xlabel="x", ylabel="y")
axes[1, 1].scatter(x[:, 0], x[:, 1], color="g")
axes[1, 1].set(title="Generated data space X", label="x", ylabel="y")
axes[1, 1].set_xlim([-2, 2])
axes[1, 1].set_ylim([-2, 2])
94/94 [==============================] - 0s 2ms/step
(-2.0, 2.0)
以上