Keras 2 : examples : TensorFlow Similarity による画像類似性検索のためのメトリック学習 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/28/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Metric learning for image similarity search using TensorFlow Similarity (Author: Owen Vallis)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : TensorFlow Similarity による画像類似性検索のためのメトリック学習
Description: CIFAR-10 画像上の類似性メトリック学習を使用するサンプル。
概要
このサンプルは 「画像の類似性検索のためのメトリック学習」サンプル に基づいています。同じデータセットを使用しますが、TensorFlow Similarity を使用してモデルを実装することを目的としています。
メトリック学習は、「類似の」入力が互いにより近くに配置されて「似ていない」入力が遠くに離れて配置されるように、入力を高次元空間に埋め込めるモデルを訓練することを目的としています。一度訓練されれば、これらのモデルはそのような類似性が有用であるような下流システムのための埋め込みを生成できます、例えば検索のためのランキングシグナルや、別の教師あり問題のための事前訓練済みの埋め込みモデルの形式です。
メトリック学習の詳細な概要については以下を参照してください :
セットアップ
このチュートリアルは類似性埋め込みを学習して評価するために TensorFlow Similarity ライブラリを使用します。TensorFlow Similarity は以下のようなコンポーネントを提供します :
- 対照 (= contrastive) モデルの訓練を単純にかつ高速に行ないます。
- バッチがサンプルのペアを含むことを確実にすることを容易にします。
- 埋め込みの品質の評価を可能にします。
import random
from matplotlib import pyplot as plt
from mpl_toolkits import axes_grid1
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_similarity as tfsim
tfsim.utils.tf_cap_memory()
print("TensorFlow:", tf.__version__)
print("TensorFlow Similarity:", tfsim.__version__)
TensorFlow: 2.6.0 TensorFlow Similarity: 0.14
データセット・サンプラー
このチュートリアルのために CIFAR-10 データセットを使用していきます。
類似性モデルが効率的に学習するため、各バッチは各クラスの少なくとも 2 つのサンプルを含まなければなりません。
これを容易にするため、tf_similarity は Sampler オブジェクトを提供します、これはクラス数とバッチ毎の各クラスの最少サンプル数の両方を設定することが可能です。
訓練と検証データセットは TFDatasetMultiShotMemorySampler オブジェクトを使用して作成されます。これは、TensorFlow Datasets からデータセットをロードして、クラスのターゲット数とクラス毎のサンプルのターゲット数を含むバッチを生成するサンプラーを作成します。更に、サンプラーをclass_list で定義されているクラスのサブセットだけを生成するように制限して、クラスのサブセット上で訓練してから埋め込みが未見のクラスに対してどのように一般化するかテストすることを可能にします。これは few-shot 学習問題に取り組むときに有用であり得ます。
次のセルは以下のような train_ds サンプルを作成します :
- TFDS から CIFAR-10 データセットをロードしてから examples_per_class_per_batch を取得します。
- サンプラーがクラスを class_list で定義されたものに制限することを確実にします。
- 各バッチが各々 8 サンプルを持つ 10 の異なるクラスを含むことを確実にします。
同じ方法で検証データセットも作成しますが、クラス毎の合計サンプル数を 100 に制限し、そしてバッチ毎のクラス毎のサンプル数を 2 のデフォルトに設定します。
# This determines the number of classes used during training.
# Here we are using all the classes.
num_known_classes = 10
class_list = random.sample(population=range(10), k=num_known_classes)
classes_per_batch = 10
# Passing multiple examples per class per batch ensures that each example has
# multiple positive pairs. This can be useful when performing triplet mining or
# when using losses like `MultiSimilarityLoss` or `CircleLoss` as these can
# take a weighted mix of all the positive pairs. In general, more examples per
# class will lead to more information for the positive pairs, while more classes
# per batch will provide more varied information in the negative pairs. However,
# the losses compute the pairwise distance between the examples in a batch so
# the upper limit of the batch size is restricted by the memory.
examples_per_class_per_batch = 8
print(
"Batch size is: "
f"{min(classes_per_batch, num_known_classes) * examples_per_class_per_batch}"
)
print(" Create Training Data ".center(34, "#"))
train_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
"cifar10",
classes_per_batch=min(classes_per_batch, num_known_classes),
splits="train",
steps_per_epoch=4000,
examples_per_class_per_batch=examples_per_class_per_batch,
class_list=class_list,
)
print("\n" + " Create Validation Data ".center(34, "#"))
val_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
"cifar10",
classes_per_batch=classes_per_batch,
splits="test",
total_examples_per_class=100,
)
Batch size is: 80 ###### Create Training Data ###### 2021-10-07 22:48:06.609114: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. converting train: 0%| | 0/50000 [00:00<?, ?it/s] 2021-10-07 22:48:06.692705: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2) The initial batch size is 80 (10 classes * 8 examples per class) with 0 augmenters filtering examples: 0%| | 0/50000 [00:00<?, ?it/s] selecting classes: 0%| | 0/10 [00:00<?, ?it/s] gather examples: 0%| | 0/50000 [00:00<?, ?it/s] indexing classes: 0%| | 0/50000 [00:00<?, ?it/s] ##### Create Validation Data ##### converting test: 0%| | 0/10000 [00:00<?, ?it/s] The initial batch size is 20 (10 classes * 2 examples per class) with 0 augmenters filtering examples: 0%| | 0/10000 [00:00<?, ?it/s] selecting classes: 0%| | 0/10 [00:00<?, ?it/s] gather examples: 0%| | 0/1000 [00:00<?, ?it/s] indexing classes: 0%| | 0/1000 [00:00<?, ?it/s]
データセットの可視化
サンプラーはデータセットをシャッフルしますので、最初の 25 画像をプロットしてデータセットの感覚を掴むことができます。
サンプラーは get_slice(begin, size) メソッドを提供します、これはサンプルのブロックを簡単に選択することを可能にします。
代わりに、バッチを生成するために generate_batch() メソッドも使用できます。これは、バッチが想定されるクラス数とクラス毎のサンプル数が含まれているか確認することを可能にします。
num_cols = num_rows = 5
# Get the first 25 examples.
x_slice, y_slice = train_ds.get_slice(begin=0, size=num_cols * num_rows)
fig = plt.figure(figsize=(6.0, 6.0))
grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols=(num_cols, num_rows), axes_pad=0.1)
for ax, im, label in zip(grid, x_slice, y_slice):
ax.imshow(im)
ax.axis("off")
埋め込みモデル
次に Keras 関数型 API を使用して SimilarityModel を定義します。モデルは、L2 正規化を適用する MetricEmbedding 層が追加された標準的な convnet です。コサイン距離を使用するとき metric embedding 層は役立ちます、ベクトルの角度だけをケアしているからです。
更に、SimilarityModel は以下のための幾つかのヘルパー・メソッドを提供しています :
- 埋め込まれたサンプルのインデキシング
- サンプル検索の遂行
- 分類の評価
- 埋め込み空間の品質の評価
詳細は TensorFlow Similarity ドキュメント を参照してください。
embedding_size = 256
inputs = keras.layers.Input((32, 32, 3))
x = keras.layers.Rescaling(scale=1.0 / 255)(inputs)
x = keras.layers.Conv2D(64, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(128, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D((4, 4))(x)
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
x = keras.layers.GlobalMaxPool2D()(x)
outputs = tfsim.layers.MetricEmbedding(embedding_size)(x)
# building model
model = tfsim.models.SimilarityModel(inputs, outputs)
model.summary()
Model: "similarity_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 32, 32, 3)] 0 _________________________________________________________________ rescaling (Rescaling) (None, 32, 32, 3) 0 _________________________________________________________________ conv2d (Conv2D) (None, 30, 30, 64) 1792 _________________________________________________________________ batch_normalization (BatchNo (None, 30, 30, 64) 256 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 128) 73856 _________________________________________________________________ batch_normalization_1 (Batch (None, 28, 28, 128) 512 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 7, 7, 128) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 5, 5, 256) 295168 _________________________________________________________________ batch_normalization_2 (Batch (None, 5, 5, 256) 1024 _________________________________________________________________ conv2d_3 (Conv2D) (None, 3, 3, 256) 590080 _________________________________________________________________ global_max_pooling2d (Global (None, 256) 0 _________________________________________________________________ metric_embedding (MetricEmbe (None, 256) 65792 ================================================================= Total params: 1,028,480 Trainable params: 1,027,584 Non-trainable params: 896 _________________________________________________________________
類似性損失
類似性損失 (= similarity loss) は各クラスの少なくとも 2 サンプルを含むバッチを想定していて、そこからペアワイズにポジティブとネガティブの距離について損失を計算します。ここでは MultiSimilarityLoss() ( 論文 ) を使用しています、TensorFlow Similarity の幾つかの損失の一つです。この損失は、self-similarity, positive-similarity と negative-similarity を考慮に入れて、バッチ内の情報を持つ総てのペアの使用を試みます。
epochs = 3
learning_rate = 0.002
val_steps = 50
# init similarity loss
loss = tfsim.losses.MultiSimilarityLoss()
# compiling and training
model.compile(
optimizer=keras.optimizers.Adam(learning_rate), loss=loss, steps_per_execution=10,
)
history = model.fit(
train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps
)
Distance metric automatically set to cosine use the distance arg to override. Epoch 1/3 4000/4000 [==============================] - ETA: 0s - loss: 2.2158Warmup complete 4000/4000 [==============================] - 1072s 268ms/step - loss: 2.2158 - val_loss: 0.8940 Warmup complete Epoch 2/3 4000/4000 [==============================] - 1052s 263ms/step - loss: 1.8965 - val_loss: 0.8814 Epoch 3/3 4000/4000 [==============================] - 1047s 262ms/step - loss: 1.6221 - val_loss: 0.8009
インデキシング
モデルを訓練したので、サンプルのインデックスを作成できます。ここでは、data パラメータ内に画像をストアするとともに x と y をインデックスに渡すことにより最初の 200 検証サンプルのインデックス化をバッチ処理します。x_index は埋め込まれてから検索可能にするためにインデックスに追加されます。y_index と data はオプションですが、ユーザがメタデータを埋め込まれたサンプルと関連付けることを可能にします。
x_index, y_index = val_ds.get_slice(begin=0, size=200)
model.reset_index()
model.index(x_index, y_index, data=x_index)
[Indexing 200 points] |-Computing embeddings |-Storing data points in key value store |-Adding embeddings to index. |-Building index. 0% 10 20 30 40 50 60 70 80 90 100% |----|----|----|----|----|----|----|----|----|----| ***************************************************
キャリブレーション
インデックスが構築されたら、マッチング・ストラテジーとキャリブレーション・メトリックを使用して距離の閾値をキャリブレートできます。
ここでは K=1 を分類器として使用しながら、最適な F1 スコアを求めています。キャリブレートされた閾値距離以下の総てのマッチングは (問い合わせサンプルとマッチング結果に関連するラベルの間の) Positive マッチとしてラベル付けられ、閾値距離より上の総てのマッチングは Negative マッチとしてラベル付けされます。
更に、追加のメトリクスを渡して計算します。出力の総ての値はキャリブレートされた閾値で計算されます。
最後に、model.calibrate() は以下を含む CalibrationResults オブジェクトを返します :
- “cutpoints”: cutpoint 名を特定の距離閾値に関連付けられた ClassificationMetric を含む辞書にマップする Python 辞書です、e.g., “optimal” : {“acc”: 0.90, “f1”: 0.92}。
- “thresholds”: ClassificationMetric 名を距離閾値の各々で計算されたメトリック値を含むリストにマップする Python 辞書です、e.g., {“f1”: [0.99, 0.80], “distance”: [0.0, 1.0]}。
x_train, y_train = train_ds.get_slice(begin=0, size=1000)
calibration = model.calibrate(
x_train,
y_train,
calibration_metric="f1",
matcher="match_nearest",
extra_metrics=["precision", "recall", "binary_accuracy"],
verbose=1,
)
Performing NN search Building NN list: 0%| | 0/1000 [00:00<?, ?it/s] Evaluating: 0%| | 0/4 [00:00<?, ?it/s] computing thresholds: 0%| | 0/975 [00:00<?, ?it/s] name value distance precision recall binary_accuracy f1 ------- ------- ---------- ----------- -------- ----------------- -------- optimal 0.94 0.0741751 0.892 1 0.892 0.942918
可視化
メトリクス単独からではモデル品質の感覚を得るのは難しいかもしれません。補足的なアプローチは、マッチ品質の感覚を得るために問い合わせ結果のセットを主導で調査することです。
ここでは 10 の検証サンプルを取り、そしてそれらを 5 つの近傍と問い合わせサンプルへの距離とともにプロットします。結果を見ると、それらが不完全である一方で依然として意味がある類似の画像を表示していて、そしてモデルはそれらのポーズや画像の照明とは無関係に類似の画像を見つけられることがわかります。
モデルはある画像については非常に確信があり、問い合わせと近傍の間が非常に小さい距離であるという結果になっていることがわかります。逆に、距離が大きくなるにつれてクラスラベルにおいて間違いが多くなることがわかります。これがマッチング・アプリケーションに対してキャリブレーションが重要である理由の一つです。
num_neighbors = 5
labels = [
"Airplane",
"Automobile",
"Bird",
"Cat",
"Deer",
"Dog",
"Frog",
"Horse",
"Ship",
"Truck",
"Unknown",
]
class_mapping = {c_id: c_lbl for c_id, c_lbl in zip(range(11), labels)}
x_display, y_display = val_ds.get_slice(begin=200, size=10)
# lookup nearest neighbors in the index
nns = model.lookup(x_display, k=num_neighbors)
# display
for idx in np.argsort(y_display):
tfsim.visualization.viz_neigbors_imgs(
x_display[idx],
y_display[idx],
nns[idx],
class_mapping=class_mapping,
fig_size=(16, 2),
)
Performing NN search Building NN list: 0%| | 0/10 [00:00<?, ?it/s]
メトリクス
距離の閾値が増えたときのマッチング性能の感覚を得るために CalibrationResults に含まれる追加のメトリクスをプロットすることもできます。
以下のプロットは Precision, Recall と F1 スコアを示しています。距離が増加するにつれてマッチング精度は低下しますが、ポジティブマッチ (recal) として受け入れた問い合わせのパーセンテージはキャリブレートされた距離の閾値まで高速に増加することが分かります。
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
x = calibration.thresholds["distance"]
ax1.plot(x, calibration.thresholds["precision"], label="precision")
ax1.plot(x, calibration.thresholds["recall"], label="recall")
ax1.plot(x, calibration.thresholds["f1"], label="f1 score")
ax1.legend()
ax1.set_title("Metric evolution as distance increase")
ax1.set_xlabel("Distance")
ax1.set_ylim((-0.05, 1.05))
ax2.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])
ax2.set_title("Precision recall curve")
ax2.set_xlabel("Recall")
ax2.set_ylabel("Precision")
ax2.set_ylim((-0.05, 1.05))
plt.show()
各クラスの 100 サンプルを取り、各サンプルと最近傍マッチの混同行列をプロットすることもできます。キャリブレートされた距離の閾値以上のマッチを表わすために「追加の (= extra)」10 番目のクラスを追加することもできます。
殆どの誤りは動物クラス間にありますが、飛行機と鳥の間にも興味深い数の混乱があることが分かります。更に、各クラスの 100 サンプルの僅かだけが、キャリブレートされた距離の閾値の外側でマッチを返したことも分かります。
cutpoint = "optimal"
# This yields 100 examples for each class.
# We defined this when we created the val_ds sampler.
x_confusion, y_confusion = val_ds.get_slice(0, -1)
matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)
tfsim.visualization.confusion_matrix(
matches,
y_confusion,
labels=labels,
title="Confusion matrix for cutpoint:%s" % cutpoint,
normalize=False,
)
ノーマッチ
どの画像がインデックスされたどのサンプルにもマッチしないかを見るために、キャリブレートされた閾値の外側のサンプルをプロットすることもできます。
これはどのような他のサンプルがインデックスされる必要があるかへの洞察を与えたり、あるいはクラス内の異常なサンプルを表面化させるかかもしれません。
クラスターの可視化
モデルの品質を素早く把握し、欠点を理解するための最良の方法の1つは、埋め込みを2D空間に投影することです。
これにより、画像のクラスターを検査し、どのクラスが絡み合っているかを理解することができます。
idx_no_match = np.where(np.array(matches) == 10)
no_match_queries = x_confusion[idx_no_match]
if len(no_match_queries):
plt.imshow(no_match_queries[0])
else:
print("All queries have a match below the distance threshold.")
All queries have a match below the distance threshold.
クラスタの可視化
モデルがどのように動作するかの品質の感覚を素早く得てその欠点を理解する最善の方法の一つは埋め込みを 2D 空間に射影することです。
これは画像のクラスターを調査してどのクラスがもつれているかを理解することを可能にします。
# Each class in val_ds was restricted to 100 examples.
num_examples_to_clusters = 1000
thumb_size = 96
plot_size = 800
vx, vy = val_ds.get_slice(0, num_examples_to_clusters)
# Uncomment to run the interactive projector.
# tfsim.visualization.projector(
# model.predict(vx),
# labels=vy,
# images=vx,
# class_mapping=class_mapping,
# image_size=thumb_size,
# plot_size=plot_size,
# )
以上