Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – AdaMatch による半教師あり学習とドメイン適応

Posted on 11/07/202111/12/2021 by Sales Information

Keras 2 : examples : AdaMatch による半教師あり学習とドメイン適応 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/07/2021 (keras 2.6.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : Semi-supervision and domain adaptation with AdaMatch (Author: Sayak Paul)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス ★ 無料 Web セミナー開催中 ★

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援

  • テレワーク & オンライン授業を支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

 

Keras 2 : examples : AdaMatch による半教師あり学習とドメイン適応

イントロダクション

このサンプルでは、AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation (AdaMatch: 半教師あり学習とドメイン適応への統一的なアプローチ) by Berthelot et al で提案された、AdaMatch アルゴリズムを実装します。それは教師なしドメイン適応の新しい最先端技術です (2021/6 の時点で)。AdaMatch は特に興味深いです、何故ならばそれは一つのフレームワークの下で半教師あり学習 (SSL) と教師なしドメイン適応 (UDA) を統合するからです。従って半教師ありドメイン適応 (SSDA) を実行する方法を提供します。

このサンプルは TensorFlow 2.5 またはそれ以上、及び TensorFlow Models を必要とします、これは次のコマンドでインストールできます :

!pip install -q tf-models-official

先に進む前に、このサンプルの基礎となる幾つかの予備的なコンセプトをレビューしましょう。

 

準備

半教師あり学習 (SSL) では、より大きなラベル付けられていないデータセット上のモデルを訓練するために少量のラベル付きデータを使用します。コンピュータビジョンのためのポピュラーな半教師あり学習法は FixMatch, MixMatch, Noisy Student Training 等を含みます。標準的な SSL ワークフローがどのようなものか考えを得るために このサンプル を参考にできます。

教師なしドメイン適応 では、ソースのラベル付けられたデータセットとターゲットのラベル付けされていないデータセットへのアクセスを持ちます。そしてタスクはターゲット・データセットに上手く一般化できるモデルを学習することです。ソースとターゲット・データセットは分布の観点から変化します。次の図はこの考えの図示しています。現在のサンプルでは、ソース・データセットとして MNIST を、ターゲット・データセットとして SVHN を使用しています、これは家のナンバーの画像から成ります。両者のデータセットはテクスチャ, 視点, 外観 等の観点から様々な変化する要因を持ちます : それらのドメイン、分布は互いに異なります。

深層学習のポピュラーなドメイン適応アルゴリズムは Deep CORAL, Moment Matching 等を含みます。

 

セットアップ

import tensorflow as tf

tf.random.set_seed(42)

import numpy as np

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from official.vision.image_classification.augment import RandAugment

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

 

データの準備

# MNIST
(
    (mnist_x_train, mnist_y_train),
    (mnist_x_test, mnist_y_test),
) = keras.datasets.mnist.load_data()

# Add a channel dimension
mnist_x_train = tf.expand_dims(mnist_x_train, -1)
mnist_x_test = tf.expand_dims(mnist_x_test, -1)

# Convert the labels to one-hot encoded vectors
mnist_y_train = tf.one_hot(mnist_y_train, 10).numpy()

# SVHN
svhn_train, svhn_test = tfds.load(
    "svhn_cropped", split=["train", "test"], as_supervised=True
)

 

定数とハイパーパラメータを定義する

RESIZE_TO = 32

SOURCE_BATCH_SIZE = 64
TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE  # Reference: Section 3.2
EPOCHS = 10
STEPS_PER_EPOCH = len(mnist_x_train) // SOURCE_BATCH_SIZE
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH

AUTO = tf.data.AUTOTUNE
LEARNING_RATE = 0.03

WEIGHT_DECAY = 0.0005
INIT = "he_normal"
DEPTH = 28
WIDTH_MULT = 2

 

データ増強ユティリティ

SSL アルゴリズムの標準的な要素は、 学習モデルに予測に一貫性を持たせるために同じ画像の弱くそして強く増強されたバージョンを供給することです。 強い増強については、RandAugment が標準的な選択です。弱い増強については、水平反転とランダム・クロッピングを使用します。

# Initialize `RandAugment` object with 2 layers of
# augmentation transforms and strength of 5.
augmenter = RandAugment(num_layers=2, magnitude=5)


def weak_augment(image, source=True):
    if image.dtype != tf.float32:
        image = tf.cast(image, tf.float32)

    # MNIST images are grayscale, this is why we first convert them to
    # RGB images.
    if source:
        image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
        image = tf.tile(image, [1, 1, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3))
    return image


def strong_augment(image, source=True):
    if image.dtype != tf.float32:
        image = tf.cast(image, tf.float32)

    if source:
        image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
        image = tf.tile(image, [1, 1, 3])
    image = augmenter.distort(image)
    return image

 

データ・ローディング・ユティリティ

def create_individual_ds(ds, aug_func, source=True):
    if source:
        batch_size = SOURCE_BATCH_SIZE
    else:
        # During training 3x more target unlabeled samples are shown
        # to the model in AdaMatch (Section 3.2 of the paper).
        batch_size = TARGET_BATCH_SIZE
    ds = ds.shuffle(batch_size * 10, seed=42)

    if source:
        ds = ds.map(lambda x, y: (aug_func(x), y), num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda x, y: (aug_func(x, False), y), num_parallel_calls=AUTO)

    ds = ds.batch(batch_size).prefetch(AUTO)
    return ds

_w と _s のサフィックスはそれぞれ弱いと強いを表します。

source_ds = tf.data.Dataset.from_tensor_slices((mnist_x_train, mnist_y_train))
source_ds_w = create_individual_ds(source_ds, weak_augment)
source_ds_s = create_individual_ds(source_ds, strong_augment)
final_source_ds = tf.data.Dataset.zip((source_ds_w, source_ds_s))

target_ds_w = create_individual_ds(svhn_train, weak_augment, source=False)
target_ds_s = create_individual_ds(svhn_train, strong_augment, source=False)
final_target_ds = tf.data.Dataset.zip((target_ds_w, target_ds_s))

ここにシングル画像バッチがどのように見えるかがあります :

 

損失計算ユティリティ

def compute_loss_source(source_labels, logits_source_w, logits_source_s):
    loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
    # First compute the losses between original source labels and
    # predictions made on the weakly and strongly augmented versions
    # of the same images.
    w_loss = loss_func(source_labels, logits_source_w)
    s_loss = loss_func(source_labels, logits_source_s)
    return w_loss + s_loss


def compute_loss_target(target_pseudo_labels_w, logits_target_s, mask):
    loss_func = keras.losses.CategoricalCrossentropy(from_logits=True, reduction="none")
    target_pseudo_labels_w = tf.stop_gradient(target_pseudo_labels_w)
    # For calculating loss for the target samples, we treat the pseudo labels
    # as the ground-truth. These are not considered during backpropagation
    # which is a standard SSL practice.
    target_loss = loss_func(target_pseudo_labels_w, logits_target_s)

    # More on `mask` later.
    mask = tf.cast(mask, target_loss.dtype)
    target_loss *= mask
    return tf.reduce_mean(target_loss, 0)

 

AdaMatch 訓練のためのサブクラス化モデル

下図は AdaMatch の全体的なワークフローを表しています (元の論文 からの引用) :

ここにワークフローの簡潔なステップ毎の分解があります :

  1. 最初にソースとターゲットのデータセットから画像の弱いそして強い増強のペアを取得します。

  2. 2 つの連結されたコピーを準備します : i. 両者のペアが連結されたもの。ii. ソースデータ画像ペアだけが連結されたもの。

  3. モデルを通して 2 つの forward パスを実行します : i. 最初の forwrad パスは 2.i から得られた連結されたコピーを使用します。この forward パスでは、Batch Normalization 統計が更新されます。ii. 2 番目の forward パスでは、2.ii で得られた連結されたコピーだけを使用します。Batch Normalization 層は推論モードで実行されます。

  4. 両方の forward パスに対してそれぞれのロジットが計算されます。

  5. ロジットは論文で紹介されている、変換のシリーズを通り抜けます (これを短く説明します)。

  6. 損失を計算して基礎となるモデルの勾配を更新します。
class AdaMatch(keras.Model):
    def __init__(self, model, total_steps, tau=0.9):
        super(AdaMatch, self).__init__()
        self.model = model
        self.tau = tau  # Denotes the confidence threshold
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.total_steps = total_steps
        self.current_step = tf.Variable(0, dtype="int64")

    @property
    def metrics(self):
        return [self.loss_tracker]

    # This is a warmup schedule to update the weight of the
    # loss contributed by the target unlabeled samples. More
    # on this in the text.
    def compute_mu(self):
        pi = tf.constant(np.pi, dtype="float32")
        step = tf.cast(self.current_step, dtype="float32")
        return 0.5 - tf.cos(tf.math.minimum(pi, (2 * pi * step) / self.total_steps)) / 2

    def train_step(self, data):
        ## Unpack and organize the data ##
        source_ds, target_ds = data
        (source_w, source_labels), (source_s, _) = source_ds
        (
            (target_w, _),
            (target_s, _),
        ) = target_ds  # Notice that we are NOT using any labels here.

        combined_images = tf.concat([source_w, source_s, target_w, target_s], 0)
        combined_source = tf.concat([source_w, source_s], 0)

        total_source = tf.shape(combined_source)[0]
        total_target = tf.shape(tf.concat([target_w, target_s], 0))[0]

        with tf.GradientTape() as tape:
            ## Forward passes ##
            combined_logits = self.model(combined_images, training=True)
            z_d_prime_source = self.model(
                combined_source, training=False
            )  # No BatchNorm update.
            z_prime_source = combined_logits[:total_source]

            ## 1. Random logit interpolation for the source images ##
            lambd = tf.random.uniform((total_source, 10), 0, 1)
            final_source_logits = (lambd * z_prime_source) + (
                (1 - lambd) * z_d_prime_source
            )

            ## 2. Distribution alignment (only consider weakly augmented images) ##
            # Compute softmax for logits of the WEAKLY augmented SOURCE images.
            y_hat_source_w = tf.nn.softmax(final_source_logits[: tf.shape(source_w)[0]])

            # Extract logits for the WEAKLY augmented TARGET images and compute softmax.
            logits_target = combined_logits[total_source:]
            logits_target_w = logits_target[: tf.shape(target_w)[0]]
            y_hat_target_w = tf.nn.softmax(logits_target_w)

            # Align the target label distribution to that of the source.
            expectation_ratio = tf.reduce_mean(y_hat_source_w) / tf.reduce_mean(
                y_hat_target_w
            )
            y_tilde_target_w = tf.math.l2_normalize(
                y_hat_target_w * expectation_ratio, 1
            )

            ## 3. Relative confidence thresholding ##
            row_wise_max = tf.reduce_max(y_hat_source_w, axis=-1)
            final_sum = tf.reduce_mean(row_wise_max, 0)
            c_tau = self.tau * final_sum
            mask = tf.reduce_max(y_tilde_target_w, axis=-1) >= c_tau

            ## Compute losses (pay attention to the indexing) ##
            source_loss = compute_loss_source(
                source_labels,
                final_source_logits[: tf.shape(source_w)[0]],
                final_source_logits[tf.shape(source_w)[0] :],
            )
            target_loss = compute_loss_target(
                y_tilde_target_w, logits_target[tf.shape(target_w)[0] :], mask
            )

            t = self.compute_mu()  # Compute weight for the target loss
            total_loss = source_loss + (t * target_loss)
            self.current_step.assign_add(
                1
            )  # Update current training step for the scheduler

        gradients = tape.gradient(total_loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.loss_tracker.update_state(total_loss)
        return {"loss": self.loss_tracker.result()}

著者は論文で 3 つの改良を紹介しています :

  • AdaMatch では、2 つの forward パスを実行してそれらの 1 つだけが Batch Normalization 統計の更新について責任を負います。これはターゲットデータセットにおける分布シフトを説明するために行なわれます。他方の forward パスでは、ソースサンプルだけを使用して、Batch Normalization 層は推論モードで実行されます。これら 2 つのパスからのソースサンプル (弱いそして強い増強バージョン) のためのロジットは Batch Normalization がどのように実行されるかにより互いに少し異なります。ソースサンプルのための最終的なロジットはこれら 2 つの異なるロジットペアの間の線形補間により計算されます。これは一貫性正則化の形式を誘導します。このステップは ランダム・ロジット補間 と呼ばれています。

  • 分布アラインメント (= Distribution alignment) がソースとターゲット・ラベル分布を揃える (= align) ために使用されます。これは基礎的なモデルがドメイン不変な表現を学習するのに更に役立ちます。教師なしドメイン適応の場合、ターゲットデータセットの任意のラベルへのアクセスを持ちません。これが疑似ラベルが基礎的なモデルから生成される理由です。

  • 基礎となるモデルはターゲットサンプルのために疑似ラベルを生成します。モデルが不完全な予測を作成することは可能性は高いです。それらは訓練が進むにつれて逆伝播し、全体のパフォーマンスに悪影響を与える可能性があります。それを補うため、閾値に基づいて高い信頼度の予測をフィルタリングします (そのため compute_loss_target() 内でマスクを使用しています)。AdaMatch では、この閾値は相対的に調整されますので、それが relative confidence thresholding (相対的信頼度閾値) と呼ばれる理由です。

これらの方法の詳細とそれらの各々がどのように寄与するかを知るには、論文 を参照してください。

About compute_mu():

AdaMatch では固定スカラー量を使用するのではなく、変換するスカラーが使用されます。それはターゲットサンプルにより寄与される損失の重みを表します。視覚的には、重みスケジューラは次のようなものです :

 

Wide-ResNet-28-2 のインスタンス化

著者はこのサンプルで使用するデータセット・ペアのために WideResNet-28-2 を使用しています。以下のコードの殆どは このスクリプト から参照されています。次のモデルはその内部にピクセル値を [0, 1] にスケールするスケーリング層を持つことに注意してください。

def wide_basic(x, n_input_plane, n_output_plane, stride):
    conv_params = [[3, 3, stride, "same"], [3, 3, (1, 1), "same"]]

    n_bottleneck_plane = n_output_plane

    # Residual block
    for i, v in enumerate(conv_params):
        if i == 0:
            if n_input_plane != n_output_plane:
                x = layers.BatchNormalization()(x)
                x = layers.Activation("relu")(x)
                convs = x
            else:
                convs = layers.BatchNormalization()(x)
                convs = layers.Activation("relu")(convs)
            convs = layers.Conv2D(
                n_bottleneck_plane,
                (v[0], v[1]),
                strides=v[2],
                padding=v[3],
                kernel_initializer=INIT,
                kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
                use_bias=False,
            )(convs)
        else:
            convs = layers.BatchNormalization()(convs)
            convs = layers.Activation("relu")(convs)
            convs = layers.Conv2D(
                n_bottleneck_plane,
                (v[0], v[1]),
                strides=v[2],
                padding=v[3],
                kernel_initializer=INIT,
                kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
                use_bias=False,
            )(convs)

    # Shortcut connection: identity function or 1x1
    # convolutional
    #  (depends on difference between input & output shape - this
    #   corresponds to whether we are using the first block in
    #   each
    #   group; see `block_series()`).
    if n_input_plane != n_output_plane:
        shortcut = layers.Conv2D(
            n_output_plane,
            (1, 1),
            strides=stride,
            padding="same",
            kernel_initializer=INIT,
            kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
            use_bias=False,
        )(x)
    else:
        shortcut = x

    return layers.Add()([convs, shortcut])


# Stacking residual units on the same stage
def block_series(x, n_input_plane, n_output_plane, count, stride):
    x = wide_basic(x, n_input_plane, n_output_plane, stride)
    for i in range(2, int(count + 1)):
        x = wide_basic(x, n_output_plane, n_output_plane, stride=1)
    return x


def get_network(image_size=32, num_classes=10):
    n = (DEPTH - 4) / 6
    n_stages = [16, 16 * WIDTH_MULT, 32 * WIDTH_MULT, 64 * WIDTH_MULT]

    inputs = keras.Input(shape=(image_size, image_size, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)

    conv1 = layers.Conv2D(
        n_stages[0],
        (3, 3),
        strides=1,
        padding="same",
        kernel_initializer=INIT,
        kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
        use_bias=False,
    )(x)

    ## Add wide residual blocks ##

    conv2 = block_series(
        conv1,
        n_input_plane=n_stages[0],
        n_output_plane=n_stages[1],
        count=n,
        stride=(1, 1),
    )  # Stage 1

    conv3 = block_series(
        conv2,
        n_input_plane=n_stages[1],
        n_output_plane=n_stages[2],
        count=n,
        stride=(2, 2),
    )  # Stage 2

    conv4 = block_series(
        conv3,
        n_input_plane=n_stages[2],
        n_output_plane=n_stages[3],
        count=n,
        stride=(2, 2),
    )  # Stage 3

    batch_norm = layers.BatchNormalization()(conv4)
    relu = layers.Activation("relu")(batch_norm)

    # Classifier
    trunk_outputs = layers.GlobalAveragePooling2D()(relu)
    outputs = layers.Dense(
        num_classes, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
    )(trunk_outputs)

    return keras.Model(inputs, outputs)

今ではそのようにして Wide ResNet をインスタンス化できます。ここで Wide ResNet を使用する目的は実装を元のものに出来る限り近付けるためであることに注意してください。

wrn_model = get_network()
print(f"Model has {wrn_model.count_params()/1e6} Million parameters.")
Model has 1.471226 Million parameters.

 

AdaMatch モデルをインスタンス化してそれをコンパイルする

reduce_lr = keras.optimizers.schedules.CosineDecay(LEARNING_RATE, TOTAL_STEPS, 0.25)
optimizer = keras.optimizers.Adam(reduce_lr)

adamatch_trainer = AdaMatch(model=wrn_model, total_steps=TOTAL_STEPS)
adamatch_trainer.compile(optimizer=optimizer)

 

モデル訓練

total_ds = tf.data.Dataset.zip((final_source_ds, final_target_ds))
adamatch_trainer.fit(total_ds, epochs=EPOCHS)
Epoch 1/10
382/382 [==============================] - 53s 96ms/step - loss: 117866954752.0000
Epoch 2/10
382/382 [==============================] - 36s 95ms/step - loss: 2.6231
Epoch 3/10
382/382 [==============================] - 36s 94ms/step - loss: 4.1699
Epoch 4/10
382/382 [==============================] - 36s 95ms/step - loss: 8.2748
Epoch 5/10
382/382 [==============================] - 36s 95ms/step - loss: 28.8679
Epoch 6/10
382/382 [==============================] - 36s 94ms/step - loss: 14.7112
Epoch 7/10
382/382 [==============================] - 36s 94ms/step - loss: 7.8206
Epoch 8/10
382/382 [==============================] - 36s 94ms/step - loss: 18.1182
Epoch 9/10
382/382 [==============================] - 36s 94ms/step - loss: 22.4258
Epoch 10/10
382/382 [==============================] - 36s 95ms/step - loss: 22.1107

<tensorflow.python.keras.callbacks.History at 0x7f9bc4990b50>

 

ターゲットとソース・テストセット上で評価

# Compile the AdaMatch model to yield accuracy.
adamatch_trained_model = adamatch_trainer.model
adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())

# Score on the target test set.
svhn_test = svhn_test.batch(TARGET_BATCH_SIZE).prefetch(AUTO)
_, accuracy = adamatch_trained_model.evaluate(svhn_test)
print(f"Accuracy on target test set: {accuracy * 100:.2f}%")
136/136 [==============================] - 2s 10ms/step - loss: 572.9810 - sparse_categorical_accuracy: 0.1960
Accuracy on target test set: 19.11%

より訓練すれば、このスコアは向上します。この同じネットワークが標準的な分類目的 (関数) で訓練される場合、それは 7.20% の精度を生成し、これは AdaMatch で得たものよりも遥かに低いです。ハイパーパラメータと他の実験の詳細について学習するために このノートブック を確認できます。

# Utility function for preprocessing the source test set.
def prepare_test_ds_source(image, label):
    image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
    image = tf.tile(image, [1, 1, 3])
    return image, label


source_test_ds = tf.data.Dataset.from_tensor_slices((mnist_x_test, mnist_y_test))
source_test_ds = (
    source_test_ds.map(prepare_test_ds_source, num_parallel_calls=AUTO)
    .batch(TARGET_BATCH_SIZE)
    .prefetch(AUTO)
)

# Evaluation on the source test set.
_, accuracy = adamatch_trained_model.evaluate(source_test_ds)
print(f"Accuracy on source test set: {accuracy * 100:.2f}%")
53/53 [==============================] - 1s 10ms/step - loss: 572.9810 - sparse_categorical_accuracy: 0.6532
Accuracy on source test set: 65.32%

これらの モデル重み を使用することにより結果を再現できます。

 

以上



クラスキャット

最近の投稿

  • LangGraph on Colab : エージェント型 RAG
  • LangGraph : 例題 : エージェント型 RAG
  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要
  • LangGraph : Prebuilt エージェント : ユーザインターフェイス

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (22) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年11月
月 火 水 木 金 土 日
1234567
891011121314
15161718192021
22232425262728
2930  
« 10月   12月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme