Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : 生成深層学習 – 条件付き画像生成のための GauGAN

Posted on 07/09/202207/14/2022 by Sales Information

Keras 2 : examples : 生成深層学習 – 条件付き画像生成のための GauGAN (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/09/2022 (keras 2.9.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Generative Deep Learning : GauGAN for conditional image generation (Author: Soumik Rakshit, Sayak Paul)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Keras 2 : examples : 生成深層学習 – 条件付き画像生成のための GauGAN

Description : 条件付き画像生成のための GauGAN の実装。

 

イントロダクション

このサンプルでは、Semantic Image Synthesis with Spatially-Adaptive Normalization で提案された GauGAN アーキテクチャの実装を提示します。簡潔には、GauGAN は、以下に示されるように、キュー (= cue) 画像とセグメンテーション・マップにより条件付けられたリアルな画像を生成するために敵対的生成ネットワーク (GAN) を使用します (画像ソース) :

GauGAN の主要コンポーネントは :

  • SPADE (aka spatially-adaptive 正規化) : GauGAN の著者らは、(バッチ正規化 のような) 従来の正規化層は、入力として提供されたセグメンテーション・マップから得られた意味的情報を破壊することを主張しています。この問題に対処するため、著者らは SPADE, 空間的に適応可能なアフィン・パラメータ (スケールとバイアス) を学習するために特に適している正規化層を導入しました。これは、各意味的ラベルに対してスケーリングとバイアス・パラメータの異なるセットを学習することにより成されます。

  • 変分エンコーダ : 変分オートエンコーダ にインスパイアされ、GauGAN は、cue 画像から正規 (ガウス) 分布の平均と分散を学習する、変分定式化を利用しています。そこから GauGAN は名前を得ました。GauGAN の generator は入力として、ガウス分布からサンプリングされた潜在値 (= latent) と one-hot エンコードされたセマンティックセグメンテーションのラベルマップを受けとります。cue 画像は generator をスタイルの生成へガイドするスタイル画像として機能します。この変分定式化は GauGAN が画像の多様性と忠実性を獲得するのに役立ちます。

  • マルチスケール・パッチ discriminator : PatchGAN にインスパイアされ、GauGAN は与えられた画像をパッチベースで評価して平均スコアを生成する discriminator を使用します。

サンプルを進めながら、様々なコンポーネントの各々を詳細に議論します。

GauGAN の徹底的なレビューについては、この記事 を参照してください。また 公式 GauGAN web サイト をチェックすることも勧めます、これは GauGAN の多くの創造的なアプリケーションを持っています。このサンプルは読者が GAN の基礎的な概念に既に馴染みがあることを仮定しています。refresher が必要であれば、以下のリソースが有用であるかもしれません :

  • Chapter on GANs from the Deep Learning with Python book by François Chollet.
  • GAN implementations on keras.io:
    • [Data efficient GANs](https://keras.io/examples/generative/gan_ada)
    • [CycleGAN](https://keras.io/examples/generative/cyclegan)
    • [Conditional GAN](https://keras.io/examples/generative/conditional_gan)

 

データ・コレクション

GauGAN モデルを訓練するために Facades データセット を使用していきます。まずはそれをダウンロードしましょう。また TensorFlow Addons もインストールします。

!gdown https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj
!unzip -q facades_data.zip
!pip install -qqq tensorflow_addons
Downloading...
From: https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj
To: /content/keras-io/scripts/tmp_2820468/facades_data.zip
100% 26.0M/26.0M [00:00<00:00, 261MB/s]
     |████████████████████████████████| 1.1 MB 8.5 MB/s 
[?25h

 

インポート

import os
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

from glob import glob
from PIL import Image

 

データ分割

PATH = "./facades_data/"
SPLIT = 0.2

files = glob(PATH + "*.jpg")
np.random.shuffle(files)

split_index = int(len(files) * (1 - SPLIT))
train_files = files[:split_index]
val_files = files[split_index:]

print(f"Total samples: {len(files)}.")
print(f"Total training samples: {len(train_files)}.")
print(f"Total validation samples: {len(val_files)}.")
Total samples: 378.
Total training samples: 302.
Total validation samples: 76.

 

データローダ

BATCH_SIZE = 4
IMG_HEIGHT = IMG_WIDTH = 256
NUM_CLASSES = 12
AUTOTUNE = tf.data.AUTOTUNE


def load(image_files, batch_size, is_train=True):
    def _random_crop(
        segmentation_map, image, labels, crop_size=(IMG_HEIGHT, IMG_WIDTH),
    ):
        crop_size = tf.convert_to_tensor(crop_size)
        image_shape = tf.shape(image)[:2]
        margins = image_shape - crop_size
        y1 = tf.random.uniform(shape=(), maxval=margins[0], dtype=tf.int32)
        x1 = tf.random.uniform(shape=(), maxval=margins[1], dtype=tf.int32)
        y2 = y1 + crop_size[0]
        x2 = x1 + crop_size[1]

        cropped_images = []
        images = [segmentation_map, image, labels]
        for img in images:
            cropped_images.append(img[y1:y2, x1:x2])
        return cropped_images

    def _load_data_tf(image_file, segmentation_map_file, label_file):
        image = tf.image.decode_png(tf.io.read_file(image_file), channels=3)
        segmentation_map = tf.image.decode_png(
            tf.io.read_file(segmentation_map_file), channels=3
        )
        labels = tf.image.decode_bmp(tf.io.read_file(label_file), channels=0)
        labels = tf.squeeze(labels)

        image = tf.cast(image, tf.float32) / 127.5 - 1
        segmentation_map = tf.cast(segmentation_map, tf.float32) / 127.5 - 1
        return segmentation_map, image, labels

    segmentation_map_files = [
        image_file.replace("images", "segmentation_map").replace("jpg", "png")
        for image_file in image_files
    ]
    label_files = [
        image_file.replace("images", "segmentation_labels").replace("jpg", "bmp")
        for image_file in image_files
    ]
    dataset = tf.data.Dataset.from_tensor_slices(
        (image_files, segmentation_map_files, label_files)
    )

    dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
    dataset = dataset.map(_load_data_tf, num_parallel_calls=AUTOTUNE)
    dataset = dataset.map(_random_crop, num_parallel_calls=AUTOTUNE)
    dataset = dataset.map(
        lambda x, y, z: (x, y, tf.one_hot(z, NUM_CLASSES)), num_parallel_calls=AUTOTUNE
    )
    return dataset.batch(batch_size, drop_remainder=True)


train_dataset = load(train_files, batch_size=BATCH_SIZE, is_train=True)
val_dataset = load(val_files, batch_size=BATCH_SIZE, is_train=False)

次に、訓練セットから幾つかサンプルを可視化しましょう。

sample_train_batch = next(iter(train_dataset))
print(f"Segmentation map batch shape: {sample_train_batch[0].shape}.")
print(f"Image batch shape: {sample_train_batch[1].shape}.")
print(f"One-hot encoded label map shape: {sample_train_batch[2].shape}.")

# Plot a view samples from the training set.
for segmentation_map, real_image in zip(sample_train_batch[0], sample_train_batch[1]):
    fig = plt.figure(figsize=(10, 10))
    fig.add_subplot(1, 2, 1).set_title("Segmentation Map")
    plt.imshow((segmentation_map + 1) / 2)
    fig.add_subplot(1, 2, 2).set_title("Real Image")
    plt.imshow((real_image + 1) / 2)
    plt.show()
Segmentation map batch shape: (4, 256, 256, 3).
Image batch shape: (4, 256, 256, 3).
One-hot encoded label map shape: (4, 256, 256, 12).

この例の残りでは、便宜上、オリジナルの GauGAN 論文 からの幾つかの図を使用します。

 

カスタム層

以下のセクションで、以下の層を実装します :

  • SPADE
  • SPADE を含む残差ブロック
  • Gaussian サンプラー

 

Some more notes on SPADE

SPatially-Adaptive (DE) 正規化 or SPADE は、入力セマンティック・レイアウトが与えられたときに写真のようにリアルな画像を合成するための単純ですが効果的な層です。Pix2Pix (Isola et al.) や Pix2PixHD (Wang et al.) のようなセマンティック入力から条件付き画像生成するための以前の方法は深層ネットワークへの入力としてセマンティック・レイアウトを直接供給し、それから畳み込み、正規化、非線形層のスタックを通して処理します。これは、正規化層が意味的情報を洗い流す傾向があるために最適ではない場合が多いです。

SPADE では、セグメンテーション・マスクが最初に埋め込み空間上に射影されてから、変調 (= modulation) パラメータ γ と β を生成するために畳み込まれます。前の条件付き正規化法と違い、γ と β はベクトルではなく、空間次元を持つテンソルです。生成された γ と β は正規化された活性に要素毎に乗算されて加算されます。modulation パラメータは入力セグメンテーション・マスクに適応可能ですので、SPADE は意味的画像合成に対してより適切です。

class SPADE(layers.Layer):
    def __init__(self, filters, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.conv = layers.Conv2D(128, 3, padding="same", activation="relu")
        self.conv_gamma = layers.Conv2D(filters, 3, padding="same")
        self.conv_beta = layers.Conv2D(filters, 3, padding="same")

    def build(self, input_shape):
        self.resize_shape = input_shape[1:3]

    def call(self, input_tensor, raw_mask):
        mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
        x = self.conv(mask)
        gamma = self.conv_gamma(x)
        beta = self.conv_beta(x)
        mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
        std = tf.sqrt(var + self.epsilon)
        normalized = (input_tensor - mean) / std
        output = gamma * normalized + beta
        return output


class ResBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters

    def build(self, input_shape):
        input_filter = input_shape[-1]
        self.spade_1 = SPADE(input_filter)
        self.spade_2 = SPADE(self.filters)
        self.conv_1 = layers.Conv2D(self.filters, 3, padding="same")
        self.conv_2 = layers.Conv2D(self.filters, 3, padding="same")
        self.learned_skip = False

        if self.filters != input_filter:
            self.learned_skip = True
            self.spade_3 = SPADE(input_filter)
            self.conv_3 = layers.Conv2D(self.filters, 3, padding="same")

    def call(self, input_tensor, mask):
        x = self.spade_1(input_tensor, mask)
        x = self.conv_1(tf.nn.leaky_relu(x, 0.2))
        x = self.spade_2(x, mask)
        x = self.conv_2(tf.nn.leaky_relu(x, 0.2))
        skip = (
            self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2))
            if self.learned_skip
            else input_tensor
        )
        output = skip + x
        return output


class GaussianSampler(layers.Layer):
    def __init__(self, batch_size, latent_dim, **kwargs):
        super().__init__(**kwargs)
        self.batch_size = batch_size
        self.latent_dim = latent_dim

    def call(self, inputs):
        means, variance = inputs
        epsilon = tf.random.normal(
            shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0
        )
        samples = means + tf.exp(0.5 * variance) * epsilon
        return samples

次に、エンコーダのためのダウンサンプリング・ブロックを実装します。

def downsample(
    channels,
    kernels,
    strides=2,
    apply_norm=True,
    apply_activation=True,
    apply_dropout=False,
):
    block = keras.Sequential()
    block.add(
        layers.Conv2D(
            channels,
            kernels,
            strides=strides,
            padding="same",
            use_bias=False,
            kernel_initializer=keras.initializers.GlorotNormal(),
        )
    )
    if apply_norm:
        block.add(tfa.layers.InstanceNormalization())
    if apply_activation:
        block.add(layers.LeakyReLU(0.2))
    if apply_dropout:
        block.add(layers.Dropout(0.5))
    return block

GauGAN エンコーダは幾つかのダウンサンプリング・ブロックから構成されます。それは分布の平均と分散を出力します。

def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):
    input_image = keras.Input(shape=image_shape)
    x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image)
    x = downsample(2 * encoder_downsample_factor, 3)(x)
    x = downsample(4 * encoder_downsample_factor, 3)(x)
    x = downsample(8 * encoder_downsample_factor, 3)(x)
    x = downsample(8 * encoder_downsample_factor, 3)(x)
    x = layers.Flatten()(x)
    mean = layers.Dense(latent_dim, name="mean")(x)
    variance = layers.Dense(latent_dim, name="variance")(x)
    return keras.Model(input_image, [mean, variance], name="encoder")

次に、generator を実装します、これは変更された残差ブロックとアップサンプリング・ブロックから構成されます。それは潜在ベクトルと one-hot エンコードされたセグメンテーション・ラベルと取り、新しい画像を生成します。

SPADE では、generator の最初の層にセグメンテーション・マップを供給する必要はありません、何故ならば潜在入力は generator にエミュレートすることを望むスタイルについて十分な構造的情報を持つからです。以前のアーキテクチャで一般に使用される、generator のエンコーダ部も捨てます。これはより軽量な generator ネットワークという結果になります、これはまた入力としてランダム・ベクトルを取り、多様な合成への単純で自然なパスを有効にします。

def build_generator(mask_shape, latent_dim=256):
    latent = keras.Input(shape=(latent_dim))
    mask = keras.Input(shape=mask_shape)
    x = layers.Dense(16384)(latent)
    x = layers.Reshape((4, 4, 1024))(x)
    x = ResBlock(filters=1024)(x, mask)
    x = layers.UpSampling2D((2, 2))(x)
    x = ResBlock(filters=1024)(x, mask)
    x = layers.UpSampling2D((2, 2))(x)
    x = ResBlock(filters=1024)(x, mask)
    x = layers.UpSampling2D((2, 2))(x)
    x = ResBlock(filters=512)(x, mask)
    x = layers.UpSampling2D((2, 2))(x)
    x = ResBlock(filters=256)(x, mask)
    x = layers.UpSampling2D((2, 2))(x)
    x = ResBlock(filters=128)(x, mask)
    x = layers.UpSampling2D((2, 2))(x)
    x = tf.nn.leaky_relu(x, 0.2)
    output_image = tf.nn.tanh(layers.Conv2D(3, 4, padding="same")(x))
    return keras.Model([latent, mask], output_image, name="generator")

discriminator はセグメンテーション・マップと画像を取り、それらを連結します。そして連結された画像のパッチが本物か偽物かを予測します。

def build_discriminator(image_shape, downsample_factor=64):
    input_image_A = keras.Input(shape=image_shape, name="discriminator_image_A")
    input_image_B = keras.Input(shape=image_shape, name="discriminator_image_B")
    x = layers.Concatenate()([input_image_A, input_image_B])
    x1 = downsample(downsample_factor, 4, apply_norm=False)(x)
    x2 = downsample(2 * downsample_factor, 4)(x1)
    x3 = downsample(4 * downsample_factor, 4)(x2)
    x4 = downsample(8 * downsample_factor, 4, strides=1)(x3)
    x5 = layers.Conv2D(1, 4)(x4)
    outputs = [x1, x2, x3, x4, x5]
    return keras.Model([input_image_A, input_image_B], outputs)

 

損失関数

GauGAN は以下の損失関数を使用します :

  • Generator:
    • discriminator 予測に渡る期待値。
    • KL divergence エンコーダにより予測された平均と分散を学習するため
    • generator の特徴空間をアラインするためのオリジナルと生成画像の discriminator 予測間の最小化。
    • Perceptual loss for encouraging the generated images to have perceptual quality.

  • Discriminator:
    • def generator_loss(y):
          return -tf.reduce_mean(y)
      
      
      def kl_divergence_loss(mean, variance):
          return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance))
      
      
      class FeatureMatchingLoss(keras.losses.Loss):
          def __init__(self, **kwargs):
              super().__init__(**kwargs)
              self.mae = keras.losses.MeanAbsoluteError()
      
          def call(self, y_true, y_pred):
              loss = 0
              for i in range(len(y_true) - 1):
                  loss += self.mae(y_true[i], y_pred[i])
              return loss
      
      
      class VGGFeatureMatchingLoss(keras.losses.Loss):
          def __init__(self, **kwargs):
              super().__init__(**kwargs)
              self.encoder_layers = [
                  "block1_conv1",
                  "block2_conv1",
                  "block3_conv1",
                  "block4_conv1",
                  "block5_conv1",
              ]
              self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
              vgg = keras.applications.VGG19(include_top=False, weights="imagenet")
              layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
              self.vgg_model = keras.Model(vgg.input, layer_outputs, name="VGG")
              self.mae = keras.losses.MeanAbsoluteError()
      
          def call(self, y_true, y_pred):
              y_true = keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1))
              y_pred = keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1))
              real_features = self.vgg_model(y_true)
              fake_features = self.vgg_model(y_pred)
              loss = 0
              for i in range(len(real_features)):
                  loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
              return loss
      
      
      class DiscriminatorLoss(keras.losses.Loss):
          def __init__(self, **kwargs):
              super().__init__(**kwargs)
              self.hinge_loss = keras.losses.Hinge()
      
          def call(self, y, is_real):
              label = 1.0 if is_real else -1.0
              return self.hinge_loss(label, y)
      
    • Hinge loss

 

