Keras 2 : examples : 生成深層学習 – Stable Diffusion の潜在的空間の探索 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/24/2022 (keras 2.11.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Generative Deep Learning : A walk through latent space with Stable Diffusion (Author: Ian Stenbit, fchollet, lukewood : Created : 2022/09/28)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 生成深層学習 – Stable Diffusion の潜在的空間の探索
Description : Stable Diffusion の潜在的多様体 (= manifold) を探求する。
概要
画像生成モデルは視覚世界の「潜在的多様体 (latent manifold)」: 各点が画像にマップされる低次元ベクトル空間を学習します。そのような多様体上の点から表示可能な画像を得ることを「デコード」と呼びます — Stable Diffusion モデルでは、これは「デコーダ」モデルにより処理されます。
画像の潜在的多様体は連続的で補間的です、これはつまり :
- 多様体上の少しの移動は対応する画像を少し変更するだけです (連続性)。
- 多様体上の任意の 2 点 A と B に対して (i.e. 任意の 2 つの画像)、各中間点も多様体上にある (i.e これもまた正当な画像) ようなパスを経由して A から B に移動することが可能です。中間点は 2 つの開始画像の間の「補間 (interpolations)」と呼ばれます。
けれども Stable Diffusion は単なる画像モデルではありません、それはまた自然言語モデルでもあります。それは 2 つの潜在的空間を持ちます : 訓練の間に使用されたエンコーダで学習された 画像表現空間と、事前訓練と訓練時再調整の組み合わせを使用して学習される プロンプト潜在的空間 です。
潜在的空間ウォーキング、あるいは 潜在的空間探索、は潜在的空間の点をサンプリングして潜在的表現を段階的に変化させる過程です。最も一般的な応用はアニメーションを生成するもので、そこではサンプリングされた各点がデコーダに供給されて最終的なアニメーションのフレームとしてストアされます。高品質な潜在的表現については、これは一貫性があるようなニメーションを生成します。これらのアニメーションは潜在的空間の特徴マップへの洞察を提供し、訓練過程の改良に最終的にはつながる可能性があります。そのような GIF の一つが下に表示されます :
このガイドでは、Stable Diffusion の視覚的潜在的多様体とテキストエンコーダの潜在的多様体を通して、プロンプト補間と循環 (circular) ウォークを遂行するために KerasCV の Stable Diffusion API を利用する方法を示します。
このガイドは読者が Stable Diffusion の高いレベルの理解を持っていることを仮定しています。まだ持っていない場合は、Stable Diffusion チュートリアル を読むことから始めるべきです。
はじめに、KerasCV をインポートしてチュートリアル Generate images with Stable Diffusion で説明した最適化を使用して Stable Diffusion モデルをロードします。M1 Mac GPU で実行している場合には混合精度を有効にするべきではないことに注意してください。
!pip install keras-cv --upgrade --quiet
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import math
from PIL import Image
# Enable mixed precision
# (only do this if you have a recent NVIDIA GPU)
keras.mixed_precision.set_global_policy("mixed_float16")
# Instantiate the Stable Diffusion model
model = keras_cv.models.StableDiffusion(jit_compile=True)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0 By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
テキストプロンプト間の補間
Stable Diffusion では、テキストプロンプトは最初にベクトルにエンコードされて、そのエンコーディングは拡散過程をガイドするために使用されます。潜在的エンコーディング・ベクトルは shape 77×768 (that’s huge!) を持ち、そして Stable Diffusion にテキストプロンプトを与えるとき、潜在的多様体上のそのような 1 点だけから画像を生成しています。
この多様体のより多くを探求するために、2 つのテキストエンコーディング間を補間してそれらの補間された点で画像を生成することができます :
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
interpolation_steps = 5
encoding_1 = tf.squeeze(model.encode_text(prompt_1))
encoding_2 = tf.squeeze(model.encode_text(prompt_2))
interpolated_encodings = tf.linspace(encoding_1, encoding_2, interpolation_steps)
# Show the size of the latent manifold
print(f"Encoding shape: {encoding_1.shape}")
Encoding shape: (77, 768)
エンコーディングを補間したら、各点から画像を生成することができます。結果としての画像の間である程度の安定性を保持するために、画像間の拡散ノイズを一定に維持します。
seed = 12345
noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)
images = model.generate_image(
interpolated_encodings,
batch_size=interpolation_steps,
diffusion_noise=noise,
)
25/25 [==============================] - 50s 340ms/step
幾つかの補間された画像を生成したので、それらを見てみましょう!
このチュートリアルを通して、一連の画像を gif としてエクスポートしていきますので、それらはある時間的なコンテキストにより簡単に見ることができます。最初と最後の画像が概念的に一致しない画像のシークエンスについては、gif を束ねます (rubber-band)。
Colab で実行している場合、以下を実行することで貴方自身の GIF を見ることができます :
from IPython.display import Image as IImage
IImage("doggo-and-fruit-5.gif")
def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
if rubber_band:
images += images[2:-1][::-1]
images[0].save(
filename,
save_all=True,
append_images=images[1:],
duration=1000 // frames_per_second,
loop=0,
)
export_as_gif(
"doggo-and-fruit-5.gif",
[Image.fromarray(img) for img in images],
frames_per_second=2,
rubber_band=True,
)
この結果は意外に見えるかもしれません。一般にプロンプト間の補間は一貫した外観の画像を生成し、2 つのプロンプトのコンテンツ間の漸進的なコンセプトシフトを示すことが多いです。これは高品質な表現空間であることを暗示していて、視覚世界の自然な構造を密接に反映しています。
これをベストに可視化するためには、数百ステップを使用した、遥かに極め細かい補間を行なう必要があります。(GPU を OOM させないように) バッチサイズを小さく保持するため、これは補間されたエンコーディングを手動でバッチ処理する必要があります。
interpolation_steps = 150
batch_size = 3
batches = interpolation_steps // batch_size
interpolated_encodings = tf.linspace(encoding_1, encoding_2, interpolation_steps)
batched_encodings = tf.split(interpolated_encodings, batches)
images = []
for batch in range(batches):
images += [
Image.fromarray(img)
for img in model.generate_image(
batched_encodings[batch],
batch_size=batch_size,
num_steps=25,
diffusion_noise=noise,
)
]
export_as_gif("doggo-and-fruit-150.gif", images, rubber_band=True)
25/25 [==============================] - 49s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 245ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step
結果としての gif は 2 つのプロンプト間の遥かに明瞭で一貫したシフトを示します。幾つかの貴方自身のプロンプトで試して実験してください!
この概念を 1 つより多い画像に対して拡張することさえできます。例えば、4 つのプロンプト間で補間できます。
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
prompt_3 = "The eiffel tower in the style of starry night"
prompt_4 = "An architectural sketch of a skyscraper"
interpolation_steps = 6
batch_size = 3
batches = (interpolation_steps**2) // batch_size
encoding_1 = tf.squeeze(model.encode_text(prompt_1))
encoding_2 = tf.squeeze(model.encode_text(prompt_2))
encoding_3 = tf.squeeze(model.encode_text(prompt_3))
encoding_4 = tf.squeeze(model.encode_text(prompt_4))
interpolated_encodings = tf.linspace(
tf.linspace(encoding_1, encoding_2, interpolation_steps),
tf.linspace(encoding_3, encoding_4, interpolation_steps),
interpolation_steps,
)
interpolated_encodings = tf.reshape(
interpolated_encodings, (interpolation_steps**2, 77, 768)
)
batched_encodings = tf.split(interpolated_encodings, batches)
images = []
for batch in range(batches):
images.append(
model.generate_image(
batched_encodings[batch],
batch_size=batch_size,
diffusion_noise=noise,
)
)
def plot_grid(
images,
path,
grid_size,
scale=2,
):
fig = plt.figure(figsize=(grid_size * scale, grid_size * scale))
fig.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)
plt.margins(x=0, y=0)
plt.axis("off")
images = images.astype(int)
for row in range(grid_size):
for col in range(grid_size):
index = row * grid_size + col
plt.subplot(grid_size, grid_size, index + 1)
plt.imshow(images[index].astype("uint8"))
plt.axis("off")
plt.margins(x=0, y=0)
plt.savefig(
fname=path,
pad_inches=0,
bbox_inches="tight",
transparent=False,
dpi=60,
)
images = np.concatenate(images)
plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)
25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 245ms/step 25/25 [==============================] - 6s 245ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step
diffusion_noise パラメータをドロップすることで拡散ノイズが変化することを可能にしながら補間することもできます :
images = []
for batch in range(batches):
images.append(model.generate_image(batched_encodings[batch], batch_size=batch_size))
images = np.concatenate(images)
plot_grid(images, "4-way-interpolation-varying-noise.jpg", interpolation_steps)
25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step
Next up — let’s go for some walks!
テキストプロンプト周りの散策
次の実験は特定のプロンプトから生成された点からはじめる潜在的多様体周りへの散策です。
walk_steps = 150
batch_size = 3
batches = walk_steps // batch_size
step_size = 0.005
encoding = tf.squeeze(
model.encode_text("The Eiffel Tower in the style of starry night")
)
# Note that (77, 768) is the shape of the text encoding.
delta = tf.ones_like(encoding) * step_size
walked_encodings = []
for step_index in range(walk_steps):
walked_encodings.append(encoding)
encoding += delta
walked_encodings = tf.stack(walked_encodings)
batched_encodings = tf.split(walked_encodings, batches)
images = []
for batch in range(batches):
images += [
Image.fromarray(img)
for img in model.generate_image(
batched_encodings[batch],
batch_size=batch_size,
num_steps=25,
diffusion_noise=noise,
)
]
export_as_gif("eiffel-tower-starry-night.gif", images, rubber_band=True)
25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 244ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step 25/25 [==============================] - 6s 242ms/step
たぶん驚く必要もなく、エンコーダの潜在的多様体から離れすぎて歩けば、支離滅裂な画像を生成します。貴方自身のプロンプトを設定し、歩く幅を増減させるために step_size を調整して貴方自身で試してください。歩幅が大きくなるとき、ウォークは極端にノイズが多い画像を生成する領域に繋がることが多いことに注意してください。
単一プロンプトに対する拡散ノイズ空間の循環ウォーク
最後の実験は、一つのプロンプトに固執して、拡散モデルがそのプロンプトから生成可能な様々な画像を探索することです。拡散過程をシードするために使用されるノイズを制御することでこれを行ないます。
2 つのノイズ成分, x と y を作成して、0 から 2π まで歩いて、x 成分のコサインと y 成分の sin を足してノイズを生成します。このアプローチを使用すれば、ウォークの最後は (ウォークを始めたのと) 同じノイズ入力に到達しますので、「ループ可能な」結果を得られます!
prompt = "An oil paintings of cows in a field next to a windmill in Holland"
encoding = tf.squeeze(model.encode_text(prompt))
walk_steps = 150
batch_size = 3
batches = walk_steps // batch_size
walk_noise_x = tf.random.normal(noise.shape, dtype=tf.float64)
walk_noise_y = tf.random.normal(noise.shape, dtype=tf.float64)
walk_scale_x = tf.cos(tf.linspace(0, 2, walk_steps) * math.pi)
walk_scale_y = tf.sin(tf.linspace(0, 2, walk_steps) * math.pi)
noise_x = tf.tensordot(walk_scale_x, walk_noise_x, axes=0)
noise_y = tf.tensordot(walk_scale_y, walk_noise_y, axes=0)
noise = tf.add(noise_x, noise_y)
batched_noise = tf.split(noise, batches)
images = []
for batch in range(batches):
images += [
Image.fromarray(img)
for img in model.generate_image(
encoding,
batch_size=batch_size,
num_steps=25,
diffusion_noise=batched_noise[batch],
)
]
export_as_gif("cows.gif", images)
25/25 [==============================] - 38s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 238ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 239ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 243ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 240ms/step 25/25 [==============================] - 6s 241ms/step 25/25 [==============================] - 6s 241ms/step
Experiment with your own prompts and with different values of unconditional_guidance_scale!
まとめ
Stable Diffusion は単なるテキスト-to-画像生成以上の多くのものを提供します。テキストエンコーダの潜在的多様体と拡散モデルのノイズ空間の探索は、このモデルのパワーを体験するための 2 つの楽しい方法です、そして KerasCV はそれを簡単にします!
以上