ホーム » Keras » TensorFlow 2.0 Beta : Tutorials : 画像 :- Keras で TensorFlow Hub

TensorFlow 2.0 Beta : Tutorials : 画像 :- Keras で TensorFlow Hub

TensorFlow 2.0 Beta : Beginner Tutorials : 画像 :- Keras で TensorFlow Hub (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/30/2019

* 本ページは、TensorFlow の本家サイトの TF 2.0 Beta – Beginner Tutorials – Images の以下のページを翻訳した上で
適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

 

画像 :- Keras で TensorFlow Hub

TensorFlow Hub は事前訓練されたモデル・コンポーネントを共有するための方法です。事前訓練モデルの検索可能なリストのためには TensorFlow Module Hubを見てください。このチュートリアルは以下を実演します :

  1. tf.keras で TensorFlow Hub をどのように使用するか。
  2. TensorFlow Hub を使用してどのように画像分類を行なうか。
  3. どのように単純な転移学習を行なうか。

 

セットアップ

from __future__ import absolute_import, division, print_function, unicode_literals

import matplotlib.pylab as plt

!pip install -q tensorflow-gpu==2.0.0-beta1
import tensorflow as tf
!pip install -q tensorflow_hub
import tensorflow_hub as hub

from tensorflow.keras import layers

 

ImageNet 分類器

分類器をダウンロードする

mobilenet をロードするために hub.module を、それを keras 層としてラップするために tf.keras.layers.Lambda を使用します。tfhub.dev からのどのような TensorFlow 2 互換画像分類器 URL もここで動作します。

classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}
IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))
])

 

単一画像上でそれを実行する

モデルを試すために単一のイメージをダウンロードします。

import numpy as np
import PIL.Image as Image

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 0us/step

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)

バッチ次元を追加して、画像をモデルに渡します。

result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)

結果はロジットの 1001 要素ベクトルで、画像のための各クラスの確率を見積もります。

そして top クラス ID は argmax で見つけられます :

predicted_class = np.argmax(result[0], axis=-1)
predicted_class
653

 

予測をデコードする

予測されたクラス ID を持ち、ImageNet ラベルを取得して、そして予測をデコードします。

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

 

単純な転移学習

TF Hub を使用するとき私達のデータセットでクラスを認識するためにモデルのトップ層を再訓練することは単純です。

 

データセット

このサンプルのために TensorFlow flowers データセットを使用します :

data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 4s 0us/step

このデータをモデルにロードする最も単純な方法は tf.keras.preprocessing.image.ImageDataGenerator を使用することです。

TensorFlow Hub の画像モジュールの総ては [0, 1] 範囲のfloat 入力を想定しています。これを獲得するために ImageDataGenerator の rescale パラメータを使用します。

画像サイズは後で処理されます。

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)
Found 3670 images belonging to 5 classes.

結果としてのオブジェクトは image_batch, label_batch ペアを返す iterator です。

for image_batch, label_batch in image_data:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
  break
Image batch shape:  (32, 224, 224, 3)
Label batch shape:  (32, 5)

 

画像のバッチ上で分類器を実行する

さて画像バッチ上で分類器を実行します。

result_batch = classifier.predict(image_batch)
result_batch.shape
(32, 1001)
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'coral fungus', 'dining table', 'feather boa',
       'park bench', 'daisy', 'bakery', 'daisy', 'feather boa', 'pot',
       'ice cream', 'picket fence', 'daisy', 'vase', 'confectionery',
       'picket fence', 'daisy', 'daisy', 'picket fence', 'sunglasses',
       'porcupine', 'picket fence', "yellow lady's slipper", 'daisy',
       'mushroom', 'Lhasa', 'picket fence', 'daisy', 'daisy', 'rapeseed',
       'bee', 'picket fence'], dtype='<U30')

今はこれらの予測が画像とともにどのように並ぶかを確認します :

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

画像属性については LICENSE.txt 参照。

結果はパーフェクトからは程遠いですが、これらはモデルがそのために訓練されたクラスではないというのが合理的な考えです (“daisy” を除いて)。

 

ヘッドレス・モデルをダウンロードする

TensorFlow Hub はまた top 分類層なしのモデルも配布しています。これらは容易に転移学習を行なうために使用できます。

tfhub.dev からのどのようなどのような TensorFlow 2 互換画像分類器 URL もここで動作します。

feature_extractor_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2" #@param {type:"string"}

特徴抽出器を作成します。

feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
                                         input_shape=(224,224,3))

それは各画像の 1280-長ベクトルを返します :

feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)

特徴抽出器層の変数を凍結します、その結果訓練は新しい分類層だけを変更します。

feature_extractor_layer.trainable = False

 

分類ヘッドを装着する

さて hub 層を tf.keras.Sequential モデルでラップし、新しい分類層を追加します。

model = tf.keras.Sequential([
  feature_extractor_layer,
  layers.Dense(image_data.num_classes, activation='softmax')
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer_1 (KerasLayer)   (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])

 

モデルを訓練する

訓練プロセスを configure するために compile を使用します :

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss='categorical_crossentropy',
  metrics=['acc'])

そしてモデルを訓練するために .fit メソッドを使用します。

このサンプルを短く保持するために 2 エポックだけ訓練します。訓練進捗を可視化するために、エポック平均の代わりに各バッチの個々の損失と精度を記録するためにカスタム callback を使用します。

class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

batch_stats_callback = CollectBatchStats()

history = model.fit(image_data, epochs=2,
                    steps_per_epoch=steps_per_epoch,
                    callbacks = [batch_stats_callback])
Epoch 1/2

WARNING: Logging before flag parsing goes to stderr.
W0628 03:59:21.830983 139712134772480 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

115/115 [==============================] - 23s 196ms/step - loss: 0.7078 - acc: 0.8438
Epoch 2/2
115/115 [==============================] - 22s 194ms/step - loss: 0.3480 - acc: 0.7812

さて 2, 3 の訓練反復後でさえ、モデルがタスクにおいて進捗していることを既に見ることができます。

plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)
[<matplotlib.lines.Line2D at 0x7f0be83547f0>]

plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
[<matplotlib.lines.Line2D at 0x7f0be17fe128>]

 

予測をチェックする

前からのプロットを再び行なうために、最初にクラス名の順序付けられたリストを得ます :

class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
class_names
array(['Daisy', 'Dandelion', 'Roses', 'Sunflowers', 'Tulips'],
      dtype='<U10')

モデルを通して画像バッチを実行してインデックスをクラス名に変換します。

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

結果をプロットします。

label_id = np.argmax(label_batch, axis=-1)
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  color = "green" if predicted_id[n] == label_id[n] else "red"
  plt.title(predicted_label_batch[n].title(), color=color)
  plt.axis('off')
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")

 

貴方のモデルをエクスポートする

モデルを訓練した今、それを saved model としてエクスポートします :

import time
t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
tf.keras.experimental.export_saved_model(model, export_path)

export_path
W0628 04:00:17.005649 139712134772480 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0628 04:00:17.007575 139712134772480 export_utils.py:182] Export includes no default signature!
W0628 04:00:17.746208 139712134772480 meta_graph.py:450] Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'
W0628 04:00:24.953796 139712134772480 export_utils.py:182] Export includes no default signature!
W0628 04:00:25.634641 139712134772480 meta_graph.py:450] Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'
Exception ignored in: <bound method _CheckpointRestoreCoordinator.__del__ of <tensorflow.python.training.tracking.util._CheckpointRestoreCoordinator object at 0x7f0b50286eb8>>
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/tracking/util.py", line 244, in __del__
    .format(pretty_printer.node_names[node_id]))
  File "/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/tracking/util.py", line 93, in node_names
    path_to_root[node_id] + (child.local_name,))
  File "/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/tracking/object_identity.py", line 76, in __getitem__
    return self._storage[self._wrap_key(key)]
KeyError: (<tensorflow.python.training.tracking.object_identity._ObjectIdentityWrapper object at 0x7f0b43e18e48>,)
W0628 04:00:34.214757 139712134772480 meta_graph.py:450] Issue encountered when serializing variables.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'list' object has no attribute 'name'

'/tmp/saved_models/1561694408'

今はそれを再ロードできて、それが依然として同じ結果を与えることを確認しましょう :

reloaded = tf.keras.experimental.load_from_saved_model(export_path, custom_objects={'KerasLayer':hub.KerasLayer})
Exception ignored in: <bound method _CheckpointRestoreCoordinator.__del__ of <tensorflow.python.training.tracking.util._CheckpointRestoreCoordinator object at 0x7f0b40ed2a20>>
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/tracking/util.py", line 244, in __del__
    .format(pretty_printer.node_names[node_id]))
  File "/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/tracking/util.py", line 93, in node_names
    path_to_root[node_id] + (child.local_name,))
  File "/tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/tracking/object_identity.py", line 76, in __getitem__
    return self._storage[self._wrap_key(key)]
KeyError: (<tensorflow.python.training.tracking.object_identity._ObjectIdentityWrapper object at 0x7f0b3b9485f8>,)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()

0.0

savd model は後で推論のためにロードするか、TFLiteTFjs のために変換できます。

 

以上



AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com