Keras 2 : examples : 生成深層学習 – Textual Inversion で StableDiffusion に新コンセプトを教える (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/26/2022 (keras 2.11.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Generative Deep Learning : Teach StableDiffusion new concepts via Textual Inversion (Author: Ian Stenbit, lukewood : Created : 2022/12/09)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : 生成深層学習 – Textual Inversion で StableDiffusion に新コンセプトを教える
Description : KerasCV の StableDiffusion 実装で新しい視覚的コンセプトを学習する。
Textual Inversion
そのリリースから、StableDiffusion はジェネラティブ (生成的) 機械学習コミュニティの中で素早くお気に入りになりました。高いボリュームのトラフィックはオープンソースが寄与した改良、ヘビーなプロンプトエンジニアリング、そして新規のアルゴリズムの考案にさえ繋がりました。
おそらく使用されている最も印象的な新しいアルゴリズムは Textual Inversion で、これは An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion で提案されました。
Textual Inversion は再調整を利用して画像生成器に特定の視覚的コンセプトを教える過程です。下の図では、この過程の例を見ることができます、そこでは作成者はモデルに “S_*” と呼称する新しいコンセプトを教えています。
概念的には、textual inversion は新しいテキストトークンに対するトークン埋め込みを、StableDiffusion の残りのコンポーネントは凍結したままで学習することで動作します。
このガイドは KerasCV でリリースされた StableDiffusion モデルを Textual-Inversion アルゴリズムを使用して再調整する方法を示します。このガイドの最後までには、”Gandalf the Gray as a ” を書くことができるようになります。
最初に、必要なパッケージをインストールして StableDiffusion インスタンスを作成しましょう、するとそのサブコンポーネントの幾つかを再調整のために利用できます。
!pip install -q git+https://github.com/keras-team/keras-cv.git
!pip install -q tensorflow==2.11.0
import math
import random
import keras_cv
import numpy as np
import tensorflow as tf
from keras_cv import layers as cv_layers
from keras_cv.models.stable_diffusion import NoiseScheduler
from tensorflow import keras
import matplotlib.pyplot as plt
stable_diffusion = keras_cv.models.StableDiffusion()
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
次に、生成画像を表示するための可視化ユティリティを定義しましょう :
def plot_images(images):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
plt.imshow(images[i])
plt.axis("off")
テキスト-画像ペアのデータセットを作成する
新しいトークンの埋め込みを訓練するために、最初にテキスト-画像ペアから構成されるデータセットを作成する必要があります。データセットの各サンプルは StableDiffusion に教えようとしているコンセプトの画像と画像のコンテンツを正確に表すキャプションを含んでいなければなりません。このチュートリアルでは、StableDiffusion に Luke と Ian の GitHub アバターのコンセプトを教えまます :
First, let’s construct an image dataset of cat dolls:
def assemble_image_dataset(urls):
# Fetch all remote files
files = [tf.keras.utils.get_file(origin=url) for url in urls]
# Resize images
resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)
images = [keras.utils.load_img(img) for img in files]
images = [keras.utils.img_to_array(img) for img in images]
images = np.array([resize(img) for img in images])
# The StableDiffusion image encoder requires images to be normalized to the
# [-1, 1] pixel value range
images = images / 127.5 - 1
# Create the tf.data.Dataset
image_dataset = tf.data.Dataset.from_tensor_slices(images)
# Shuffle and introduce random noise
image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)
image_dataset = image_dataset.map(
cv_layers.RandomCropAndResize(
target_size=(512, 512),
crop_area_factor=(0.8, 1.0),
aspect_ratio_factor=(1.0, 1.0),
),
num_parallel_calls=tf.data.AUTOTUNE,
)
image_dataset = image_dataset.map(
cv_layers.RandomFlip(mode="horizontal"),
num_parallel_calls=tf.data.AUTOTUNE,
)
return image_dataset
Next, we assemble a text dataset:
MAX_PROMPT_LENGTH = 77
placeholder_token = ""
def pad_embedding(embedding):
return embedding + (
[stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))
)
stable_diffusion.tokenizer.add_tokens(placeholder_token)
def assemble_text_dataset(prompts):
prompts = [prompt.format(placeholder_token) for prompt in prompts]
embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]
embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]
text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)
text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)
return text_dataset
最後に、テキスト-画像ペアのデータセットを作成するためにデータセットをひとつに zip します。
def assemble_dataset(urls, prompts):
image_dataset = assemble_image_dataset(urls)
text_dataset = assemble_text_dataset(prompts)
# the image dataset is quite short, so we repeat it to match the length of the
# text prompt dataset
image_dataset = image_dataset.repeat()
# we use the text prompt dataset to determine the length of the dataset. Due to
# the fact that there are relatively few prompts we repeat the dataset 5 times.
# we have found that this anecdotally improves results.
text_dataset = text_dataset.repeat(5)
return tf.data.Dataset.zip((image_dataset, text_dataset))
プロンプトが説明的 (descriptive) であることを保証するため、極めて一般的なプロンプトを使用します。
Let’s try this out with some sample images and prompts.
train_ds = assemble_dataset(
urls=[
"https://i.imgur.com/VIedH1X.jpg",
"https://i.imgur.com/eBw13hE.png",
"https://i.imgur.com/oJ3rSg7.png",
"https://i.imgur.com/5mCL6Df.jpg",
"https://i.imgur.com/4Q6WWyI.jpg",
],
prompts=[
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
],
)
プロンプトの正確さの重要性について
このガイドを書く時最初の試みの際、データセットにこれらの猫の人形 (cat dolls) のグループの画像を含めましたが、上に列挙された一般的なプロンプトを使い続けました。結果は逸話的に (anecdotally) 貧弱なものでした。例えば、ここにこの手法を使用した猫の人形 gandalf があります :
それはコンセプト的には近いですが、それほど素晴らしいわけではありません。
これを是正するため、画像を単一の猫の人形と猫の人形のグループに分割して実験し始めました。この分割に続いて、グループのショットのために新しいプロンプトを考え出しました。
コンテンツを正確に表現するテキスト-to-画像ペアでの訓練は結果の品質を大幅にブーストしました。これはプロンプトの正確性の重要さを証明しています。
画像を単一画像とグループ画像に分離することに加えて、”a dark photo of the {}” のような幾つかの不正確なプロンプトも除去しました。
これを念頭に置いて、最終的な訓練データセットを以下のように作成しました :
single_ds = assemble_dataset(
urls=[
"https://i.imgur.com/VIedH1X.jpg",
"https://i.imgur.com/eBw13hE.png",
"https://i.imgur.com/oJ3rSg7.png",
"https://i.imgur.com/5mCL6Df.jpg",
"https://i.imgur.com/4Q6WWyI.jpg",
],
prompts=[
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
],
)
Looks great!
次に、GitHub アバターのグループのデータセットを作成します :
group_ds = assemble_dataset(
urls=[
"https://i.imgur.com/yVmZ2Qa.jpg",
"https://i.imgur.com/JbyFbZJ.jpg",
"https://i.imgur.com/CCubd3q.jpg",
],
prompts=[
"a photo of a group of {}",
"a rendering of a group of {}",
"a cropped photo of the group of {}",
"the photo of a group of {}",
"a photo of a clean group of {}",
"a photo of my group of {}",
"a photo of a cool group of {}",
"a close-up photo of a group of {}",
"a bright photo of the group of {}",
"a cropped photo of a group of {}",
"a photo of the group of {}",
"a good photo of the group of {}",
"a photo of one group of {}",
"a close-up photo of the group of {}",
"a rendition of the group of {}",
"a photo of the clean group of {}",
"a rendition of a group of {}",
"a photo of a nice group of {}",
"a good photo of a group of {}",
"a photo of the nice group of {}",
"a photo of the small group of {}",
"a photo of the weird group of {}",
"a photo of the large group of {}",
"a photo of a cool group of {}",
"a photo of a small group of {}",
],
)
最後に、2 つのデータセットを連結します :
train_ds = single_ds.concatenate(group_ds)
train_ds = train_ds.batch(1).shuffle(
train_ds.cardinality(), reshuffle_each_iteration=True
)
新しいトークンをテキストエンコーダに追加する
次に、StableDiffusion のための新しいテキストエンコーダを作成して ” のための新しい埋め込みをモデルに追加します。
tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]
new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(
tf.constant(tokenized_initializer)
)
# Get len of .vocab instead of tokenizer
new_vocab_size = len(stable_diffusion.tokenizer.vocab)
# The embedding layer is the 2nd layer in the text encoder
old_token_weights = stable_diffusion.text_encoder.layers[
2
].token_embedding.get_weights()
old_position_weights = stable_diffusion.text_encoder.layers[
2
].position_embedding.get_weights()
old_token_weights = old_token_weights[0]
new_weights = np.expand_dims(new_weights, axis=0)
new_weights = np.concatenate([old_token_weights, new_weights], axis=0)
Let’s construct a new TextEncoder and prepare it.
# Have to set download_weights False so we can init (otherwise tries to load weights)
new_encoder = keras_cv.models.stable_diffusion.TextEncoder(
keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,
vocab_size=new_vocab_size,
download_weights=False,
)
for index, layer in enumerate(stable_diffusion.text_encoder.layers):
# Layer 2 is the embedding layer, so we omit it from our weight-copying
if index == 2:
continue
new_encoder.layers[index].set_weights(layer.get_weights())
new_encoder.layers[2].token_embedding.set_weights([new_weights])
new_encoder.layers[2].position_embedding.set_weights(old_position_weights)
stable_diffusion._text_encoder = new_encoder
stable_diffusion._text_encoder.compile(jit_compile=True)
訓練
さてエキサイティングなパート: 訓練に移ることができます!
TextualInversion では、モデルの訓練される唯一のピースは埋め込みベクトルです。モデルの残りは凍結しましょう。
stable_diffusion.diffusion_model.trainable = False
stable_diffusion.decoder.trainable = False
stable_diffusion.text_encoder.trainable = True
stable_diffusion.text_encoder.layers[2].trainable = True
def traverse_layers(layer):
if hasattr(layer, "layers"):
for layer in layer.layers:
yield layer
if hasattr(layer, "token_embedding"):
yield layer.token_embedding
if hasattr(layer, "position_embedding"):
yield layer.position_embedding
for layer in traverse_layers(stable_diffusion.text_encoder):
if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:
layer.trainable = True
else:
layer.trainable = False
new_encoder.layers[2].position_embedding.trainable = False
Let’s confirm the proper weights are set to trainable.
all_models = [
stable_diffusion.text_encoder,
stable_diffusion.diffusion_model,
stable_diffusion.decoder,
]
print([[w.shape for w in model.trainable_weights] for model in all_models])
[[TensorShape([49409, 768])], [], []]
新しい埋め込みの訓練
埋め込みを訓練するためには、幾つかのユティリティが必要です。KerasCV から NoiseScheduler をインポートして、以下のユティリティを下で定義します :
- sample_from_encoder_outputs はベース StableDiffusion 画像エンコーダのラッパーで、それは (他の多くの SD アプリケーションのように) 単なる平均を取るのではなく、画像エンコーダにより生成される統計的な分布からサンプリングします。
- get_timestep_embedding は拡散モデルの特定の時間ステップに対する埋め込みを生成します。
- get_position_ids はテキストエンコーダに対する位置 ID のテンソルを生成します (これは [1, MAX_PROMPT_LENGTH] からの単なる系列です)。
# Remove the top layer from the encoder, which cuts off the variance and only returns
# the mean
training_image_encoder = keras.Model(
stable_diffusion.image_encoder.input,
stable_diffusion.image_encoder.layers[-2].output,
)
def sample_from_encoder_outputs(outputs):
mean, logvar = tf.split(outputs, 2, axis=-1)
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
std = tf.exp(0.5 * logvar)
sample = tf.random.normal(tf.shape(mean))
return mean + std * sample
def get_timestep_embedding(timestep, dim=320, max_period=10000):
half = dim // 2
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
return embedding
def get_position_ids():
return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
次に、StableDiffusionFineTuner を実装します、これは keras.Model のサブクラスでテキストエンコーダのトークン埋め込みを訓練するために train_step をオーバーライドします。これは Textual Inversion アルゴリズムの中核です。
抽象的に言えば、訓練ステップは訓練画像に対する凍結された SD 画像エンコーダの潜在的分布の出力からサンプルを取り、そのサンプルにノイズを加えてから、ノイズのあるサンプルを凍結された拡散モデルに渡します。拡散モデルの隠れ状態は画像に対応するプロンプトのテキストエンコーダの出力です。
最終的な目標状態は、拡散モデルが隠れ状態としてテキストエンコーディングを使用してサンプルからノイズを分離できるようにすることですので、損失はノイズと拡散モデル (理想的には画像潜在変数からノイズを取り除きます) の出力の平均二乗誤差です。
テキストエンコーダのトークン埋め込みだけに対する勾配を計算し、訓練ステップでは学習しているトークン以外のすべてのトークンに対する勾配はゼロにします。
訓練ステップについての詳細はインラインのコードコメントをご覧ください。
class StableDiffusionFineTuner(keras.Model):
def __init__(self, stable_diffusion, noise_scheduler, **kwargs):
super().__init__(**kwargs)
self.stable_diffusion = stable_diffusion
self.noise_scheduler = noise_scheduler
def train_step(self, data):
images, embeddings = data
with tf.GradientTape() as tape:
# Sample from the predicted distribution for the training image
latents = sample_from_encoder_outputs(training_image_encoder(images))
# The latents must be downsampled to match the scale of the latents used
# in the training of StableDiffusion. This number is truly just a "magic"
# constant that they chose when training the model.
latents = latents * 0.18215
# Produce random noise in the same shape as the latent sample
noise = tf.random.normal(tf.shape(latents))
batch_dim = tf.shape(latents)[0]
# Pick a random timestep for each sample in the batch
timesteps = tf.random.uniform(
(batch_dim,),
minval=0,
maxval=noise_scheduler.train_timesteps,
dtype=tf.int64,
)
# Add noise to the latents based on the timestep for each sample
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
# Encode the text in the training samples to use as hidden state in the
# diffusion model
encoder_hidden_state = self.stable_diffusion.text_encoder(
[embeddings, get_position_ids()]
)
# Compute timestep embeddings for the randomly-selected timesteps for each
# sample in the batch
timestep_embeddings = tf.map_fn(
fn=get_timestep_embedding,
elems=timesteps,
fn_output_signature=tf.float32,
)
# Call the diffusion model
noise_pred = self.stable_diffusion.diffusion_model(
[noisy_latents, timestep_embeddings, encoder_hidden_state]
)
# Compute the mean-squared error loss and reduce it.
loss = self.compiled_loss(noise_pred, noise)
loss = tf.reduce_mean(loss, axis=2)
loss = tf.reduce_mean(loss, axis=1)
loss = tf.reduce_mean(loss)
# Load the trainable weights and compute the gradients for them
trainable_weights = self.stable_diffusion.text_encoder.trainable_weights
grads = tape.gradient(loss, trainable_weights)
# Gradients are stored in indexed slices, so we have to find the index
# of the slice(s) which contain the placeholder token.
index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())
condition = grads[0].indices == 49408
condition = tf.expand_dims(condition, axis=-1)
# Override the gradients, zeroing out the gradients for all slices that
# aren't for the placeholder token, effectively freezing the weights for
# all other tokens.
grads[0] = tf.IndexedSlices(
values=tf.where(condition, grads[0].values, 0),
indices=grads[0].indices,
dense_shape=grads[0].dense_shape,
)
self.optimizer.apply_gradients(zip(grads, trainable_weights))
return {"loss": loss}
訓練を始める前に、私たちのトークンに対して StableDiffusion が何を生成するか見てみましょう。
generated = stable_diffusion.text_to_image(
f"an oil painting of {placeholder_token}", seed=1337, batch_size=3
)
plot_images(generated)
25/25 [==============================] - 19s 314ms/step
ご覧のように、モデルは依然として私たちのトークンを猫として考えています、これはカスタムトークンを初期化するために使用したシード・トークンであるためです。
さて、訓練を開始するには、任意の他の Keras モデルのようにモデルを compile() するだけです。それを行なう前に、訓練用のノイズスケジューラもインスタンス化して学習率と optimizer のような訓練パラメータを設定します。
noise_scheduler = NoiseScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
train_timesteps=1000,
)
trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")
EPOCHS = 50
learning_rate = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS
)
optimizer = keras.optimizers.Adam(
weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10
)
trainer.compile(
optimizer=optimizer,
# We are performing reduction manually in our train step, so none is required here.
loss=keras.losses.MeanSquaredError(reduction="none"),
)
訓練を監視するため、エポック毎にカスタムトークンを使用して 2,3 の画像を生成するために keras.callbacks.Callback を作成することができます。
異なるプロンプトで 3 つのコールバックを作成し、訓練の過程でそれらがどのように進捗するかを見ることができるようにします。固定シードを使用しますので、学習されたトークンの進捗を簡単に見ることができます。
class GenerateImages(keras.callbacks.Callback):
def __init__(
self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs
):
super().__init__(**kwargs)
self.stable_diffusion = stable_diffusion
self.prompt = prompt
self.seed = seed
self.frequency = frequency
self.steps = steps
def on_epoch_end(self, epoch, logs):
if epoch % self.frequency == 0:
images = self.stable_diffusion.text_to_image(
self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed
)
plot_images(
images,
)
cbs = [
GenerateImages(
stable_diffusion, prompt=f"an oil painting of {placeholder_token}", seed=1337
),
GenerateImages(
stable_diffusion, prompt=f"gandalf the gray as a {placeholder_token}", seed=1337
),
GenerateImages(
stable_diffusion,
prompt=f"two {placeholder_token} getting married, photorealistic, high quality",
seed=1337,
),
]
Now, all that is left to do is to call model.fit()!
trainer.fit(
train_ds,
epochs=EPOCHS,
callbacks=cbs,
)
Epoch 1/50 50/50 [==============================] - 16s 318ms/step 50/50 [==============================] - 16s 318ms/step 50/50 [==============================] - 16s 318ms/step 250/250 [==============================] - 194s 469ms/step - loss: 0.1533 Epoch 2/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1557 Epoch 3/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1359 Epoch 4/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1693 Epoch 5/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1475 Epoch 6/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1472 Epoch 7/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1533 Epoch 8/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1450 Epoch 9/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1639 Epoch 10/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1351 Epoch 11/50 50/50 [==============================] - 16s 316ms/step 50/50 [==============================] - 16s 316ms/step 50/50 [==============================] - 16s 317ms/step 250/250 [==============================] - 116s 464ms/step - loss: 0.1474 Epoch 12/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1737 Epoch 13/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1427 Epoch 14/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1698 Epoch 15/50 250/250 [==============================] - 68s 270ms/step - loss: 0.1424 Epoch 16/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1339 Epoch 17/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1397 Epoch 18/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1469 Epoch 19/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1649 Epoch 20/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1582 Epoch 21/50 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 316ms/step 50/50 [==============================] - 16s 316ms/step 250/250 [==============================] - 116s 462ms/step - loss: 0.1331 Epoch 22/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1319 Epoch 23/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1521 Epoch 24/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1486 Epoch 25/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1449 Epoch 26/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1349 Epoch 27/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1454 Epoch 28/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1394 Epoch 29/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1489 Epoch 30/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1338 Epoch 31/50 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 320ms/step 50/50 [==============================] - 16s 315ms/step 250/250 [==============================] - 116s 462ms/step - loss: 0.1328 Epoch 32/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1693 Epoch 33/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1420 Epoch 34/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1255 Epoch 35/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1239 Epoch 36/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1558 Epoch 37/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1527 Epoch 38/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1461 Epoch 39/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1555 Epoch 40/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1515 Epoch 41/50 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 315ms/step 250/250 [==============================] - 116s 461ms/step - loss: 0.1291 Epoch 42/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1474 Epoch 43/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1908 Epoch 44/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1506 Epoch 45/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1424 Epoch 46/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1601 Epoch 47/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1312 Epoch 48/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1524 Epoch 49/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1477 Epoch 50/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1397 <keras.callbacks.History at 0x7f183aea3eb8>
モデルが時間につれて新しいトークンをどのように学習するかを見るのは非常に楽しいです。いろいろ試して、ベストな画像を生成するために訓練パラメータと訓練データセットをどのように調整できるかを見てください。
Taking the Fine Tuned Model for a Spin
Now for the really fun part. 私たちはカスタムトークンに対するトークン埋め込みを学習しましたので、今では他のトークンに対するのと同じ方法で StableDiffusion で画像を生成することができます。
ここに、cat doll (猫の人形) トークンからのサンプル出力とともに、幾つかの楽しいプロンプトの例があります!
generated = stable_diffusion.text_to_image(
f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "
"golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "
"character concepts, digital painting, mystery, adventure",
batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 316ms/step
generated = stable_diffusion.text_to_image(
f"A masterpiece of a {placeholder_token} crying out to the heavens. "
f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "
"life right out of it.",
batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 314ms/step
generated = stable_diffusion.text_to_image(
f"An evil {placeholder_token}.", batch_size=3
)
plot_images(generated)
25/25 [==============================] - 8s 322ms/step
generated = stable_diffusion.text_to_image(
f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",
batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 315ms/step
まとめ
Textual Inversion アルゴリズムを使用して StableDiffusion に新しいコンセプトを教えることができます!
Some possible next steps to follow:
- 貴方自身のプロンプトを試す。
- モデルにスタイルを教える。
- Gather a dataset of your favorite pet cat or dog and teach the model about it
以上