GAN モニタ・コールバック

次に、訓練の間に GauGAN の結果をモニタするコールバックを実装します。

class GanMonitor(keras.callbacks.Callback):
    def __init__(self, val_dataset, n_samples, epoch_interval=5):
        self.val_images = next(iter(val_dataset))
        self.n_samples = n_samples
        self.epoch_interval = epoch_interval

    def infer(self):
        latent_vector = tf.random.normal(
            shape=(self.model.batch_size, self.model.latent_dim), mean=0.0, stddev=2.0
        )
        return self.model.predict([latent_vector, self.val_images[2]])

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.epoch_interval == 0:
            generated_images = self.infer()
            for _ in range(self.n_samples):
                grid_row = min(generated_images.shape[0], 3)
                f, axarr = plt.subplots(grid_row, 3, figsize=(18, grid_row * 6))
                for row in range(grid_row):
                    ax = axarr if grid_row == 1 else axarr[row]
                    ax[0].imshow((self.val_images[0][row] + 1) / 2)
                    ax[0].axis("off")
                    ax[0].set_title("Mask", fontsize=20)
                    ax[1].imshow((self.val_images[1][row] + 1) / 2)
                    ax[1].axis("off")
                    ax[1].set_title("Ground Truth", fontsize=20)
                    ax[2].imshow((generated_images[row] + 1) / 2)
                    ax[2].axis("off")
                    ax[2].set_title("Generated", fontsize=20)
                plt.show()

 

サブクラス化された GauGAN モデル

最後に、train_step() メソッドをオーバライドして総てを (from tf.keras.Model から) サブクラス化されたモデルにまとめます。

class GauGAN(keras.Model):
    def __init__(
        self,
        image_size,
        num_classes,
        batch_size,
        latent_dim,
        feature_loss_coeff=10,
        vgg_feature_loss_coeff=0.1,
        kl_divergence_loss_coeff=0.1,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.image_size = image_size
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.image_shape = (image_size, image_size, 3)
        self.mask_shape = (image_size, image_size, num_classes)
        self.feature_loss_coeff = feature_loss_coeff
        self.vgg_feature_loss_coeff = vgg_feature_loss_coeff
        self.kl_divergence_loss_coeff = kl_divergence_loss_coeff

        self.discriminator = build_discriminator(self.image_shape)
        self.generator = build_generator(self.mask_shape)
        self.encoder = build_encoder(self.image_shape)
        self.sampler = GaussianSampler(batch_size, latent_dim)
        self.patch_size, self.combined_model = self.build_combined_generator()

        self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss")
        self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss")
        self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss")
        self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.disc_loss_tracker,
            self.gen_loss_tracker,
            self.feat_loss_tracker,
            self.vgg_loss_tracker,
            self.kl_loss_tracker,
        ]

    def build_combined_generator(self):
        # This method builds a model that takes as inputs the following:
        # latent vector, one-hot encoded segmentation label map, and
        # a segmentation map. It then (i) generates an image with the generator,
        # (ii) passes the generated images and segmentation map to the discriminator.
        # Finally, the model produces the following outputs: (a) discriminator outputs,
        # (b) generated image.
        # We will be using this model to simplify the implementation.
        self.discriminator.trainable = False
        mask_input = keras.Input(shape=self.mask_shape, name="mask")
        image_input = keras.Input(shape=self.image_shape, name="image")
        latent_input = keras.Input(shape=(self.latent_dim), name="latent")
        generated_image = self.generator([latent_input, mask_input])
        discriminator_output = self.discriminator([image_input, generated_image])
        patch_size = discriminator_output[-1].shape[1]
        combined_model = keras.Model(
            [latent_input, mask_input, image_input],
            [discriminator_output, generated_image],
        )
        return patch_size, combined_model

    def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs):
        super().compile(**kwargs)
        self.generator_optimizer = keras.optimizers.Adam(
            gen_lr, beta_1=0.0, beta_2=0.999
        )
        self.discriminator_optimizer = keras.optimizers.Adam(
            disc_lr, beta_1=0.0, beta_2=0.999
        )
        self.discriminator_loss = DiscriminatorLoss()
        self.feature_matching_loss = FeatureMatchingLoss()
        self.vgg_loss = VGGFeatureMatchingLoss()

    def train_discriminator(self, latent_vector, segmentation_map, real_image, labels):
        fake_images = self.generator([latent_vector, labels])
        with tf.GradientTape() as gradient_tape:
            pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
            pred_real = self.discriminator([segmentation_map, real_image])[-1]
            loss_fake = self.discriminator_loss(pred_fake, False)
            loss_real = self.discriminator_loss(pred_real, True)
            total_loss = 0.5 * (loss_fake + loss_real)

        self.discriminator.trainable = True
        gradients = gradient_tape.gradient(
            total_loss, self.discriminator.trainable_variables
        )
        self.discriminator_optimizer.apply_gradients(
            zip(gradients, self.discriminator.trainable_variables)
        )
        return total_loss

    def train_generator(
        self, latent_vector, segmentation_map, labels, image, mean, variance
    ):
        # Generator learns through the signal provided by the discriminator. During
        # backpropagation, we only update the generator parameters.
        self.discriminator.trainable = False
        with tf.GradientTape() as tape:
            real_d_output = self.discriminator([segmentation_map, image])
            fake_d_output, fake_image = self.combined_model(
                [latent_vector, labels, segmentation_map]
            )
            pred = fake_d_output[-1]

            # Compute generator losses.
            g_loss = generator_loss(pred)
            kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
            vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
            feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
                real_d_output, fake_d_output
            )
            total_loss = g_loss + kl_loss + vgg_loss + feature_loss

        all_trainable_variables = (
            self.combined_model.trainable_variables + self.encoder.trainable_variables
        )

        gradients = tape.gradient(total_loss, all_trainable_variables)
        self.generator_optimizer.apply_gradients(
            zip(gradients, all_trainable_variables)
        )
        return total_loss, feature_loss, vgg_loss, kl_loss

    def train_step(self, data):
        segmentation_map, image, labels = data
        mean, variance = self.encoder(image)
        latent_vector = self.sampler([mean, variance])
        discriminator_loss = self.train_discriminator(
            latent_vector, segmentation_map, image, labels
        )
        (generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator(
            latent_vector, segmentation_map, labels, image, mean, variance
        )

        # Report progress.
        self.disc_loss_tracker.update_state(discriminator_loss)
        self.gen_loss_tracker.update_state(generator_loss)
        self.feat_loss_tracker.update_state(feature_loss)
        self.vgg_loss_tracker.update_state(vgg_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        results = {m.name: m.result() for m in self.metrics}
        return results

    def test_step(self, data):
        segmentation_map, image, labels = data
        # Obtain the learned moments of the real image distribution.
        mean, variance = self.encoder(image)

        # Sample a latent from the distribution defined by the learned moments.
        latent_vector = self.sampler([mean, variance])

        # Generate the fake images.
        fake_images = self.generator([latent_vector, labels])

        # Calculate the losses.
        pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
        pred_real = self.discriminator([segmentation_map, image])[-1]
        loss_fake = self.discriminator_loss(pred_fake, False)
        loss_real = self.discriminator_loss(pred_real, True)
        total_discriminator_loss = 0.5 * (loss_fake + loss_real)
        real_d_output = self.discriminator([segmentation_map, image])
        fake_d_output, fake_image = self.combined_model(
            [latent_vector, labels, segmentation_map]
        )
        pred = fake_d_output[-1]
        g_loss = generator_loss(pred)
        kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
        vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
        feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
            real_d_output, fake_d_output
        )
        total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss

        # Report progress.
        self.disc_loss_tracker.update_state(total_discriminator_loss)
        self.gen_loss_tracker.update_state(total_generator_loss)
        self.feat_loss_tracker.update_state(feature_loss)
        self.vgg_loss_tracker.update_state(vgg_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        results = {m.name: m.result() for m in self.metrics}
        return results

    def call(self, inputs):
        latent_vectors, labels = inputs
        return self.generator([latent_vectors, labels])

 

GauGAN 訓練

gaugan = GauGAN(IMG_HEIGHT, NUM_CLASSES, BATCH_SIZE, latent_dim=256)
gaugan.compile()
history = gaugan.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=15,
    callbacks=[GanMonitor(val_dataset, BATCH_SIZE)],
)


def plot_history(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_history("disc_loss")
plot_history("gen_loss")
plot_history("feat_loss")
plot_history("vgg_loss")
plot_history("kl_loss")
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
80142336/80134624 [==============================] - 0s 0us/step
80150528/80134624 [==============================] - 0s 0us/step
Epoch 1/15
WARNING:tensorflow:Gradients do not exist for variables ['conv2d_6/kernel:0', 'conv2d_7/kernel:0', 'instance_normalization_3/gamma:0', 'instance_normalization_3/beta:0', 'conv2d_8/kernel:0', 'instance_normalization_4/gamma:0', 'instance_normalization_4/beta:0', 'conv2d_9/kernel:0', 'instance_normalization_5/gamma:0', 'instance_normalization_5/beta:0', 'conv2d_10/kernel:0', 'instance_normalization_6/gamma:0', 'instance_normalization_6/beta:0', 'mean/kernel:0', 'mean/bias:0', 'variance/kernel:0', 'variance/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d_6/kernel:0', 'conv2d_7/kernel:0', 'instance_normalization_3/gamma:0', 'instance_normalization_3/beta:0', 'conv2d_8/kernel:0', 'instance_normalization_4/gamma:0', 'instance_normalization_4/beta:0', 'conv2d_9/kernel:0', 'instance_normalization_5/gamma:0', 'instance_normalization_5/beta:0', 'conv2d_10/kernel:0', 'instance_normalization_6/gamma:0', 'instance_normalization_6/beta:0', 'mean/kernel:0', 'mean/bias:0', 'variance/kernel:0', 'variance/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
75/75 [==============================] - ETA: 0s - disc_loss: 1.1359 - gen_loss: 114.6762 - feat_loss: 9.6107 - vgg_loss: 17.5540 - kl_loss: 87.3495

75/75 [==============================] - 69s 620ms/step - disc_loss: 1.1359 - gen_loss: 114.6762 - feat_loss: 9.6107 - vgg_loss: 17.5540 - kl_loss: 87.3495 - val_disc_loss: 0.9339 - val_gen_loss: 115.9361 - val_feat_loss: 12.0053 - val_vgg_loss: 17.6576 - val_kl_loss: 86.4241
Epoch 2/15
75/75 [==============================] - 39s 522ms/step - disc_loss: 0.9191 - gen_loss: 115.4342 - feat_loss: 11.1384 - vgg_loss: 16.6970 - kl_loss: 87.1299 - val_disc_loss: 0.9073 - val_gen_loss: 117.0431 - val_feat_loss: 10.9293 - val_vgg_loss: 17.3108 - val_kl_loss: 87.1910
Epoch 3/15
75/75 [==============================] - 40s 530ms/step - disc_loss: 0.7783 - gen_loss: 116.0476 - feat_loss: 11.2018 - vgg_loss: 16.4767 - kl_loss: 87.7456 - val_disc_loss: 0.7877 - val_gen_loss: 115.6750 - val_feat_loss: 11.3406 - val_vgg_loss: 16.9246 - val_kl_loss: 87.8907
Epoch 4/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.6915 - gen_loss: 115.8905 - feat_loss: 10.7578 - vgg_loss: 16.3213 - kl_loss: 88.0270 - val_disc_loss: 0.7651 - val_gen_loss: 115.5427 - val_feat_loss: 11.5930 - val_vgg_loss: 17.0086 - val_kl_loss: 87.3675
Epoch 5/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.6652 - gen_loss: 115.3557 - feat_loss: 10.7736 - vgg_loss: 16.3333 - kl_loss: 87.4493 - val_disc_loss: 0.9139 - val_gen_loss: 115.3157 - val_feat_loss: 11.3612 - val_vgg_loss: 17.0591 - val_kl_loss: 87.6537
Epoch 6/15
75/75 [==============================] - ETA: 0s - disc_loss: 0.6541 - gen_loss: 115.2529 - feat_loss: 10.6386 - vgg_loss: 16.2342 - kl_loss: 87.5053

 

75/75 [==============================] - 43s 573ms/step - disc_loss: 0.6541 - gen_loss: 115.2529 - feat_loss: 10.6386 - vgg_loss: 16.2342 - kl_loss: 87.5053 - val_disc_loss: 0.3999 - val_gen_loss: 116.1638 - val_feat_loss: 11.1031 - val_vgg_loss: 16.9759 - val_kl_loss: 87.5684
Epoch 7/15
75/75 [==============================] - 40s 530ms/step - disc_loss: 0.6029 - gen_loss: 115.3866 - feat_loss: 10.6807 - vgg_loss: 16.2025 - kl_loss: 87.5448 - val_disc_loss: 0.4571 - val_gen_loss: 117.5491 - val_feat_loss: 10.7403 - val_vgg_loss: 16.8212 - val_kl_loss: 88.4036
Epoch 8/15
75/75 [==============================] - 39s 522ms/step - disc_loss: 0.5798 - gen_loss: 115.1903 - feat_loss: 10.5906 - vgg_loss: 16.1720 - kl_loss: 87.4315 - val_disc_loss: 0.4470 - val_gen_loss: 114.3039 - val_feat_loss: 10.8104 - val_vgg_loss: 16.9426 - val_kl_loss: 86.1938
Epoch 9/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5412 - gen_loss: 115.6245 - feat_loss: 10.5598 - vgg_loss: 16.1985 - kl_loss: 87.8032 - val_disc_loss: 0.3365 - val_gen_loss: 116.6229 - val_feat_loss: 10.9437 - val_vgg_loss: 16.9026 - val_kl_loss: 87.6305
Epoch 10/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5822 - gen_loss: 115.2743 - feat_loss: 10.5127 - vgg_loss: 16.1609 - kl_loss: 87.5487 - val_disc_loss: 0.4711 - val_gen_loss: 116.2798 - val_feat_loss: 11.4271 - val_vgg_loss: 16.7643 - val_kl_loss: 87.5068
Epoch 11/15
75/75 [==============================] - ETA: 0s - disc_loss: 0.5588 - gen_loss: 115.1784 - feat_loss: 10.5053 - vgg_loss: 16.0580 - kl_loss: 87.5601

 

75/75 [==============================] - 43s 571ms/step - disc_loss: 0.5588 - gen_loss: 115.1784 - feat_loss: 10.5053 - vgg_loss: 16.0580 - kl_loss: 87.5601 - val_disc_loss: 0.7086 - val_gen_loss: 116.6372 - val_feat_loss: 11.6103 - val_vgg_loss: 16.9735 - val_kl_loss: 88.3342
Epoch 12/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5461 - gen_loss: 115.2417 - feat_loss: 10.5498 - vgg_loss: 16.0916 - kl_loss: 87.5209 - val_disc_loss: 0.5056 - val_gen_loss: 116.8908 - val_feat_loss: 10.3318 - val_vgg_loss: 16.8074 - val_kl_loss: 87.9361
Epoch 13/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5482 - gen_loss: 114.9858 - feat_loss: 10.3671 - vgg_loss: 16.0029 - kl_loss: 87.5513 - val_disc_loss: 0.5819 - val_gen_loss: 116.6370 - val_feat_loss: 11.3866 - val_vgg_loss: 17.0710 - val_kl_loss: 88.0975
Epoch 14/15
75/75 [==============================] - 39s 522ms/step - disc_loss: 0.5596 - gen_loss: 114.5251 - feat_loss: 10.3841 - vgg_loss: 16.0361 - kl_loss: 86.9755 - val_disc_loss: 0.4472 - val_gen_loss: 115.9854 - val_feat_loss: 10.2750 - val_vgg_loss: 16.9934 - val_kl_loss: 87.2017
Epoch 15/15
75/75 [==============================] - 39s 521ms/step - disc_loss: 0.5417 - gen_loss: 114.3627 - feat_loss: 10.1977 - vgg_loss: 15.9871 - kl_loss: 87.1026 - val_disc_loss: 0.4202 - val_gen_loss: 115.1249 - val_feat_loss: 10.2537 - val_vgg_loss: 16.8011 - val_kl_loss: 86.5379

 

推論

val_iterator = iter(val_dataset)

for _ in range(5):
    val_images = next(val_iterator)
    # Sample latent from a normal distribution.
    latent_vector = tf.random.normal(
        shape=(gaugan.batch_size, gaugan.latent_dim), mean=0.0, stddev=2.0
    )
    # Generate fake images.
    fake_images = gaugan.predict([latent_vector, val_images[2]])

    real_images = val_images
    grid_row = min(fake_images.shape[0], 3)
    grid_col = 3
    f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col * 6, grid_row * 6))
    for row in range(grid_row):
        ax = axarr if grid_row == 1 else axarr[row]
        ax[0].imshow((real_images[0][row] + 1) / 2)
        ax[0].axis("off")
        ax[0].set_title("Mask", fontsize=20)
        ax[1].imshow((real_images[1][row] + 1) / 2)
        ax[1].axis("off")
        ax[1].set_title("Ground Truth", fontsize=20)
        ax[2].imshow((fake_images[row] + 1) / 2)
        ax[2].axis("off")
        ax[2].set_title("Generated", fontsize=20)
    plt.show()

 

終わりに

  • この例で使用したデータセットは小さいものです。より良い結果を得るためにはより大きなデータセットを使用することを勧めます。GauGAN の結果は COCO-Stuff と CityScapes データセットで実演されました。

  • このサンプルは Soon-Yau Cheong による Hands-On Image Generation with TensorFlow の 6 章と Divyansh Jha による Implementing SPADE using fastai にインスパイアされました。

  • If you found this example interesting and exciting, you might want to check out our repository which we are currently building. It will include reimplementations of popular GANs and pretrained models. Our focus will be on readibility and making the code as accessible as possible. Our plan is to first train our implementation of GauGAN (following the code of this example) on a bigger dataset and then make the repository public. We welcome contributions!

  • Recently GauGAN2 was also released. You can check it out here.

 

以上



クラスキャット

最近の投稿

  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要
  • LangGraph : Prebuilt エージェント : ユーザインターフェイス
  • LangGraph : Prebuilt エージェント : 配備
  • LangGraph : Prebuilt エージェント : マルチエージェント

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (20) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2022年7月
月 火 水 木 金 土 日
 123
45678910
11121314151617
18192021222324
25262728293031
« 6月   8月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme