Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

Keras 2 : examples : コンピュータビジョン – セマンティック画像クラスタリング

Posted on 12/13/202112/16/2021 by Sales Information

Keras 2 : examples : セマンティック画像クラスタリング (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/13/2021 (keras 2.7.0)

* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Code examples : Computer Vision : Semantic Image Clustering (Author: Khalid Salama)

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

  • 人工知能研究開発支援
    1. 人工知能研修サービス(経営者層向けオンサイト研修)
    2. テクニカルコンサルティングサービス
    3. 実証実験(プロトタイプ構築)
    4. アプリケーションへの実装

  • 人工知能研修サービス

  • PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Keras 2 : examples : セマンティック画像クラスタリング

Description: Adopting Nearest neighbors (SCAN) アルゴリズムによるセマンティック・クラスタリング。

 

イントロダクション

このサンプルは CIFAR-10 データセット上で Semantic Clustering by Adopting Nearest neighbors (SCAN) アルゴリズム (Van Gansbeke et al., 2020) を適用する方法を実演します。このアルゴリズムは 2 つの段階から構成されます :

  1. 画像の自己教師あり視覚表現学習、そこでは simCLR テクニックを使用します。
  2. 隣接するベクトルのクラスタ割当て間の一致 (= agreement) を最大化するために学習された視覚表現ベクトルをクラスタリングする。

このサンプルは TensorFlow Addons を必要とします、これは次のコマンドを使用してインストールできます :

pip install tensorflow-addons

 

セットアップ

from collections import defaultdict
import random
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from tqdm import tqdm

 

データの準備

num_classes = 10
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_data = np.concatenate([x_train, x_test])
y_data = np.concatenate([y_train, y_test])

print("x_data shape:", x_data.shape, "- y_data shape:", y_data.shape)

classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]
x_data shape: (60000, 32, 32, 3) - y_data shape: (60000, 1)

 

ハイパーパラメータの定義

target_size = 32  # Resize the input images.
representation_dim = 512  # The dimensions of the features vector.
projection_units = 128  # The projection head of the representation learner.
num_clusters = 20  # Number of clusters.
k_neighbours = 5  # Number of neighbours to consider during cluster learning.
tune_encoder_during_clustering = False  # Freeze the encoder in the cluster learning.

 

データ前処理の実装

データ前処理ステップは入力画像を望まれる target_size にリサイズして特徴毎に正規化を適用します。視覚エンコーダとして keras.applications.ResNet50V2 を使用するとき、画像の 255 x 255 入力へのリサイズはより正確な結果になりますがより長時間の訓練が必要となることに注意してください。

data_preprocessing = keras.Sequential(
    [
        layers.Resizing(target_size, target_size),
        layers.Normalization(),
    ]
)
# Compute the mean and the variance from the data for normalization.
data_preprocessing.layers[-1].adapt(x_data)

 

データ増強

入力画像に適用する単一のデータ増強関数をランダムに選択する simCLR とは違い、データ増強関数のセットを入力画像にランダムに適用します。(データ増強チュートリアル に従って他の画像増強テクニックで実験することができます。)

data_augmentation = keras.Sequential(
    [
        layers.RandomTranslation(
            height_factor=(-0.2, 0.2), width_factor=(-0.2, 0.2), fill_mode="nearest"
        ),
        layers.RandomFlip(mode="horizontal"),
        layers.RandomRotation(
            factor=0.15, fill_mode="nearest"
        ),
        layers.RandomZoom(
            height_factor=(-0.3, 0.1), width_factor=(-0.3, 0.1), fill_mode="nearest"
        )
    ]
)

ランダム画像を表示します。

image_idx = np.random.choice(range(x_data.shape[0]))
image = x_data[image_idx]
image_class = classes[y_data[image_idx][0]]
plt.figure(figsize=(3, 3))
plt.imshow(x_data[image_idx].astype("uint8"))
plt.title(image_class)
_ = plt.axis("off")

画像の増強バージョンのサンプルを表示します。

plt.figure(figsize=(10, 10))
for i in range(9):
    augmented_images = data_augmentation(np.array([image]))
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_images[0].numpy().astype("uint8"))
    plt.axis("off")

 

自己教師あり表現学習

視覚エンコーダの実装

def create_encoder(representation_dim):
    encoder = keras.Sequential(
        [
            keras.applications.ResNet50V2(
                include_top=False, weights=None, pooling="avg"
            ),
            layers.Dense(representation_dim),
        ]
    )
    return encoder

 

教師なし対照損失

class RepresentationLearner(keras.Model):
    def __init__(
        self,
        encoder,
        projection_units,
        num_augmentations,
        temperature=1.0,
        dropout_rate=0.1,
        l2_normalize=False,
        **kwargs
    ):
        super(RepresentationLearner, self).__init__(**kwargs)
        self.encoder = encoder
        # Create projection head.
        self.projector = keras.Sequential(
            [
                layers.Dropout(dropout_rate),
                layers.Dense(units=projection_units, use_bias=False),
                layers.BatchNormalization(),
                layers.ReLU(),
            ]
        )
        self.num_augmentations = num_augmentations
        self.temperature = temperature
        self.l2_normalize = l2_normalize
        self.loss_tracker = keras.metrics.Mean(name="loss")

    @property
    def metrics(self):
        return [self.loss_tracker]

    def compute_contrastive_loss(self, feature_vectors, batch_size):
        num_augmentations = tf.shape(feature_vectors)[0] // batch_size
        if self.l2_normalize:
            feature_vectors = tf.math.l2_normalize(feature_vectors, -1)
        # The logits shape is [num_augmentations * batch_size, num_augmentations * batch_size].
        logits = (
            tf.linalg.matmul(feature_vectors, feature_vectors, transpose_b=True)
            / self.temperature
        )
        # Apply log-max trick for numerical stability.
        logits_max = tf.math.reduce_max(logits, axis=1)
        logits = logits - logits_max
        # The shape of targets is [num_augmentations * batch_size, num_augmentations * batch_size].
        # targets is a matrix consits of num_augmentations submatrices of shape [batch_size * batch_size].
        # Each [batch_size * batch_size] submatrix is an identity matrix (diagonal entries are ones).
        targets = tf.tile(tf.eye(batch_size), [num_augmentations, num_augmentations])
        # Compute cross entropy loss
        return keras.losses.categorical_crossentropy(
            y_true=targets, y_pred=logits, from_logits=True
        )

    def call(self, inputs):
        # Preprocess the input images.
        preprocessed = data_preprocessing(inputs)
        # Create augmented versions of the images.
        augmented = []
        for _ in range(self.num_augmentations):
            augmented.append(data_augmentation(preprocessed))
        augmented = layers.Concatenate(axis=0)(augmented)
        # Generate embedding representations of the images.
        features = self.encoder(augmented)
        # Apply projection head.
        return self.projector(features)

    def train_step(self, inputs):
        batch_size = tf.shape(inputs)[0]
        # Run the forward pass and compute the contrastive loss
        with tf.GradientTape() as tape:
            feature_vectors = self(inputs, training=True)
            loss = self.compute_contrastive_loss(feature_vectors, batch_size)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update loss tracker metric
        self.loss_tracker.update_state(loss)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, inputs):
        batch_size = tf.shape(inputs)[0]
        feature_vectors = self(inputs, training=False)
        loss = self.compute_contrastive_loss(feature_vectors, batch_size)
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

 

モデルの訓練

# Create vision encoder.
encoder = create_encoder(representation_dim)
# Create representation learner.
representation_learner = RepresentationLearner(
    encoder, projection_units, num_augmentations=2, temperature=0.1
)
# Create a a Cosine decay learning rate scheduler.
lr_scheduler = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=0.001, decay_steps=500, alpha=0.1
)
# Compile the model.
representation_learner.compile(
    optimizer=tfa.optimizers.AdamW(learning_rate=lr_scheduler, weight_decay=0.0001),
)
# Fit the model.
history = representation_learner.fit(
    x=x_data,
    batch_size=512,
    epochs=50,  # for better results, increase the number of epochs to 500.
)
Epoch 1/50
118/118 [==============================] - 29s 135ms/step - loss: 41.1135
Epoch 2/50
118/118 [==============================] - 15s 125ms/step - loss: 11.7141
Epoch 3/50
118/118 [==============================] - 15s 125ms/step - loss: 11.1728
Epoch 4/50
118/118 [==============================] - 15s 125ms/step - loss: 10.9717
Epoch 5/50
118/118 [==============================] - 15s 125ms/step - loss: 10.8574
Epoch 6/50
118/118 [==============================] - 15s 125ms/step - loss: 10.9496
Epoch 7/50
118/118 [==============================] - 15s 124ms/step - loss: 10.7493
Epoch 8/50
118/118 [==============================] - 15s 124ms/step - loss: 10.5979
Epoch 9/50
118/118 [==============================] - 15s 124ms/step - loss: 10.4613
Epoch 10/50
118/118 [==============================] - 15s 125ms/step - loss: 10.2900
Epoch 11/50
118/118 [==============================] - 15s 124ms/step - loss: 10.1303
Epoch 12/50
118/118 [==============================] - 15s 124ms/step - loss: 9.9608
Epoch 13/50
118/118 [==============================] - 15s 124ms/step - loss: 9.7788
Epoch 14/50
118/118 [==============================] - 15s 124ms/step - loss: 9.5830
Epoch 15/50
118/118 [==============================] - 15s 124ms/step - loss: 9.4038
Epoch 16/50
118/118 [==============================] - 15s 124ms/step - loss: 9.1887
Epoch 17/50
118/118 [==============================] - 15s 124ms/step - loss: 9.0000
Epoch 18/50
118/118 [==============================] - 15s 124ms/step - loss: 8.7764
Epoch 19/50
118/118 [==============================] - 15s 124ms/step - loss: 8.5784
Epoch 20/50
118/118 [==============================] - 15s 124ms/step - loss: 8.3592
Epoch 21/50
118/118 [==============================] - 15s 124ms/step - loss: 8.2545
Epoch 22/50
118/118 [==============================] - 15s 124ms/step - loss: 8.1171
Epoch 23/50
118/118 [==============================] - 15s 124ms/step - loss: 7.9598
Epoch 24/50
118/118 [==============================] - 15s 124ms/step - loss: 7.8623
Epoch 25/50
118/118 [==============================] - 15s 124ms/step - loss: 7.7169
Epoch 26/50
118/118 [==============================] - 15s 124ms/step - loss: 7.5100
Epoch 27/50
118/118 [==============================] - 15s 124ms/step - loss: 7.5887
Epoch 28/50
118/118 [==============================] - 15s 124ms/step - loss: 7.3511
Epoch 29/50
118/118 [==============================] - 15s 124ms/step - loss: 7.1647
Epoch 30/50
118/118 [==============================] - 15s 124ms/step - loss: 7.1549
Epoch 31/50
118/118 [==============================] - 15s 124ms/step - loss: 7.0462
Epoch 32/50
118/118 [==============================] - 15s 124ms/step - loss: 6.8149
Epoch 33/50
118/118 [==============================] - 15s 124ms/step - loss: 6.6954
Epoch 34/50
118/118 [==============================] - 15s 124ms/step - loss: 6.5354
Epoch 35/50
118/118 [==============================] - 15s 124ms/step - loss: 6.3982
Epoch 36/50
118/118 [==============================] - 15s 124ms/step - loss: 6.4175
Epoch 37/50
118/118 [==============================] - 15s 124ms/step - loss: 6.3820
Epoch 38/50
118/118 [==============================] - 15s 124ms/step - loss: 6.2560
Epoch 39/50
118/118 [==============================] - 15s 124ms/step - loss: 6.1237
Epoch 40/50
118/118 [==============================] - 15s 124ms/step - loss: 6.0485
Epoch 41/50
118/118 [==============================] - 15s 124ms/step - loss: 5.8846
Epoch 42/50
118/118 [==============================] - 15s 124ms/step - loss: 5.7548
Epoch 43/50
118/118 [==============================] - 15s 124ms/step - loss: 6.0794
Epoch 44/50
118/118 [==============================] - 15s 124ms/step - loss: 5.9023
Epoch 45/50
118/118 [==============================] - 15s 124ms/step - loss: 5.9548
Epoch 46/50
118/118 [==============================] - 15s 124ms/step - loss: 6.0809
Epoch 47/50
118/118 [==============================] - 15s 124ms/step - loss: 5.6123
Epoch 48/50
118/118 [==============================] - 15s 124ms/step - loss: 5.5667
Epoch 49/50
118/118 [==============================] - 15s 124ms/step - loss: 5.4573
Epoch 50/50
118/118 [==============================] - 15s 124ms/step - loss: 5.4597

(訳者注: 実験結果)

Epoch 1/50
118/118 [==============================] - 29s 87ms/step - loss: 29.9556
Epoch 2/50
118/118 [==============================] - 9s 77ms/step - loss: 11.5125
Epoch 3/50
118/118 [==============================] - 9s 78ms/step - loss: 11.0072
Epoch 4/50
118/118 [==============================] - 9s 78ms/step - loss: 10.7874
Epoch 5/50
118/118 [==============================] - 9s 78ms/step - loss: 10.6437
Epoch 6/50
118/118 [==============================] - 9s 78ms/step - loss: 10.5506
Epoch 7/50
118/118 [==============================] - 9s 78ms/step - loss: 10.4373
Epoch 8/50
118/118 [==============================] - 9s 78ms/step - loss: 10.2851
Epoch 9/50
118/118 [==============================] - 9s 78ms/step - loss: 10.1732
Epoch 10/50
118/118 [==============================] - 9s 78ms/step - loss: 9.9649
Epoch 11/50
118/118 [==============================] - 9s 77ms/step - loss: 9.8258
Epoch 12/50
118/118 [==============================] - 9s 78ms/step - loss: 9.5910
Epoch 13/50
118/118 [==============================] - 9s 78ms/step - loss: 9.3802
Epoch 14/50
118/118 [==============================] - 9s 78ms/step - loss: 9.1439
Epoch 15/50
118/118 [==============================] - 9s 78ms/step - loss: 8.9558
Epoch 16/50
118/118 [==============================] - 9s 78ms/step - loss: 8.7526
Epoch 17/50
118/118 [==============================] - 9s 78ms/step - loss: 8.6240
Epoch 18/50
118/118 [==============================] - 9s 78ms/step - loss: 8.4214
Epoch 19/50
118/118 [==============================] - 9s 78ms/step - loss: 8.2466
Epoch 20/50
118/118 [==============================] - 9s 78ms/step - loss: 8.1365
Epoch 21/50
118/118 [==============================] - 9s 77ms/step - loss: 7.9626
Epoch 22/50
118/118 [==============================] - 9s 78ms/step - loss: 7.9407
Epoch 23/50
118/118 [==============================] - 9s 78ms/step - loss: 7.7123
Epoch 24/50
118/118 [==============================] - 9s 78ms/step - loss: 7.6336
Epoch 25/50
118/118 [==============================] - 9s 77ms/step - loss: 7.4061
Epoch 26/50
118/118 [==============================] - 9s 77ms/step - loss: 7.3237
Epoch 27/50
118/118 [==============================] - 9s 77ms/step - loss: 7.1124
Epoch 28/50
118/118 [==============================] - 9s 78ms/step - loss: 6.9913
Epoch 29/50
118/118 [==============================] - 9s 78ms/step - loss: 6.8683
Epoch 30/50
118/118 [==============================] - 9s 78ms/step - loss: 7.0046
Epoch 31/50
118/118 [==============================] - 9s 78ms/step - loss: 6.6838
Epoch 32/50
118/118 [==============================] - 9s 78ms/step - loss: 6.5952
Epoch 33/50
118/118 [==============================] - 9s 78ms/step - loss: 6.4429
Epoch 34/50
118/118 [==============================] - 9s 78ms/step - loss: 6.8122
Epoch 35/50
118/118 [==============================] - 9s 78ms/step - loss: 6.4525
Epoch 36/50
118/118 [==============================] - 9s 78ms/step - loss: 6.3212
Epoch 37/50
118/118 [==============================] - 9s 78ms/step - loss: 6.5632
Epoch 38/50
118/118 [==============================] - 9s 78ms/step - loss: 6.1890
Epoch 39/50
118/118 [==============================] - 9s 77ms/step - loss: 5.9394
Epoch 40/50
118/118 [==============================] - 9s 78ms/step - loss: 5.8260
Epoch 41/50
118/118 [==============================] - 9s 77ms/step - loss: 5.6711
Epoch 42/50
118/118 [==============================] - 9s 78ms/step - loss: 5.7763
Epoch 43/50
118/118 [==============================] - 9s 78ms/step - loss: 5.8551
Epoch 44/50
118/118 [==============================] - 9s 78ms/step - loss: 5.7423
Epoch 45/50
118/118 [==============================] - 9s 78ms/step - loss: 5.9383
Epoch 46/50
118/118 [==============================] - 9s 78ms/step - loss: 5.5107
Epoch 47/50
118/118 [==============================] - 9s 78ms/step - loss: 5.7404
Epoch 48/50
118/118 [==============================] - 9s 78ms/step - loss: 5.3652
Epoch 49/50
118/118 [==============================] - 9s 78ms/step - loss: 5.2598
Epoch 50/50
118/118 [==============================] - 9s 78ms/step - loss: 5.2350
CPU times: user 9min 3s, sys: 30.9 s, total: 9min 34s
Wall time: 8min

訓練損失のプロット

plt.plot(history.history["loss"])
plt.ylabel("loss")
plt.xlabel("epoch")
plt.show()

 

最近傍の計算

画像に対する埋め込みの生成

batch_size = 500
# Get the feature vector representations of the images.
feature_vectors = encoder.predict(x_data, batch_size=batch_size, verbose=1)
# Normalize the feature vectores.
feature_vectors = tf.math.l2_normalize(feature_vectors, -1)
120/120 [==============================] - 4s 18ms/step

 

各埋め込みに対する k 近傍を見つける

neighbours = []
num_batches = feature_vectors.shape[0] // batch_size
for batch_idx in tqdm(range(num_batches)):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size
    current_batch = feature_vectors[start_idx:end_idx]
    # Compute the dot similarity.
    similarities = tf.linalg.matmul(current_batch, feature_vectors, transpose_b=True)
    # Get the indices of most similar vectors.
    _, indices = tf.math.top_k(similarities, k=k_neighbours + 1, sorted=True)
    # Add the indices to the neighbours.
    neighbours.append(indices[..., 1:])

neighbours = np.reshape(np.array(neighbours), (-1, k_neighbours))

各行で幾つかの近傍を表示しましょう :

nrows = 4
ncols = k_neighbours + 1

plt.figure(figsize=(12, 12))
position = 1
for _ in range(nrows):
    anchor_idx = np.random.choice(range(x_data.shape[0]))
    neighbour_indicies = neighbours[anchor_idx]
    indices = [anchor_idx] + neighbour_indicies.tolist()
    for j in range(ncols):
        plt.subplot(nrows, ncols, position)
        plt.imshow(x_data[indices[j]].astype("uint8"))
        plt.title(classes[y_data[indices[j]][0]])
        plt.axis("off")
        position += 1

 


各行の画像が視覚的に似ていて、類似のクラスに属していることに気付くでしょう。

 

近傍によるセマンティック・クラスタリング

クラスタリング一貫性損失の実装

この損失は近傍が同じクラスタリング割当てを持つことを確実にしようとします。

class ClustersConsistencyLoss(keras.losses.Loss):
    def __init__(self):
        super(ClustersConsistencyLoss, self).__init__()

    def __call__(self, target, similarity, sample_weight=None):
        # Set targets to be ones.
        target = tf.ones_like(similarity)
        # Compute cross entropy loss.
        loss = keras.losses.binary_crossentropy(
            y_true=target, y_pred=similarity, from_logits=True
        )
        return tf.math.reduce_mean(loss)

 

クラスタ・エントロピー損失の実装

この損失は、インスタンスの殆どが一つのクラスに割当てられることを回避するため、クラスタ分布がおおよそ一様であることを確実にしようとします。

class ClustersEntropyLoss(keras.losses.Loss):
    def __init__(self, entropy_loss_weight=1.0):
        super(ClustersEntropyLoss, self).__init__()
        self.entropy_loss_weight = entropy_loss_weight

    def __call__(self, target, cluster_probabilities, sample_weight=None):
        # Ideal entropy = log(num_clusters).
        num_clusters = tf.cast(tf.shape(cluster_probabilities)[-1], tf.dtypes.float32)
        target = tf.math.log(num_clusters)
        # Compute the overall clusters distribution.
        cluster_probabilities = tf.math.reduce_mean(cluster_probabilities, axis=0)
        # Replacing zero probabilities - if any - with a very small value.
        cluster_probabilities = tf.clip_by_value(
            cluster_probabilities, clip_value_min=1e-8, clip_value_max=1.0
        )
        # Compute the entropy over the clusters.
        entropy = -tf.math.reduce_sum(
            cluster_probabilities * tf.math.log(cluster_probabilities)
        )
        # Compute the difference between the target and the actual.
        loss = target - entropy
        return loss

 

クラスタリング・モデルの実装

このモデルは入力として raw 画像を取り、訓練済みのエンコーダを使用してその特徴ベクトルを生成し、そしてクラスタ割当てとして特徴ベクトルを与えて、クラスタの確率分布を生成します。

def create_clustering_model(encoder, num_clusters, name=None):
    inputs = keras.Input(shape=input_shape)
    # Preprocess the input images.
    preprocessed = data_preprocessing(inputs)
    # Apply data augmentation to the images.
    augmented = data_augmentation(preprocessed)
    # Generate embedding representations of the images.
    features = encoder(augmented)
    # Assign the images to clusters.
    outputs = layers.Dense(units=num_clusters, activation="softmax")(features)
    # Create the model.
    model = keras.Model(inputs=inputs, outputs=outputs, name=name)
    return model

 

クラスタリング学習器 (= learner) の実装

このモデルは入力アンカー画像とその近傍を受け取り、clustering_model モデルを使用してそれらのためのクラスタ割当てを生成し、2 つの出力を生成します :

  1. similarity : アンカー画像とその近傍のクラスタ割当て間の類似度。この出力は ClustersConsistencyLoss に供給されます。
  2. anchor_clustering : アンカー画像のクラスタ割当て。これは ClustersEntropyLoss に供給されます。
def create_clustering_learner(clustering_model):
    anchor = keras.Input(shape=input_shape, name="anchors")
    neighbours = keras.Input(
        shape=tuple([k_neighbours]) + input_shape, name="neighbours"
    )
    # Changes neighbours shape to [batch_size * k_neighbours, width, height, channels]
    neighbours_reshaped = tf.reshape(neighbours, shape=tuple([-1]) + input_shape)
    # anchor_clustering shape: [batch_size, num_clusters]
    anchor_clustering = clustering_model(anchor)
    # neighbours_clustering shape: [batch_size * k_neighbours, num_clusters]
    neighbours_clustering = clustering_model(neighbours_reshaped)
    # Convert neighbours_clustering shape to [batch_size, k_neighbours, num_clusters]
    neighbours_clustering = tf.reshape(
        neighbours_clustering,
        shape=(-1, k_neighbours, tf.shape(neighbours_clustering)[-1]),
    )
    # similarity shape: [batch_size, 1, k_neighbours]
    similarity = tf.linalg.einsum(
        "bij,bkj->bik", tf.expand_dims(anchor_clustering, axis=1), neighbours_clustering
    )
    # similarity shape:  [batch_size, k_neighbours]
    similarity = layers.Lambda(lambda x: tf.squeeze(x, axis=1), name="similarity")(
        similarity
    )
    # Create the model.
    model = keras.Model(
        inputs=[anchor, neighbours],
        outputs=[similarity, anchor_clustering],
        name="clustering_learner",
    )
    return model

 

モデルの訓練

# If tune_encoder_during_clustering is set to False,
# then freeze the encoder weights.
for layer in encoder.layers:
    layer.trainable = tune_encoder_during_clustering
# Create the clustering model and learner.
clustering_model = create_clustering_model(encoder, num_clusters, name="clustering")
clustering_learner = create_clustering_learner(clustering_model)
# Instantiate the model losses.
losses = [ClustersConsistencyLoss(), ClustersEntropyLoss(entropy_loss_weight=5)]
# Create the model inputs and labels.
inputs = {"anchors": x_data, "neighbours": tf.gather(x_data, neighbours)}
labels = tf.ones(shape=(x_data.shape[0]))
# Compile the model.
clustering_learner.compile(
    optimizer=tfa.optimizers.AdamW(learning_rate=0.0005, weight_decay=0.0001),
    loss=losses,
)

# Begin training the model.
clustering_learner.fit(x=inputs, y=labels, batch_size=512, epochs=50)
Epoch 1/50
118/118 [==============================] - 20s 95ms/step - loss: 0.6655 - similarity_loss: 0.6642 - clustering_loss: 0.0013
Epoch 2/50
118/118 [==============================] - 10s 86ms/step - loss: 0.6361 - similarity_loss: 0.6325 - clustering_loss: 0.0036
Epoch 3/50
118/118 [==============================] - 10s 85ms/step - loss: 0.6129 - similarity_loss: 0.6070 - clustering_loss: 0.0059
Epoch 4/50
118/118 [==============================] - 10s 85ms/step - loss: 0.6005 - similarity_loss: 0.5930 - clustering_loss: 0.0075
Epoch 5/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5923 - similarity_loss: 0.5849 - clustering_loss: 0.0074
Epoch 6/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5879 - similarity_loss: 0.5795 - clustering_loss: 0.0084
Epoch 7/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5841 - similarity_loss: 0.5754 - clustering_loss: 0.0087
Epoch 8/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5817 - similarity_loss: 0.5733 - clustering_loss: 0.0084
Epoch 9/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5811 - similarity_loss: 0.5717 - clustering_loss: 0.0094
Epoch 10/50
118/118 [==============================] - 10s 85ms/step - loss: 0.5797 - similarity_loss: 0.5697 - clustering_loss: 0.0100
Epoch 11/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5767 - similarity_loss: 0.5676 - clustering_loss: 0.0091
Epoch 12/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5771 - similarity_loss: 0.5667 - clustering_loss: 0.0104
Epoch 13/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5755 - similarity_loss: 0.5661 - clustering_loss: 0.0094
Epoch 14/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5746 - similarity_loss: 0.5653 - clustering_loss: 0.0093
Epoch 15/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5743 - similarity_loss: 0.5640 - clustering_loss: 0.0103
Epoch 16/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5738 - similarity_loss: 0.5636 - clustering_loss: 0.0102
Epoch 17/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5732 - similarity_loss: 0.5627 - clustering_loss: 0.0106
Epoch 18/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5723 - similarity_loss: 0.5621 - clustering_loss: 0.0102
Epoch 19/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5711 - similarity_loss: 0.5615 - clustering_loss: 0.0096
Epoch 20/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5693 - similarity_loss: 0.5596 - clustering_loss: 0.0097
Epoch 21/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5699 - similarity_loss: 0.5600 - clustering_loss: 0.0099
Epoch 22/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5694 - similarity_loss: 0.5592 - clustering_loss: 0.0102
Epoch 23/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5703 - similarity_loss: 0.5595 - clustering_loss: 0.0108
Epoch 24/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5687 - similarity_loss: 0.5587 - clustering_loss: 0.0101
Epoch 25/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5688 - similarity_loss: 0.5585 - clustering_loss: 0.0103
Epoch 26/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5690 - similarity_loss: 0.5583 - clustering_loss: 0.0108
Epoch 27/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5679 - similarity_loss: 0.5572 - clustering_loss: 0.0107
Epoch 28/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5681 - similarity_loss: 0.5573 - clustering_loss: 0.0108
Epoch 29/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5682 - similarity_loss: 0.5572 - clustering_loss: 0.0111
Epoch 30/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5675 - similarity_loss: 0.5571 - clustering_loss: 0.0104
Epoch 31/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5679 - similarity_loss: 0.5562 - clustering_loss: 0.0116
Epoch 32/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5663 - similarity_loss: 0.5554 - clustering_loss: 0.0109
Epoch 33/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5665 - similarity_loss: 0.5556 - clustering_loss: 0.0109
Epoch 34/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5679 - similarity_loss: 0.5568 - clustering_loss: 0.0111
Epoch 35/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5680 - similarity_loss: 0.5563 - clustering_loss: 0.0117
Epoch 36/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5665 - similarity_loss: 0.5553 - clustering_loss: 0.0112
Epoch 37/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5674 - similarity_loss: 0.5556 - clustering_loss: 0.0118
Epoch 38/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5648 - similarity_loss: 0.5543 - clustering_loss: 0.0105
Epoch 39/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5653 - similarity_loss: 0.5549 - clustering_loss: 0.0103
Epoch 40/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5656 - similarity_loss: 0.5544 - clustering_loss: 0.0113
Epoch 41/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5644 - similarity_loss: 0.5542 - clustering_loss: 0.0102
Epoch 42/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5658 - similarity_loss: 0.5540 - clustering_loss: 0.0118
Epoch 43/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5655 - similarity_loss: 0.5539 - clustering_loss: 0.0116
Epoch 44/50
118/118 [==============================] - 10s 87ms/step - loss: 0.5662 - similarity_loss: 0.5543 - clustering_loss: 0.0119
Epoch 45/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5651 - similarity_loss: 0.5537 - clustering_loss: 0.0114
Epoch 46/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5635 - similarity_loss: 0.5534 - clustering_loss: 0.0101
Epoch 47/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5633 - similarity_loss: 0.5529 - clustering_loss: 0.0103
Epoch 48/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5643 - similarity_loss: 0.5526 - clustering_loss: 0.0117
Epoch 49/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5653 - similarity_loss: 0.5532 - clustering_loss: 0.0121
Epoch 50/50
118/118 [==============================] - 10s 86ms/step - loss: 0.5641 - similarity_loss: 0.5525 - clustering_loss: 0.0117

<tensorflow.python.keras.callbacks.History at 0x7f1da373ea10>
Epoch 1/50
118/118 [==============================] - 16s 70ms/step - loss: 0.6686 - similarity_loss: 0.6683 - clustering_loss: 3.3759e-04
Epoch 2/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6681 - similarity_loss: 0.6679 - clustering_loss: 2.2749e-04
Epoch 3/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6674 - similarity_loss: 0.6670 - clustering_loss: 4.8134e-04
Epoch 4/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6668 - similarity_loss: 0.6660 - clustering_loss: 7.9841e-04
Epoch 5/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6659 - similarity_loss: 0.6651 - clustering_loss: 8.5452e-04
Epoch 6/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6656 - similarity_loss: 0.6644 - clustering_loss: 0.0012
Epoch 7/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6651 - similarity_loss: 0.6637 - clustering_loss: 0.0014
Epoch 8/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6649 - similarity_loss: 0.6633 - clustering_loss: 0.0016
Epoch 9/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6647 - similarity_loss: 0.6629 - clustering_loss: 0.0017
Epoch 10/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6641 - similarity_loss: 0.6626 - clustering_loss: 0.0015
Epoch 11/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6641 - similarity_loss: 0.6623 - clustering_loss: 0.0018
Epoch 12/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6638 - similarity_loss: 0.6619 - clustering_loss: 0.0018
Epoch 13/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6636 - similarity_loss: 0.6617 - clustering_loss: 0.0019
Epoch 14/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6635 - similarity_loss: 0.6615 - clustering_loss: 0.0020
Epoch 15/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6633 - similarity_loss: 0.6613 - clustering_loss: 0.0020
Epoch 16/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6632 - similarity_loss: 0.6612 - clustering_loss: 0.0020
Epoch 17/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6634 - similarity_loss: 0.6609 - clustering_loss: 0.0025
Epoch 18/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6630 - similarity_loss: 0.6608 - clustering_loss: 0.0022
Epoch 19/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6633 - similarity_loss: 0.6607 - clustering_loss: 0.0026
Epoch 20/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6632 - similarity_loss: 0.6606 - clustering_loss: 0.0026
Epoch 21/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6628 - similarity_loss: 0.6605 - clustering_loss: 0.0023
Epoch 22/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6628 - similarity_loss: 0.6603 - clustering_loss: 0.0025
Epoch 23/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6630 - similarity_loss: 0.6603 - clustering_loss: 0.0027
Epoch 24/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6626 - similarity_loss: 0.6602 - clustering_loss: 0.0024
Epoch 25/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6626 - similarity_loss: 0.6601 - clustering_loss: 0.0025
Epoch 26/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6627 - similarity_loss: 0.6600 - clustering_loss: 0.0027
Epoch 27/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6626 - similarity_loss: 0.6600 - clustering_loss: 0.0026
Epoch 28/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6627 - similarity_loss: 0.6599 - clustering_loss: 0.0028
Epoch 29/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6623 - similarity_loss: 0.6599 - clustering_loss: 0.0025
Epoch 30/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6623 - similarity_loss: 0.6597 - clustering_loss: 0.0026
Epoch 31/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6622 - similarity_loss: 0.6597 - clustering_loss: 0.0025
Epoch 32/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6623 - similarity_loss: 0.6597 - clustering_loss: 0.0026
Epoch 33/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6624 - similarity_loss: 0.6596 - clustering_loss: 0.0029
Epoch 34/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6626 - similarity_loss: 0.6596 - clustering_loss: 0.0031
Epoch 35/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6622 - similarity_loss: 0.6595 - clustering_loss: 0.0027
Epoch 36/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6622 - similarity_loss: 0.6594 - clustering_loss: 0.0028
Epoch 37/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6622 - similarity_loss: 0.6594 - clustering_loss: 0.0029
Epoch 38/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6623 - similarity_loss: 0.6594 - clustering_loss: 0.0029
Epoch 39/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6623 - similarity_loss: 0.6594 - clustering_loss: 0.0029
Epoch 40/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6623 - similarity_loss: 0.6594 - clustering_loss: 0.0029
Epoch 41/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6620 - similarity_loss: 0.6593 - clustering_loss: 0.0027
Epoch 42/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6621 - similarity_loss: 0.6593 - clustering_loss: 0.0028
Epoch 43/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6620 - similarity_loss: 0.6593 - clustering_loss: 0.0028
Epoch 44/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6621 - similarity_loss: 0.6592 - clustering_loss: 0.0029
Epoch 45/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6621 - similarity_loss: 0.6592 - clustering_loss: 0.0029
Epoch 46/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6621 - similarity_loss: 0.6592 - clustering_loss: 0.0029
Epoch 47/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6619 - similarity_loss: 0.6591 - clustering_loss: 0.0028
Epoch 48/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6624 - similarity_loss: 0.6592 - clustering_loss: 0.0032
Epoch 49/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6621 - similarity_loss: 0.6591 - clustering_loss: 0.0030
Epoch 50/50
118/118 [==============================] - 7s 63ms/step - loss: 0.6621 - similarity_loss: 0.6591 - clustering_loss: 0.0030
CPU times: user 7min 18s, sys: 19.6 s, total: 7min 37s
Wall time: 6min 23s

訓練損失のプロット

plt.plot(history.history["loss"])
plt.ylabel("loss")
plt.xlabel("epoch")
plt.show()

 

クラスタ分析

画像のクラスタへの割当て

# Get the cluster probability distribution of the input images.
clustering_probs = clustering_model.predict(x_data, batch_size=batch_size, verbose=1)
# Get the cluster of the highest probability.
cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy()
# Store the clustering confidence.
# Images with the highest clustering confidence are considered the 'prototypes'
# of the clusters.
cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy()

クラスタのサイズを計算しましょう。

clusters = defaultdict(list)
for idx, c in enumerate(cluster_assignments):
    clusters[c].append((idx, cluster_confidence[idx]))

for c in range(num_clusters):
    print("cluster", c, ":", len(clusters[c]))
cluster 0 : 4132
cluster 1 : 4057
cluster 2 : 1713
cluster 3 : 2801
cluster 4 : 2511
cluster 5 : 2655
cluster 6 : 2517
cluster 7 : 4493
cluster 8 : 3687
cluster 9 : 1716
cluster 10 : 3397
cluster 11 : 3606
cluster 12 : 3325
cluster 13 : 4010
cluster 14 : 2188
cluster 15 : 3278
cluster 16 : 1902
cluster 17 : 1858
cluster 18 : 3828
cluster 19 : 2326
cluster 0 : 5102
cluster 1 : 3477
cluster 2 : 1382
cluster 3 : 5087
cluster 4 : 2589
cluster 5 : 1893
cluster 6 : 2010
cluster 7 : 5885
cluster 8 : 3072
cluster 9 : 1013
cluster 10 : 2860
cluster 11 : 4134
cluster 12 : 5173
cluster 13 : 2267
cluster 14 : 2237
cluster 15 : 2864
cluster 16 : 1281
cluster 17 : 1335
cluster 18 : 3970
cluster 19 : 2369

クラスタがおおよそ均衡の取れたサイズを持つことに気付くでしょう。

 

クラスタ画像の可視化

各クラスタのプロトタイプ — 最も高いクラスタリング信頼度を持つインスタンスを表示します。

num_images = 8
plt.figure(figsize=(15, 15))
position = 1
for c in range(num_clusters):
    cluster_instances = sorted(clusters[c], key=lambda kv: kv[1], reverse=True)

    for j in range(num_images):
        image_idx = cluster_instances[j][0]
        plt.subplot(num_clusters, num_images, position)
        plt.imshow(x_data[image_idx].astype("uint8"))
        plt.title(classes[y_data[image_idx][0]])
        plt.axis("off")
        position += 1

 


 

クラスタリング精度を計算する

最初に、各クラスタに対してその画像の大多数のラベルに基づいてラベルを割当てます。それから、大多数のラベルを持つ画像の数をクラスタのサイズで除算して各クラスタの精度を計算します。

cluster_label_counts = dict()

for c in range(num_clusters):
    cluster_label_counts[c] = [0] * num_classes
    instances = clusters[c]
    for i, _ in instances:
        cluster_label_counts[c][y_data[i][0]] += 1

    cluster_label_idx = np.argmax(cluster_label_counts[c])
    correct_count = np.max(cluster_label_counts[c])
    cluster_size = len(clusters[c])
    accuracy = (
        np.round((correct_count / cluster_size) * 100, 2) if cluster_size > 0 else 0
    )
    cluster_label = classes[cluster_label_idx]
    print("cluster", c, "label is:", cluster_label, " -  accuracy:", accuracy, "%")
cluster 0 label is: frog  -  accuracy: 23.11 %
cluster 1 label is: truck  -  accuracy: 23.56 %
cluster 2 label is: bird  -  accuracy: 29.01 %
cluster 3 label is: dog  -  accuracy: 16.67 %
cluster 4 label is: truck  -  accuracy: 27.8 %
cluster 5 label is: ship  -  accuracy: 36.91 %
cluster 6 label is: deer  -  accuracy: 27.89 %
cluster 7 label is: dog  -  accuracy: 23.84 %
cluster 8 label is: airplane  -  accuracy: 21.7 %
cluster 9 label is: bird  -  accuracy: 22.38 %
cluster 10 label is: automobile  -  accuracy: 24.76 %
cluster 11 label is: automobile  -  accuracy: 24.15 %
cluster 12 label is: cat  -  accuracy: 17.44 %
cluster 13 label is: truck  -  accuracy: 23.44 %
cluster 14 label is: ship  -  accuracy: 31.67 %
cluster 15 label is: airplane  -  accuracy: 41.06 %
cluster 16 label is: deer  -  accuracy: 22.77 %
cluster 17 label is: airplane  -  accuracy: 15.18 %
cluster 18 label is: frog  -  accuracy: 33.31 %
cluster 19 label is: deer  -  accuracy: 18.7 %
cluster 0 label is: truck  -  accuracy: 24.07 %
cluster 1 label is: ship  -  accuracy: 25.77 %
cluster 2 label is: dog  -  accuracy: 16.35 %
cluster 3 label is: airplane  -  accuracy: 29.15 %
cluster 4 label is: automobile  -  accuracy: 25.61 %
cluster 5 label is: bird  -  accuracy: 31.06 %
cluster 6 label is: deer  -  accuracy: 14.23 %
cluster 7 label is: dog  -  accuracy: 18.23 %
cluster 8 label is: airplane  -  accuracy: 14.29 %
cluster 9 label is: bird  -  accuracy: 19.74 %
cluster 10 label is: airplane  -  accuracy: 30.94 %
cluster 11 label is: frog  -  accuracy: 30.14 %
cluster 12 label is: dog  -  accuracy: 16.61 %
cluster 13 label is: dog  -  accuracy: 14.91 %
cluster 14 label is: frog  -  accuracy: 28.25 %
cluster 15 label is: horse  -  accuracy: 16.97 %
cluster 16 label is: airplane  -  accuracy: 27.01 %
cluster 17 label is: deer  -  accuracy: 24.72 %
cluster 18 label is: truck  -  accuracy: 31.64 %
cluster 19 label is: dog  -  accuracy: 21.44 %

 

結論

精度の結果を改善するには、以下が実行できます :

  1. 表現学習とクラスタリング段階 (= phase) のエポック数を増やす ;
  2. クラスタリング段階の間にエンコーダ重みが調整されることを可能にする ; そして
  3. 元の SCAN 論文 で記述されているように、self-labeling を通して最終的な再調整ステップを遂行する。

教師なし画像クラスタリング技術は、教師あり画像クラスタリング技術の精度を超えることは期待されていません、むしろ画像のセマンティクスを学習してそれらを (元のクラスに類似した) クラスタにグループ分けできることを示していることに注意してください。

 

以上



クラスキャット

最近の投稿

  • LangGraph on Colab : マルチエージェント・スーパーバイザー
  • LangGraph on Colab : エージェント型 RAG
  • LangGraph : 例題 : エージェント型 RAG
  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (23) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2021年12月
月 火 水 木 金 土 日
 12345
6789101112
13141516171819
20212223242526
2728293031  
« 11月   3月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme