Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – SimSiam による自己教師あり対照学習

Posted on 12/17/202112/20/2021 by Sales Information

Keras 2 : examples : SimSiam による自己教師あり対照学習 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/17/2021 (keras 2.7.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : Self-supervised contrastive learning with SimSiam (Author: Sayak Paul)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Keras 2 : examples : SimSiam による自己教師あり対照学習

Description: コンピュータビジョンのための自己教師あり学習法の実装。

 

自己教師あり学習 (SSL) は表現学習の領域の興味深い研究分野です。SSL システムはラベルのないデータポイントのコーパスから教師あり信号を定式化しようとします。例えば、与えられた単語のセットから次の単語を予測するために深層ニューラルネットワークを訓練します。文献によれば、これらのタスクは pretext タスク or auxiliary (補助的) タスクとして知られています。(Wikipedia テキストコーパス のような) 巨大なデータセット上で そのようなネットワークを訓練する 場合、それは下流タスクに上手く転移できる非常に効果的な表現を学習します。BERT, GPT-3, ELMo のような言語モデルの総てはこれから恩恵を受けます。

言語モデルのように、同様のアプローチを用いてコンピュータビジョンモデルを訓練できます。コンピュータビジョンで上手く運ぶため、基礎モデル (深層ニューラルネットワーク) がビジョンデータに存在する意味情報を理解できるように学習タスクを定式化する必要があります。そのようなタスクの一つはモデルが同じ画像の 2 つの異なるバージョンを対比させることです。このようにしてモデルが、類似の画像はできる限り一緒に、一方で似ていない画像は離れてグループ分けされるような表現を学習することが期待されます。

このサンプルでは、Exploring Simple Siamese Representation Learning で提案された、SimSiam と呼ばれるそのようなシステムの一つを実装していきます。それは以下のように実装されます :

  1. 確率的データ増強パイプラインにより同じデータセットの 2 つの異なるバージョンを作成します。これらのバージョンを作成する間、ランダム初期化シードは同じである必要があることに注意してください。

  2. 分類ヘッドのない ResNet ( バックボーン ) を取り、その上に浅い完全結合ネットワーク ( 投影ヘッド ) を追加します。合わせて、これは エンコーダ として知られています。

  3. エンコーダの出力を 予測器 (= predictor) に渡します、これは再度、浅い完全結合ネットワークで AutoEncoder ライクな構造を持ちます。

  4. そして 2 つの異なるバージョンのデータセット間のコサイン類似度を最大にするようにエンコーダを訓練します。

このサンプルは TensorFlow 2.4 またはそれ以上を必要とします。

 

セットアップ

from tensorflow.keras import layers
from tensorflow.keras import regularizers
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np

 

ハイパーパラメータの定義

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128
EPOCHS = 5
CROP_TO = 32
SEED = 26

PROJECT_DIM = 2048
LATENT_DIM = 512
WEIGHT_DECAY = 0.0005

 

CIFAR-10 データセットのロード

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
Total training examples: 50000
Total test examples: 10000

 

データ増強パイプラインの定義

SimCLR で研究されたように、適切なデータ増強パイプラインを持つことは SSL システムがコンピュータビジョンで効果的に動作するために重要です。最も重要であるように思われる 2 つの特別な増強変換は :

  1. ランダムにリサイズされたクロップと、
  2. カラー distortion (歪み) です。(BYOL, MoCoV2, SwAV 等のような) コンピュータビジョンのための他の SSL システムの殆どはこれらを訓練パイプラインに含みます。
def flip_random_crop(image):
    # With random crops we also apply horizontal flipping.
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))
    return image


def color_jitter(x, strength=[0.4, 0.4, 0.4, 0.1]):
    x = tf.image.random_brightness(x, max_delta=0.8 * strength[0])
    x = tf.image.random_contrast(
        x, lower=1 - 0.8 * strength[1], upper=1 + 0.8 * strength[1]
    )
    x = tf.image.random_saturation(
        x, lower=1 - 0.8 * strength[2], upper=1 + 0.8 * strength[2]
    )
    x = tf.image.random_hue(x, max_delta=0.2 * strength[3])
    # Affine transformations can disturb the natural range of
    # RGB images, hence this is needed.
    x = tf.clip_by_value(x, 0, 255)
    return x


def color_drop(x):
    x = tf.image.rgb_to_grayscale(x)
    x = tf.tile(x, [1, 1, 3])
    return x


def random_apply(func, x, p):
    if tf.random.uniform([], minval=0, maxval=1) < p:
        return func(x)
    else:
        return x


def custom_augment(image):
    # As discussed in the SimCLR paper, the series of augmentation
    # transformations (except for random crops) need to be applied
    # randomly to impose translational invariance.
    image = flip_random_crop(image)
    image = random_apply(color_jitter, image, p=0.8)
    image = random_apply(color_drop, image, p=0.2)
    return image

増強パイプラインは一般に扱っているデータセットの様々な特性に依存していることは留意されるべきです。例えば、データセットの画像が過度に物体中心であるならば、非常に高い確率でランダムなクロップを取ることは訓練性能を損なうかもしれません。

次に私達の増強パイプラインをデータセットに適用して幾つかの出力を可視化しましょう。

 

データを TensorFlow Dataset オブジェクトに変換する

ここでは正解ラベルなしにデータセットの 2 つの異なるバージョンを作成します。

ssl_ds_one = tf.data.Dataset.from_tensor_slices(x_train)
ssl_ds_one = (
    ssl_ds_one.shuffle(1024, seed=SEED)
    .map(custom_augment, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

ssl_ds_two = tf.data.Dataset.from_tensor_slices(x_train)
ssl_ds_two = (
    ssl_ds_two.shuffle(1024, seed=SEED)
    .map(custom_augment, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

# We then zip both of these datasets.
ssl_ds = tf.data.Dataset.zip((ssl_ds_one, ssl_ds_two))

# Visualize a few augmented images.
sample_images_one = next(iter(ssl_ds_one))
plt.figure(figsize=(10, 10))
for n in range(25):
    ax = plt.subplot(5, 5, n + 1)
    plt.imshow(sample_images_one[n].numpy().astype("int"))
    plt.axis("off")
plt.show()

# Ensure that the different versions of the dataset actually contain
# identical images.
sample_images_two = next(iter(ssl_ds_two))
plt.figure(figsize=(10, 10))
for n in range(25):
    ax = plt.subplot(5, 5, n + 1)
    plt.imshow(sample_images_two[n].numpy().astype("int"))
    plt.axis("off")
plt.show()

samples_images_one と sample_images_two の画像は基本的に同じですが、異なって増強されていることに気づいてください。

 

エンコーダと予測器の定義

CIFAR10 データセットのために特に設定されたあ ResNet20 の実装を利用します。コードは keras-idiomatic-programmer レポジトリから引用されています。これらのアーキテクチャのハイパーパラメータは 原論文 の Section 3 と Appendix から参照されています。

!wget -q https://git.io/JYx2x -O resnet_cifar10_v2.py
import resnet_cifar10_v2

N = 2
DEPTH = N * 9 + 2
NUM_BLOCKS = ((DEPTH - 2) // 9) - 1


def get_encoder():
    # Input and backbone.
    inputs = layers.Input((CROP_TO, CROP_TO, 3))
    x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(
        inputs
    )
    x = resnet_cifar10_v2.stem(x)
    x = resnet_cifar10_v2.learner(x, NUM_BLOCKS)
    x = layers.GlobalAveragePooling2D(name="backbone_pool")(x)

    # Projection head.
    x = layers.Dense(
        PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dense(
        PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
    )(x)
    outputs = layers.BatchNormalization()(x)
    return tf.keras.Model(inputs, outputs, name="encoder")


def get_predictor():
    model = tf.keras.Sequential(
        [
            # Note the AutoEncoder-like structure.
            layers.Input((PROJECT_DIM,)),
            layers.Dense(
                LATENT_DIM,
                use_bias=False,
                kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
            ),
            layers.ReLU(),
            layers.BatchNormalization(),
            layers.Dense(PROJECT_DIM),
        ],
        name="predictor",
    )
    return model

 

(事前) 訓練ループの定義

これらの種類のアプローチを持つネットワークの訓練の裏にある主要な理由の一つは、分類のような下流タスクのために学習された表現を活用することです。これがこの特定の訓練段階が事前訓練としても参照される理由です。

損失関数を定義することから始めます。

def compute_loss(p, z):
    # The authors of SimSiam emphasize the impact of
    # the `stop_gradient` operator in the paper as it
    # has an important role in the overall optimization.
    z = tf.stop_gradient(z)
    p = tf.math.l2_normalize(p, axis=1)
    z = tf.math.l2_normalize(z, axis=1)
    # Negative cosine similarity (minimizing this is
    # equivalent to maximizing the similarity).
    return -tf.reduce_mean(tf.reduce_sum((p * z), axis=1))

次に tf.keras.Model クラスの train_step() 関数を override することで訓練ループを定義します。

class SimSiam(tf.keras.Model):
    def __init__(self, encoder, predictor):
        super(SimSiam, self).__init__()
        self.encoder = encoder
        self.predictor = predictor
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

    @property
    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, data):
        # Unpack the data.
        ds_one, ds_two = data

        # Forward pass through the encoder and predictor.
        with tf.GradientTape() as tape:
            z1, z2 = self.encoder(ds_one), self.encoder(ds_two)
            p1, p2 = self.predictor(z1), self.predictor(z2)
            # Note that here we are enforcing the network to match
            # the representations of two differently augmented batches
            # of data.
            loss = compute_loss(p1, z2) / 2 + compute_loss(p2, z1) / 2

        # Compute gradients and update the parameters.
        learnable_params = (
            self.encoder.trainable_variables + self.predictor.trainable_variables
        )
        gradients = tape.gradient(loss, learnable_params)
        self.optimizer.apply_gradients(zip(gradients, learnable_params))

        # Monitor loss.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

 

ネットワークの事前訓練

このサンプルのためには、モデルを 5 エポックだけ訓練します。実際には、これは少なくとも 100 エポックであるべきです。

# Create a cosine decay learning scheduler.
num_training_samples = len(x_train)
steps = EPOCHS * (num_training_samples // BATCH_SIZE)
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=0.03, decay_steps=steps
)

# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="loss", patience=5, restore_best_weights=True
)

# Compile model and start training.
simsiam = SimSiam(get_encoder(), get_predictor())
simsiam.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))
history = simsiam.fit(ssl_ds, epochs=EPOCHS, callbacks=[early_stopping])

# Visualize the training progress of the model.
plt.plot(history.history["loss"])
plt.grid()
plt.title("Negative Cosine Similairty")
plt.show()
Epoch 1/5
391/391 [==============================] - 33s 42ms/step - loss: -0.8973
Epoch 2/5
391/391 [==============================] - 16s 40ms/step - loss: -0.9129
Epoch 3/5
391/391 [==============================] - 16s 40ms/step - loss: -0.9165
Epoch 4/5
391/391 [==============================] - 16s 40ms/step - loss: -0.9176
Epoch 5/5
391/391 [==============================] - 16s 40ms/step - loss: -0.9182

(訳者注: 実験結果 - 100 epochs)

Epoch 1/100
391/391 [==============================] - 27s 35ms/step - loss: -0.8702
Epoch 2/100
391/391 [==============================] - 13s 34ms/step - loss: -0.8956
Epoch 3/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9062
Epoch 4/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9203
Epoch 5/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9327
Epoch 6/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9396
Epoch 7/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9434
Epoch 8/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9455
Epoch 9/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9483
Epoch 10/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9505
Epoch 11/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9522
Epoch 12/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9541
Epoch 13/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9562
Epoch 14/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9581
Epoch 15/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9600
Epoch 16/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9613
Epoch 17/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9629
Epoch 18/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9644
Epoch 19/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9655
Epoch 20/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9663
Epoch 21/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9672
Epoch 22/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9682
Epoch 23/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9691
Epoch 24/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9702
Epoch 25/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9713
Epoch 26/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9718
Epoch 27/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9729
Epoch 28/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9734
Epoch 29/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9742
Epoch 30/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9746
Epoch 31/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9756
Epoch 32/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9757
Epoch 33/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9762
Epoch 34/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9766
Epoch 35/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9767
Epoch 36/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9771
Epoch 37/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9774
Epoch 38/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9777
Epoch 39/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9779
Epoch 40/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9781
Epoch 41/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9783
Epoch 42/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9783
Epoch 43/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9787
Epoch 44/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9788
Epoch 45/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9789
Epoch 46/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9788
Epoch 47/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9791
Epoch 48/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9789
Epoch 49/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9792
Epoch 50/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9794
Epoch 51/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9796
Epoch 52/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9797
Epoch 53/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9797
Epoch 54/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9799
Epoch 55/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9799
Epoch 56/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9799
Epoch 57/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9798
Epoch 58/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9799
Epoch 59/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9797
Epoch 60/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9800
Epoch 61/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9798
Epoch 62/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9800
Epoch 63/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9800
Epoch 64/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9799
Epoch 65/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9801
Epoch 66/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9800
Epoch 67/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9801
Epoch 68/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9802
Epoch 69/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9803
Epoch 70/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9802
Epoch 71/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9803
Epoch 72/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9802
Epoch 73/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9803
Epoch 74/100
391/391 [==============================] - 13s 34ms/step - loss: -0.9803

異なるデータセットと異なるバックボーン・アーキテクチャで解が非常に素早く -1 (損失の最小値) に近づく場合、それは表現 collapse による傾向があります。それはエンコーダが総ての画像に対して同様の出力を生成する現象です。その場合には特に以下の領域で追加のハイパーパラメータ調整が必要です :

  • カラー distortion の強度とそれらの確率。
  • 学習率とそのスケジュール
  • バックボーンとそれらの投影ヘッドの両方のアーキテクチャ

 

SSL 法を評価する

コンピュータビジョンで SSL 法 (あるいはそのような別の事前訓練法) を評価するために最も一般的に使用される手法は、訓練されたバックボーンモデル (この場合は ResNet20) の凍結された特徴の上で線形分類器を学習して未見の画像で分類器を評価することです。他の手法はソースターゲットあるいは 5% か 10% のラベルが存在するターゲットデータセット上での 再調整 も含みます。実践的には、セマンティック・セグメンテーション, 物体検出, 等の任意の下流タスクに対してバックボーンモデルを使用できます、そこではバックボーンモデルは通常は純粋な教師あり学習で事前訓練されます。

# We first create labeled `Dataset` objects.
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# Then we shuffle, batch, and prefetch this dataset for performance. We
# also apply random resized crops as an augmentation but only to the
# training set.
train_ds = (
    train_ds.shuffle(1024)
    .map(lambda x, y: (flip_random_crop(x), y), num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)

# Extract the backbone ResNet20.
backbone = tf.keras.Model(
    simsiam.encoder.input, simsiam.encoder.get_layer("backbone_pool").output
)

# We then create our linear classifier and train it.
backbone.trainable = False
inputs = layers.Input((CROP_TO, CROP_TO, 3))
x = backbone(inputs, training=False)
outputs = layers.Dense(10, activation="softmax")(x)
linear_model = tf.keras.Model(inputs, outputs, name="linear_model")

# Compile model and start training.
linear_model.compile(
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
    optimizer=tf.keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),
)
history = linear_model.fit(
    train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=[early_stopping]
)
_, test_acc = linear_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
Epoch 1/5
391/391 [==============================] - 7s 11ms/step - loss: 3.8072 - accuracy: 0.1527 - val_loss: 3.7449 - val_accuracy: 0.2046
Epoch 2/5
391/391 [==============================] - 3s 8ms/step - loss: 3.7356 - accuracy: 0.2107 - val_loss: 3.7055 - val_accuracy: 0.2308
Epoch 3/5
391/391 [==============================] - 3s 8ms/step - loss: 3.7036 - accuracy: 0.2228 - val_loss: 3.6874 - val_accuracy: 0.2329
Epoch 4/5
391/391 [==============================] - 3s 8ms/step - loss: 3.6893 - accuracy: 0.2276 - val_loss: 3.6808 - val_accuracy: 0.2334
Epoch 5/5
391/391 [==============================] - 3s 9ms/step - loss: 3.6845 - accuracy: 0.2305 - val_loss: 3.6798 - val_accuracy: 0.2339
79/79 [==============================] - 1s 7ms/step - loss: 3.6798 - accuracy: 0.2339
Test accuracy: 23.39%
Epoch 1/100
391/391 [==============================] - 7s 11ms/step - loss: 3.7784 - accuracy: 0.2017 - val_loss: 3.7345 - val_accuracy: 0.2350
Epoch 2/100
391/391 [==============================] - 4s 9ms/step - loss: 3.7007 - accuracy: 0.2508 - val_loss: 3.6701 - val_accuracy: 0.2599
Epoch 3/100
391/391 [==============================] - 3s 9ms/step - loss: 3.6461 - accuracy: 0.2650 - val_loss: 3.6240 - val_accuracy: 0.2728
Epoch 4/100
391/391 [==============================] - 3s 9ms/step - loss: 3.6060 - accuracy: 0.2728 - val_loss: 3.5892 - val_accuracy: 0.2798
Epoch 5/100
391/391 [==============================] - 3s 9ms/step - loss: 3.5752 - accuracy: 0.2787 - val_loss: 3.5624 - val_accuracy: 0.2872

...

Epoch 90/100
391/391 [==============================] - 3s 9ms/step - loss: 3.3543 - accuracy: 0.3317 - val_loss: 3.3558 - val_accuracy: 0.3357
Epoch 91/100
391/391 [==============================] - 3s 9ms/step - loss: 3.3545 - accuracy: 0.3311 - val_loss: 3.3558 - val_accuracy: 0.3355
79/79 [==============================] - 1s 7ms/step - loss: 3.3560 - accuracy: 0.3353
Test accuracy: 33.53%
 

以上



クラスキャット

最近の投稿

  • LangGraph on Colab : エージェント型 RAG
  • LangGraph : 例題 : エージェント型 RAG
  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要
  • LangGraph : Prebuilt エージェント : ユーザインターフェイス

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (22) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年12月
月 火 水 木 金 土 日
 12345
6789101112
13141516171819
20212223242526
2728293031  
« 11月   3月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme