Keras 2 : examples : Near-duplicate 画像検索 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/02/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Near-duplicate image search (Author: Sayak Paul)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- E-Mail:sales-info@classcat.com ; WebSite: www.classcat.com ; Facebook
Keras 2 : examples : Near-duplicate 画像検索
Description: 深層学習と局所性鋭敏型ハッシュ (= locality-sensitive hashing) を使用した Near-duplicate 画像検索ユティリティを構築します。
イントロダクション
類似画像の (ほぼ) リアルタイムでの取得は情報検索システムの重要なユースケースです。それを活用した幾つかのポピュラーな製品は Pinterest, Google Image Search 等を含みます。このサンプルでは、局所性鋭敏型ハッシュ (LSH, Locality Sensitive Hashing) と、事前訓練済みの画像分類器により計算された画像表現の上の ランダム投影 を使用して、類似画像検索ユティリティを構築します。この種類の検索エンジンは near-duplicate (or near-dup) 画像検出器としても知られています。TensorRT を使用して GPU 上での検索ユティリティの推論性能を最適化することも調べます。
これに関係して確認するに値する keras.io/examples/vision 下に別のサンプルがあります :
最後に、このサンプルは次のリソースをリファレンスとして使用し、そのためコードの一部を再利用しています : Locality Sensitive Hashing for Similar Item Search.
インポート
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import time
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
データセットをロードして 1,000 画像の訓練セットを作成する
このサンプルの実行時間を短く保つために、語彙を構築するために (TensorFlow Datasets 経由で利用可能な) tf_flowers データセットからの 1,000 画像のサブセットを使用していきます。
train_ds, validation_ds = tfds.load(
"tf_flowers", split=["train[:85%]", "train[85%:]"], as_supervised=True
)
IMAGE_SIZE = 224
NUM_IMAGES = 1000
images = []
labels = []
for (image, label) in train_ds.take(NUM_IMAGES):
image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
images.append(image.numpy())
labels.append(label.numpy())
images = np.array(images)
labels = np.array(labels)
事前訓練済みモデルのロード
このセクションでは、tf_flowers データセットで訓練された画像分類モデルをロードします。訓練セットを構築するためにトータル画像の 85% をが使用されました。訓練の詳細は、このノートブック を参照してください。
基礎となるモデルは (Big Transfer (BiT): General Visual Representation Learning で提案された) BiT-ResNet です。BiT-ResNet ファミリーのモデルは広範囲の様々な下流タスクに渡り優秀な転移性能を提供するとして知られています。
!wget -q https://git.io/JuMq0 -O flower_model_bit_0.96875.zip
!unzip -qq flower_model_bit_0.96875.zip
bit_model = tf.keras.models.load_model("flower_model_bit_0.96875")
bit_model.count_params()
23510597
埋め込みモデルの作成
問合せ (= query) 画像が与えられたとき類似の画像を取得するためには、最初に関係する総ての画像のベクトル表現を生成する必要があります。これを埋め込みモデル経由で行ないます、これは事前訓練済みの分類器から出力特徴を抽出して結果としての特徴ベクトルを正規化します。
embedding_model = tf.keras.Sequential(
[
tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
tf.keras.layers.Rescaling(scale=1.0 / 255),
bit_model.layers[1],
tf.keras.layers.Normalization(mean=0, variance=1),
],
name="embedding_model",
)
embedding_model.summary()
Model: "embedding_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 224, 224, 3) 0 _________________________________________________________________ keras_layer (KerasLayer) (None, 2048) 23500352 _________________________________________________________________ normalization (Normalization (None, 2048) 0 ================================================================= Total params: 23,500,352 Trainable params: 23,500,352 Non-trainable params: 0 _________________________________________________________________
モデル内の正規化層に気付いてください。それは表現ベクトルを単位球の空間に投影するために使用されます。
ハッシュ化ユティリティ
def hash_func(embedding, random_vectors):
embedding = np.array(embedding)
# Random projection.
bools = np.dot(embedding, random_vectors) > 0
return [bool2int(bool_vec) for bool_vec in bools]
def bool2int(x):
y = 0
for i, j in enumerate(x):
if j:
y += 1 << i
return y
embedding_model から出力されるベクトルの shape は (2048,) で、実用面 (ストレージ, 検索性能, etc.) を考慮するとそれは非常に大きいです。そのため、情報内容を減らすことなく埋め込みベクトルの次元性を削減する必要が発生します。これがランダム投影が登場する場所です。それは与えられた平面上の点群間の距離が近似的に保全される場合、その平面の次元性は更に削減できるという原理に基づいています。
hash_func() 内では、最初に埋め込みベクトルの次元を削減します。次にハッシュバケットを決めるために画像の bitwise なハッシュ値を計算します。同じハッシュ値を持つ画像は同じハッシュバケットに入る可能性が高いです。配備の観点からは、bitwise なハッシュ値はストアして演算するのに安価です。
Query ユティリティ
Table クラスは単一のハッシュテーブルを作成する役割を果たしています。ハッシュテーブルの各エントリはデータセットからの画像の削減された埋め込みと一意の識別子の間のマップです。次元削減テクニックはランダム性を伴いますので、処理が実行されるたびに類似の画像が同じハッシュバケットにマップされないことが起こりえます。この結果を減らすために、複数のテーブルからの結果を考慮します -- テーブルの数と削減次元はここではキーとなるハイパーパラメータです。
重要なことは、現実世界のアプリケーションに取り組むときは局所性鋭敏型ハッシュを貴方自身で再実装はしないでしょう。代わりに、以下のポピュラーなライブラリの一つを使用する可能性が高いです :
class Table:
def __init__(self, hash_size, dim):
self.table = {}
self.hash_size = hash_size
self.random_vectors = np.random.randn(hash_size, dim).T
def add(self, id, vectors, label):
# Create a unique indentifier.
entry = {"id_label": str(id) + "_" + str(label)}
# Compute the hash values.
hashes = hash_func(vectors, self.random_vectors)
# Add the hash values to the current table.
for h in hashes:
if h in self.table:
self.table[h].append(entry)
else:
self.table[h] = [entry]
def query(self, vectors):
# Compute hash value for the query vector.
hashes = hash_func(vectors, self.random_vectors)
results = []
# Loop over the query hashes and determine if they exist in
# the current table.
for h in hashes:
if h in self.table:
results.extend(self.table[h])
return results
次の LSH クラスではユティリティが複数のハッシュテーブルを持つようにパックします。
class LSH:
def __init__(self, hash_size, dim, num_tables):
self.num_tables = num_tables
self.tables = []
for i in range(self.num_tables):
self.tables.append(Table(hash_size, dim))
def add(self, id, vectors, label):
for table in self.tables:
table.add(id, vectors, label)
def query(self, vectors):
results = []
for table in self.tables:
results.extend(table.query(vectors))
return results
これでクラス内でマスター LSH テーブル (多くのテーブルのコレクション) を構築しておすさするためのロジックをカプセル化できます。それは 2 つのメソッドを持ちます :
- train(): 最終的な LSH テーブルの構築を担当します。
- query(): query 画像が与えられたとき一致の数を計算して類似スコアを定量化します。
class BuildLSHTable:
def __init__(
self,
prediction_model,
concrete_function=False,
hash_size=8,
dim=2048,
num_tables=10,
):
self.hash_size = hash_size
self.dim = dim
self.num_tables = num_tables
self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
self.prediction_model = prediction_model
self.concrete_function = concrete_function
def train(self, training_files):
for id, training_file in enumerate(training_files):
# Unpack the data.
image, label = training_file
if len(image.shape) < 4:
image = image[None, ...]
# Compute embeddings and update the LSH tables.
# More on `self.concrete_function()` later.
if self.concrete_function:
features = self.prediction_model(tf.constant(image))[
"normalization"
].numpy()
else:
features = self.prediction_model.predict(image)
self.lsh.add(id, features, label)
def query(self, image, verbose=True):
# Compute the embeddings of the query image and fetch the results.
if len(image.shape) < 4:
image = image[None, ...]
if self.concrete_function:
features = self.prediction_model(tf.constant(image))[
"normalization"
].numpy()
else:
features = self.prediction_model.predict(image)
results = self.lsh.query(features)
if verbose:
print("Matches:", len(results))
# Calculate Jaccard index to quantify the similarity.
counts = {}
for r in results:
if r["id_label"] in counts:
counts[r["id_label"]] += 1
else:
counts[r["id_label"]] = 1
for k in counts:
counts[k] = float(counts[k]) / self.dim
return counts
LSH テーブルの作成
実装されたヘルパーユティリティとクラスで、次に LSH テーブルを構築できます。最適化された埋め込みモデルと最適化されていないものの間でパフォーマンスをベンチマークしていくので、公正でない比較を避けるために GPU のウォームアップも行ないます。
# Utility to warm up the GPU.
def warmup():
dummy_sample = tf.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3))
for _ in range(100):
_ = embedding_model.predict(dummy_sample)
そして最初に GPU ウォームアップを行なって embedding_model でマスター LSH テーブルの構築に進むことができます。
warmup()
training_files = zip(images, labels)
lsh_builder = BuildLSHTable(embedding_model)
lsh_builder.train(training_files)
執筆時には、wall 時間は Tesla T4 GPU 上で 54.1 秒でした。この計時は使用している GPU に基づいて様々であるかもしれません。
TensorRT によるモデルの最適化
NVIDIA ベースの GPU については、 刈り取り (= pruning), 定数畳み込み (= constant folding), 層融合 (= layer fusion), 等のような様々な最適化テクニックを使用することで TensorRT フレームワーク は推論の遅延時間を劇的に良くするために使用できます。ここでは埋め込みモデルを最適化するために tf.experimental.tensorrt モジュールを使用します。
# First serialize the embedding model as a SavedModel.
embedding_model.save("embedding_model")
# Initialize the conversion parameters.
params = tf.experimental.tensorrt.ConversionParams(
precision_mode="FP16", maximum_cached_engines=16
)
# Run the conversion.
converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="embedding_model", conversion_params=params
)
converter.convert()
converter.save("tensorrt_embedding_model")
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: embedding_model/assets INFO:tensorflow:Assets written to: embedding_model/assets INFO:tensorflow:Linked TensorRT version: (0, 0, 0) INFO:tensorflow:Linked TensorRT version: (0, 0, 0) INFO:tensorflow:Loaded TensorRT version: (0, 0, 0) INFO:tensorflow:Loaded TensorRT version: (0, 0, 0) INFO:tensorflow:Assets written to: tensorrt_embedding_model/assets INFO:tensorflow:Assets written to: tensorrt_embedding_model/assets
tf.experimental.tensorrt.ConversionParams() の内部のパラメータについての注意 :
- precision_mode は変換されるモデルでの演算の数値精度を定義します。
- maximum_cached_engines は TRT エンジンの最大数を指定します、これは動的演算 (未知の shape による演算) を処理するためにキャッシュされます。
他のオプションについて更に学習するには、公式ドキュメント を参照してください。tf.experimental.tensorrt モジュールにより提供される様々な量子化オプションも調べることもできます。
# Load the converted model.
root = tf.saved_model.load("tensorrt_embedding_model")
trt_model_function = root.signatures["serving_default"]
最適化されたモデルで LSH テーブルを構築する
warmup()
training_files = zip(images, labels)
lsh_builder_trt = BuildLSHTable(trt_model_function, concrete_function=True)
lsh_builder_trt.train(training_files)
13.1 秒である wall 時間の差に注目してください。前は、非最適化モデルで 54.1 秒でした。
ハッシュ・テーブルの一つを詳しく調べてそれらがどのように表現されているかの考えを得ることができます。
idx = 0
for hash, entry in lsh_builder_trt.lsh.tables[0].table.items():
if idx == 5:
break
if len(entry) < 5:
print(hash, entry)
idx += 1
145 [{'id_label': '3_4'}, {'id_label': '727_3'}] 5 [{'id_label': '12_4'}] 128 [{'id_label': '30_2'}, {'id_label': '480_2'}] 208 [{'id_label': '34_2'}, {'id_label': '132_2'}, {'id_label': '984_2'}] 188 [{'id_label': '42_0'}, {'id_label': '135_3'}, {'id_label': '436_3'}, {'id_label': '670_3'}]
検証画像で結果を可視化
このセクションでは最初に類似画像の解析プロセスを可視化するユティリティ関数の幾つかを書きます。そして最適化ありとなしでモデルの query 性能をベンチマークします。
最初に、テスト目的で検証セットから 100 画像を取ります。
validation_images = []
validation_labels = []
for image, label in validation_ds.take(100):
image = tf.image.resize(image, (224, 224))
validation_images.append(image.numpy())
validation_labels.append(label.numpy())
validation_images = np.array(validation_images)
validation_labels = np.array(validation_labels)
validation_images.shape, validation_labels.shape
((100, 224, 224, 3), (100,))
次に可視化ユティリティを書きます。
def plot_images(images, labels):
plt.figure(figsize=(20, 10))
columns = 5
for (i, image) in enumerate(images):
ax = plt.subplot(len(images) / columns + 1, columns, i + 1)
if i == 0:
ax.set_title("Query Image\n" + "Label: {}".format(labels[i]))
else:
ax.set_title("Similar Image # " + str(i) + "\nLabel: {}".format(labels[i]))
plt.imshow(image.astype("int"))
plt.axis("off")
def visualize_lsh(lsh_class):
idx = np.random.choice(len(validation_images))
image = validation_images[idx]
label = validation_labels[idx]
results = lsh_class.query(image)
candidates = []
labels = []
overlaps = []
for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
if idx == 4:
break
image_id, label = r.split("_")[0], r.split("_")[1]
candidates.append(images[int(image_id)])
labels.append(label)
overlaps.append(results[r])
candidates.insert(0, image)
labels.insert(0, label)
plot_images(candidates, labels)
Non-TRT モデル
for _ in range(5):
visualize_lsh(lsh_builder)
visualize_lsh(lsh_builder)
Matches: 507 Matches: 554 Matches: 438 Matches: 370 Matches: 407 Matches: 306
TRT モデル
for _ in range(5):
visualize_lsh(lsh_builder_trt)
Matches: 458 Matches: 181 Matches: 280 Matches: 280 Matches: 503
気付いたかもしれませんが、幾つかの誤った結果があります。これは幾つかの方法で軽減できます :
- 特にノイズのあるサンプルに対して初期埋め込みを生成するためのより良いモデル。ArcFace, 教師あり対照学習, 等のようなテクニックを使用できます、これらは検索目的ための表現のより良い学習を暗黙的に促進します。
- テーブルの数と削減次元の間のトレードオフは重要でアプリケーションに必要な正しい recall を設定するのに役立ちます。
Query 性能のベンチマーク
def benchmark(lsh_class):
warmup()
start_time = time.time()
for _ in range(1000):
image = np.ones((1, 224, 224, 3)).astype("float32")
_ = lsh_class.query(image, verbose=False)
end_time = time.time() - start_time
print(f"Time taken: {end_time:.3f}")
benchmark(lsh_builder)
benchmark(lsh_builder_trt)
Time taken: 54.359 Time taken: 13.963
2 つのモデルの query 性能の間の大きな違いに直ちに気付くことができます。
Final remarks
このサンプルでは、モデルを最適化するために NVIDIA の TensorRT フレームワークを調べました。GPU ベースの推論サーバのためには最善に適しています。異なるハードウェア・プラットフォームに提供するフレームワークの他の選択肢があります :
- モバイルとエッジデバイスのための TensorFlow Lite。
- コモディティな CPU ベースのサーバのための ONNX。
- Apache TVM、様々なプラットフォームをカバーする機械学習モデルのためのコンパイラ。
一般のベクトル類似性検索に基づいたアプリケーションについて更に学習するため確認したいかもしれない幾つかのリソースがここにあります :
- ANN ベンチマーク
- 異方性ベクトル量子化 (ScaNN, Anisotropic Vector Quantization) による大規模スケールの推論の高速化
- 類似性検索のためのベクトルの拡散 (= Spreading)
- リアルタイム埋め込み類似性マッチングシステムの構築
以上