TensorFlow 2.0 : 上級 Tutorials : 画像 :- 事前訓練された ConvNet で転移学習 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/07/2019
* 本ページは、TensorFlow org サイトの TF 2.0 – Advanced Tutorials – Images の以下のページを翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
画像 :- 事前訓練された ConvNet で転移学習
このチュートリアルでは事前訓練されたネットワークからの転移学習を使用して猫 vs 犬画像をどのように分類するかを学習します。
事前訓練されたモデル は、典型的には巨大スケールの画像分類タスク上、巨大なデータセット上で以前に訓練された (セーブされた) ネットワークです。事前訓練されたモデルをそのまま使用するか与えられたタスクにこのモデルをカスタマイズするために 転移学習 を使用します。
転移学習の背後にある直感は、このモデルが十分に巨大で一般的なデータセット上で訓練された場合、このモデルは視覚世界の一般的なモデルとして効果的に役立つであろうということです。それから巨大なデータセット上で巨大なモデルをスクラッチから訓練し始めることなくこれらの学習された特徴マップを活用できます。
このノートブックでは、事前訓練されたモデルをカスタマイズする 2 つの方法を試します :
- 特徴抽出 (= Feature Extraction) – 新しいサンプルから意味がある特徴を抽出するために以前のネットワークにより学習された表現を使用します。事前訓練されたモデルの上に (スクラッチから訓練される) 新しい分類器を単に追加します、その結果前に学習された特徴マップを私達のデータセットのために再目的化できます。
モデル全体を (再) 訓練する必要はありません。ベース畳み込みネットワークは既に写真を分類するために一般的に有用な特徴を既に含んでいます。けれども、事前訓練されたモデルの最後の分類パートは元の分類タスクに特有で、結果的に (その上で) モデルが訓練されたクラスのセットに特有です。
- 再調整 (= Fine-Tuning) – 凍結されたモデルベースの 2, 3 のトップ層を解凍して、新たに追加された分類層とベースモデルの最後の層群の両者を一緒に訓練します。これはベースモデルの高次特徴表現を、特定のタスクのためにより関連付けるために「再調整」することを可能にします。
一般的な機械学習ワークフローに従います。
- データを調べて理解する。
- 入力パイプラインを構築します、このケースでは Keras ImageDataGenerator を使用します。
- モデルを構成する。
- 事前訓練されたベースモデル (と事前訓練された重み) をロードする
- トップに分類層をスタックする。
- モデルを訓練する。
- モデルを評価する。
from __future__ import absolute_import, division, print_function, unicode_literals import os import numpy as np import matplotlib.pyplot as plt
import tensorflow as tf keras = tf.keras
データ前処理
データ・ダウンロード
猫と犬のデータセットをロードするために TensorFlow Dataset を利用します。
この tfds パッケージは事前定義されたデータをロードする最も容易な方法です。もし貴方自身のデータを持ち、インポートしてそれを TensorFlow で使用することに興味があれば loading image data を見てください。
import tensorflow_datasets as tfds tfds.disable_progress_bar()
tfds.load メソッドはデータをダウンロードしてキャッシュし、tf.data.Dataset オブジェクトを返します。これらのオブジェクトはデータを操作してそれをモデルにパイプするためのパワフルで、効率的なメソッドを提供します。
“cats_vs_dog” は標準的な分割を定義していないので、それをデータの 80%, 10%, 10% で (train, validation, test) にそれぞれ分割するために subsplit 機能を使用します。
SPLIT_WEIGHTS = (8, 1, 1) splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS) (raw_train, raw_validation, raw_test), metadata = tfds.load( 'cats_vs_dogs', split=list(splits), with_info=True, as_supervised=True)
Downloading and preparing dataset cats_vs_dogs (786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/2.0.1... /home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning, WARNING:absl:1738 images were corrupted and were skipped WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)` WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)` Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/2.0.1. Subsequent calls will reuse this data.
結果としての tf.data.Dataset オブジェクトは (image, label) ペアを含みます。そこでは画像は可変な shape と 3 チャネルを持ち、そしてラベルはスカラーです。
print(raw_train) print(raw_validation) print(raw_test)
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)> <_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)> <_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
訓練セットから最初の 2 つの画像とラベルを表示します。
get_label_name = metadata.features['label'].int2str for image, label in raw_train.take(2): plt.figure() plt.imshow(image) plt.title(get_label_name(label))
データをフォーマットする
タスクのために画像をフォーマットするために tf.image モジュールを使用します。
画像を固定入力サイズにリサイズして、入力チャネルを [-1, 1] の範囲にリスケールします。
IMG_SIZE = 160 # All images will be resized to 160x160 def format_example(image, label): image = tf.cast(image, tf.float32) image = (image/127.5) - 1 image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) return image, label
map メソッドを使用してこの関数をデータセットの各アイテムに適用します :
train = raw_train.map(format_example) validation = raw_validation.map(format_example) test = raw_test.map(format_example)
今はデータをシャッフルしてバッチ化します。
BATCH_SIZE = 32 SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE) validation_batches = validation.batch(BATCH_SIZE) test_batches = test.batch(BATCH_SIZE)
データのバッチを調べます :
for image_batch, label_batch in train_batches.take(1): pass image_batch.shape
TensorShape([32, 160, 160, 3])
事前訓練された convnet からベースモデルを作成する
Google で開発された MobileNet V2 モデルからベースモデルを作成します。これは ImageNet データセット、web 画像の 1.4 M 画像と 1000 クラスの巨大データセット上で事前訓練されています。ImageNet はパンノキ (= jackfruit) と注射器のようなカテゴリを持つ非常に恣意的な研究訓練データセットを持ちますが、この知識の土台は特定のデータセットから猫と犬を識別するのに役立ちます。
最初に、特徴抽出のために使用する MobileNet V2 の層を選択する必要があります。明らかに、最も最後の分類層 (「トップ」上、何故ならば機械学習モデルの殆どの図はボトムからトップに進みます) は全く役立ちません。代わりに、flatten 演算の前の最も最後の層に依拠する一般的な実践に従います。この層は「ボトルネック層」と呼ばれます。ボトルネック特徴は final/top 層に比較して遥かに汎用性を保持します。
最初に、ImageNet 上で訓練された重みとともに事前ロードされた MobileNet V2 モデルをインスタンス化します。include_top=False 引数を指定することにより、トップに分類層を含まないネットワークをロードします、これは特徴抽出のために理想的です。
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3) # Create the base model from the pre-trained model MobileNet V2 base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 1s 0us/step
この特徴抽出器は各 160x160x3 画像を 5x5x1280 特徴ブロックに変換します。画像のサンプルバッチにそれが何をするかを見ます :
feature_batch = base_model(image_batch) print(feature_batch.shape)
(32, 5, 5, 1280)
特徴抽出
前のステップで作成された畳み込みベースを凍結してそれを特徴抽出器として使用し、その上に分類器を追加して top-level 分類器を訓練します。
畳み込みベースを凍結する
compile してモデルを訓練する前に畳み込みベースを凍結することは重要です。凍結する (あるいは layer.trainable = False を設定する) ことにより、与えられた層の重みが訓練の間に更新されることを回避します。MobileNet V2 は多くの層を持ちますが、全体のモデルの trainable フラグを False に設定すれば総ての層を凍結します。
base_model.trainable = False
# Let's take a look at the base model architecture base_model.summary()
Model: "mobilenetv2_1.00_160" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 160, 160, 3) 0 __________________________________________________________________________________________________ Conv1_pad (ZeroPadding2D) (None, 161, 161, 3) 0 input_1[0][0] __________________________________________________________________________________________________ Conv1 (Conv2D) (None, 80, 80, 32) 864 Conv1_pad[0][0] __________________________________________________________________________________________________ bn_Conv1 (BatchNormalization) (None, 80, 80, 32) 128 Conv1[0][0] __________________________________________________________________________________________________ Conv1_relu (ReLU) (None, 80, 80, 32) 0 bn_Conv1[0][0] __________________________________________________________________________________________________ expanded_conv_depthwise (Depthw (None, 80, 80, 32) 288 Conv1_relu[0][0] __________________________________________________________________________________________________ expanded_conv_depthwise_BN (Bat (None, 80, 80, 32) 128 expanded_conv_depthwise[0][0] __________________________________________________________________________________________________ expanded_conv_depthwise_relu (R (None, 80, 80, 32) 0 expanded_conv_depthwise_BN[0][0] __________________________________________________________________________________________________ expanded_conv_project (Conv2D) (None, 80, 80, 16) 512 expanded_conv_depthwise_relu[0][0 __________________________________________________________________________________________________ expanded_conv_project_BN (Batch (None, 80, 80, 16) 64 expanded_conv_project[0][0] __________________________________________________________________________________________________ block_1_expand (Conv2D) (None, 80, 80, 96) 1536 expanded_conv_project_BN[0][0] __________________________________________________________________________________________________ block_1_expand_BN (BatchNormali (None, 80, 80, 96) 384 block_1_expand[0][0] __________________________________________________________________________________________________ block_1_expand_relu (ReLU) (None, 80, 80, 96) 0 block_1_expand_BN[0][0] __________________________________________________________________________________________________ block_1_pad (ZeroPadding2D) (None, 81, 81, 96) 0 block_1_expand_relu[0][0] __________________________________________________________________________________________________ block_1_depthwise (DepthwiseCon (None, 40, 40, 96) 864 block_1_pad[0][0] __________________________________________________________________________________________________ block_1_depthwise_BN (BatchNorm (None, 40, 40, 96) 384 block_1_depthwise[0][0] __________________________________________________________________________________________________ block_1_depthwise_relu (ReLU) (None, 40, 40, 96) 0 block_1_depthwise_BN[0][0] __________________________________________________________________________________________________ block_1_project (Conv2D) (None, 40, 40, 24) 2304 block_1_depthwise_relu[0][0] __________________________________________________________________________________________________ block_1_project_BN (BatchNormal (None, 40, 40, 24) 96 block_1_project[0][0] __________________________________________________________________________________________________ block_2_expand (Conv2D) (None, 40, 40, 144) 3456 block_1_project_BN[0][0] __________________________________________________________________________________________________ block_2_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_2_expand[0][0] __________________________________________________________________________________________________ block_2_expand_relu (ReLU) (None, 40, 40, 144) 0 block_2_expand_BN[0][0] __________________________________________________________________________________________________ block_2_depthwise (DepthwiseCon (None, 40, 40, 144) 1296 block_2_expand_relu[0][0] __________________________________________________________________________________________________ block_2_depthwise_BN (BatchNorm (None, 40, 40, 144) 576 block_2_depthwise[0][0] __________________________________________________________________________________________________ block_2_depthwise_relu (ReLU) (None, 40, 40, 144) 0 block_2_depthwise_BN[0][0] __________________________________________________________________________________________________ block_2_project (Conv2D) (None, 40, 40, 24) 3456 block_2_depthwise_relu[0][0] __________________________________________________________________________________________________ block_2_project_BN (BatchNormal (None, 40, 40, 24) 96 block_2_project[0][0] __________________________________________________________________________________________________ block_2_add (Add) (None, 40, 40, 24) 0 block_1_project_BN[0][0] block_2_project_BN[0][0] __________________________________________________________________________________________________ block_3_expand (Conv2D) (None, 40, 40, 144) 3456 block_2_add[0][0] __________________________________________________________________________________________________ block_3_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_3_expand[0][0] __________________________________________________________________________________________________ block_3_expand_relu (ReLU) (None, 40, 40, 144) 0 block_3_expand_BN[0][0] __________________________________________________________________________________________________ block_3_pad (ZeroPadding2D) (None, 41, 41, 144) 0 block_3_expand_relu[0][0] __________________________________________________________________________________________________ block_3_depthwise (DepthwiseCon (None, 20, 20, 144) 1296 block_3_pad[0][0] __________________________________________________________________________________________________ block_3_depthwise_BN (BatchNorm (None, 20, 20, 144) 576 block_3_depthwise[0][0] __________________________________________________________________________________________________ block_3_depthwise_relu (ReLU) (None, 20, 20, 144) 0 block_3_depthwise_BN[0][0] __________________________________________________________________________________________________ block_3_project (Conv2D) (None, 20, 20, 32) 4608 block_3_depthwise_relu[0][0] __________________________________________________________________________________________________ block_3_project_BN (BatchNormal (None, 20, 20, 32) 128 block_3_project[0][0] __________________________________________________________________________________________________ block_4_expand (Conv2D) (None, 20, 20, 192) 6144 block_3_project_BN[0][0] __________________________________________________________________________________________________ block_4_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_4_expand[0][0] __________________________________________________________________________________________________ block_4_expand_relu (ReLU) (None, 20, 20, 192) 0 block_4_expand_BN[0][0] __________________________________________________________________________________________________ block_4_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_4_expand_relu[0][0] __________________________________________________________________________________________________ block_4_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_4_depthwise[0][0] __________________________________________________________________________________________________ block_4_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_4_depthwise_BN[0][0] __________________________________________________________________________________________________ block_4_project (Conv2D) (None, 20, 20, 32) 6144 block_4_depthwise_relu[0][0] __________________________________________________________________________________________________ block_4_project_BN (BatchNormal (None, 20, 20, 32) 128 block_4_project[0][0] __________________________________________________________________________________________________ block_4_add (Add) (None, 20, 20, 32) 0 block_3_project_BN[0][0] block_4_project_BN[0][0] __________________________________________________________________________________________________ block_5_expand (Conv2D) (None, 20, 20, 192) 6144 block_4_add[0][0] __________________________________________________________________________________________________ block_5_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_5_expand[0][0] __________________________________________________________________________________________________ block_5_expand_relu (ReLU) (None, 20, 20, 192) 0 block_5_expand_BN[0][0] __________________________________________________________________________________________________ block_5_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_5_expand_relu[0][0] __________________________________________________________________________________________________ block_5_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_5_depthwise[0][0] __________________________________________________________________________________________________ block_5_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_5_depthwise_BN[0][0] __________________________________________________________________________________________________ block_5_project (Conv2D) (None, 20, 20, 32) 6144 block_5_depthwise_relu[0][0] __________________________________________________________________________________________________ block_5_project_BN (BatchNormal (None, 20, 20, 32) 128 block_5_project[0][0] __________________________________________________________________________________________________ block_5_add (Add) (None, 20, 20, 32) 0 block_4_add[0][0] block_5_project_BN[0][0] __________________________________________________________________________________________________ block_6_expand (Conv2D) (None, 20, 20, 192) 6144 block_5_add[0][0] __________________________________________________________________________________________________ block_6_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_6_expand[0][0] __________________________________________________________________________________________________ block_6_expand_relu (ReLU) (None, 20, 20, 192) 0 block_6_expand_BN[0][0] __________________________________________________________________________________________________ block_6_pad (ZeroPadding2D) (None, 21, 21, 192) 0 block_6_expand_relu[0][0] __________________________________________________________________________________________________ block_6_depthwise (DepthwiseCon (None, 10, 10, 192) 1728 block_6_pad[0][0] __________________________________________________________________________________________________ block_6_depthwise_BN (BatchNorm (None, 10, 10, 192) 768 block_6_depthwise[0][0] __________________________________________________________________________________________________ block_6_depthwise_relu (ReLU) (None, 10, 10, 192) 0 block_6_depthwise_BN[0][0] __________________________________________________________________________________________________ block_6_project (Conv2D) (None, 10, 10, 64) 12288 block_6_depthwise_relu[0][0] __________________________________________________________________________________________________ block_6_project_BN (BatchNormal (None, 10, 10, 64) 256 block_6_project[0][0] __________________________________________________________________________________________________ block_7_expand (Conv2D) (None, 10, 10, 384) 24576 block_6_project_BN[0][0] __________________________________________________________________________________________________ block_7_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_7_expand[0][0] __________________________________________________________________________________________________ block_7_expand_relu (ReLU) (None, 10, 10, 384) 0 block_7_expand_BN[0][0] __________________________________________________________________________________________________ block_7_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_7_expand_relu[0][0] __________________________________________________________________________________________________ block_7_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_7_depthwise[0][0] __________________________________________________________________________________________________ block_7_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_7_depthwise_BN[0][0] __________________________________________________________________________________________________ block_7_project (Conv2D) (None, 10, 10, 64) 24576 block_7_depthwise_relu[0][0] __________________________________________________________________________________________________ block_7_project_BN (BatchNormal (None, 10, 10, 64) 256 block_7_project[0][0] __________________________________________________________________________________________________ block_7_add (Add) (None, 10, 10, 64) 0 block_6_project_BN[0][0] block_7_project_BN[0][0] __________________________________________________________________________________________________ block_8_expand (Conv2D) (None, 10, 10, 384) 24576 block_7_add[0][0] __________________________________________________________________________________________________ block_8_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_8_expand[0][0] __________________________________________________________________________________________________ block_8_expand_relu (ReLU) (None, 10, 10, 384) 0 block_8_expand_BN[0][0] __________________________________________________________________________________________________ block_8_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_8_expand_relu[0][0] __________________________________________________________________________________________________ block_8_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_8_depthwise[0][0] __________________________________________________________________________________________________ block_8_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_8_depthwise_BN[0][0] __________________________________________________________________________________________________ block_8_project (Conv2D) (None, 10, 10, 64) 24576 block_8_depthwise_relu[0][0] __________________________________________________________________________________________________ block_8_project_BN (BatchNormal (None, 10, 10, 64) 256 block_8_project[0][0] __________________________________________________________________________________________________ block_8_add (Add) (None, 10, 10, 64) 0 block_7_add[0][0] block_8_project_BN[0][0] __________________________________________________________________________________________________ block_9_expand (Conv2D) (None, 10, 10, 384) 24576 block_8_add[0][0] __________________________________________________________________________________________________ block_9_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_9_expand[0][0] __________________________________________________________________________________________________ block_9_expand_relu (ReLU) (None, 10, 10, 384) 0 block_9_expand_BN[0][0] __________________________________________________________________________________________________ block_9_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_9_expand_relu[0][0] __________________________________________________________________________________________________ block_9_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_9_depthwise[0][0] __________________________________________________________________________________________________ block_9_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_9_depthwise_BN[0][0] __________________________________________________________________________________________________ block_9_project (Conv2D) (None, 10, 10, 64) 24576 block_9_depthwise_relu[0][0] __________________________________________________________________________________________________ block_9_project_BN (BatchNormal (None, 10, 10, 64) 256 block_9_project[0][0] __________________________________________________________________________________________________ block_9_add (Add) (None, 10, 10, 64) 0 block_8_add[0][0] block_9_project_BN[0][0] __________________________________________________________________________________________________ block_10_expand (Conv2D) (None, 10, 10, 384) 24576 block_9_add[0][0] __________________________________________________________________________________________________ block_10_expand_BN (BatchNormal (None, 10, 10, 384) 1536 block_10_expand[0][0] __________________________________________________________________________________________________ block_10_expand_relu (ReLU) (None, 10, 10, 384) 0 block_10_expand_BN[0][0] __________________________________________________________________________________________________ block_10_depthwise (DepthwiseCo (None, 10, 10, 384) 3456 block_10_expand_relu[0][0] __________________________________________________________________________________________________ block_10_depthwise_BN (BatchNor (None, 10, 10, 384) 1536 block_10_depthwise[0][0] __________________________________________________________________________________________________ block_10_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_10_depthwise_BN[0][0] __________________________________________________________________________________________________ block_10_project (Conv2D) (None, 10, 10, 96) 36864 block_10_depthwise_relu[0][0] __________________________________________________________________________________________________ block_10_project_BN (BatchNorma (None, 10, 10, 96) 384 block_10_project[0][0] __________________________________________________________________________________________________ block_11_expand (Conv2D) (None, 10, 10, 576) 55296 block_10_project_BN[0][0] __________________________________________________________________________________________________ block_11_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_11_expand[0][0] __________________________________________________________________________________________________ block_11_expand_relu (ReLU) (None, 10, 10, 576) 0 block_11_expand_BN[0][0] __________________________________________________________________________________________________ block_11_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_11_expand_relu[0][0] __________________________________________________________________________________________________ block_11_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_11_depthwise[0][0] __________________________________________________________________________________________________ block_11_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_11_depthwise_BN[0][0] __________________________________________________________________________________________________ block_11_project (Conv2D) (None, 10, 10, 96) 55296 block_11_depthwise_relu[0][0] __________________________________________________________________________________________________ block_11_project_BN (BatchNorma (None, 10, 10, 96) 384 block_11_project[0][0] __________________________________________________________________________________________________ block_11_add (Add) (None, 10, 10, 96) 0 block_10_project_BN[0][0] block_11_project_BN[0][0] __________________________________________________________________________________________________ block_12_expand (Conv2D) (None, 10, 10, 576) 55296 block_11_add[0][0] __________________________________________________________________________________________________ block_12_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_12_expand[0][0] __________________________________________________________________________________________________ block_12_expand_relu (ReLU) (None, 10, 10, 576) 0 block_12_expand_BN[0][0] __________________________________________________________________________________________________ block_12_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_12_expand_relu[0][0] __________________________________________________________________________________________________ block_12_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_12_depthwise[0][0] __________________________________________________________________________________________________ block_12_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_12_depthwise_BN[0][0] __________________________________________________________________________________________________ block_12_project (Conv2D) (None, 10, 10, 96) 55296 block_12_depthwise_relu[0][0] __________________________________________________________________________________________________ block_12_project_BN (BatchNorma (None, 10, 10, 96) 384 block_12_project[0][0] __________________________________________________________________________________________________ block_12_add (Add) (None, 10, 10, 96) 0 block_11_add[0][0] block_12_project_BN[0][0] __________________________________________________________________________________________________ block_13_expand (Conv2D) (None, 10, 10, 576) 55296 block_12_add[0][0] __________________________________________________________________________________________________ block_13_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_13_expand[0][0] __________________________________________________________________________________________________ block_13_expand_relu (ReLU) (None, 10, 10, 576) 0 block_13_expand_BN[0][0] __________________________________________________________________________________________________ block_13_pad (ZeroPadding2D) (None, 11, 11, 576) 0 block_13_expand_relu[0][0] __________________________________________________________________________________________________ block_13_depthwise (DepthwiseCo (None, 5, 5, 576) 5184 block_13_pad[0][0] __________________________________________________________________________________________________ block_13_depthwise_BN (BatchNor (None, 5, 5, 576) 2304 block_13_depthwise[0][0] __________________________________________________________________________________________________ block_13_depthwise_relu (ReLU) (None, 5, 5, 576) 0 block_13_depthwise_BN[0][0] __________________________________________________________________________________________________ block_13_project (Conv2D) (None, 5, 5, 160) 92160 block_13_depthwise_relu[0][0] __________________________________________________________________________________________________ block_13_project_BN (BatchNorma (None, 5, 5, 160) 640 block_13_project[0][0] __________________________________________________________________________________________________ block_14_expand (Conv2D) (None, 5, 5, 960) 153600 block_13_project_BN[0][0] __________________________________________________________________________________________________ block_14_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_14_expand[0][0] __________________________________________________________________________________________________ block_14_expand_relu (ReLU) (None, 5, 5, 960) 0 block_14_expand_BN[0][0] __________________________________________________________________________________________________ block_14_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_14_expand_relu[0][0] __________________________________________________________________________________________________ block_14_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_14_depthwise[0][0] __________________________________________________________________________________________________ block_14_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_14_depthwise_BN[0][0] __________________________________________________________________________________________________ block_14_project (Conv2D) (None, 5, 5, 160) 153600 block_14_depthwise_relu[0][0] __________________________________________________________________________________________________ block_14_project_BN (BatchNorma (None, 5, 5, 160) 640 block_14_project[0][0] __________________________________________________________________________________________________ block_14_add (Add) (None, 5, 5, 160) 0 block_13_project_BN[0][0] block_14_project_BN[0][0] __________________________________________________________________________________________________ block_15_expand (Conv2D) (None, 5, 5, 960) 153600 block_14_add[0][0] __________________________________________________________________________________________________ block_15_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_15_expand[0][0] __________________________________________________________________________________________________ block_15_expand_relu (ReLU) (None, 5, 5, 960) 0 block_15_expand_BN[0][0] __________________________________________________________________________________________________ block_15_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_15_expand_relu[0][0] __________________________________________________________________________________________________ block_15_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_15_depthwise[0][0] __________________________________________________________________________________________________ block_15_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_15_depthwise_BN[0][0] __________________________________________________________________________________________________ block_15_project (Conv2D) (None, 5, 5, 160) 153600 block_15_depthwise_relu[0][0] __________________________________________________________________________________________________ block_15_project_BN (BatchNorma (None, 5, 5, 160) 640 block_15_project[0][0] __________________________________________________________________________________________________ block_15_add (Add) (None, 5, 5, 160) 0 block_14_add[0][0] block_15_project_BN[0][0] __________________________________________________________________________________________________ block_16_expand (Conv2D) (None, 5, 5, 960) 153600 block_15_add[0][0] __________________________________________________________________________________________________ block_16_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_16_expand[0][0] __________________________________________________________________________________________________ block_16_expand_relu (ReLU) (None, 5, 5, 960) 0 block_16_expand_BN[0][0] __________________________________________________________________________________________________ block_16_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_16_expand_relu[0][0] __________________________________________________________________________________________________ block_16_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_16_depthwise[0][0] __________________________________________________________________________________________________ block_16_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_16_depthwise_BN[0][0] __________________________________________________________________________________________________ block_16_project (Conv2D) (None, 5, 5, 320) 307200 block_16_depthwise_relu[0][0] __________________________________________________________________________________________________ block_16_project_BN (BatchNorma (None, 5, 5, 320) 1280 block_16_project[0][0] __________________________________________________________________________________________________ Conv_1 (Conv2D) (None, 5, 5, 1280) 409600 block_16_project_BN[0][0] __________________________________________________________________________________________________ Conv_1_bn (BatchNormalization) (None, 5, 5, 1280) 5120 Conv_1[0][0] __________________________________________________________________________________________________ out_relu (ReLU) (None, 5, 5, 1280) 0 Conv_1_bn[0][0] ================================================================================================== Total params: 2,257,984 Trainable params: 0 Non-trainable params: 2,257,984 __________________________________________________________________________________________________
分類ヘッドを追加する
特徴ブロックから予測を生成するために、特徴を画像毎に単一 1280-要素ベクトルに変換するために tf.keras.layers.GlobalAveragePlloing2d 層を使用して 5×5 空間的位置に渡り平均します。
global_average_layer = tf.keras.layers.GlobalAveragePooling2D() feature_batch_average = global_average_layer(feature_batch) print(feature_batch_average.shape)
(32, 1280)
これらの特徴を画像毎に単一の予測に変換するために tf.keras.layers.Dense 層を適用します。ここでは活性化関数は必要ありません、何故ならばこの予測はロジット、あるいは生の予測値として扱われるからです。正数はクラス 1 を予測し、負数はクラス 0 を予測します。
prediction_layer = keras.layers.Dense(1) prediction_batch = prediction_layer(feature_batch_average) print(prediction_batch.shape)
(32, 1)
さて特徴抽出器、そしてこれらの 2 つの層を tf.keras.Sequential モデルを使用してスタックします :
model = tf.keras.Sequential([ base_model, global_average_layer, prediction_layer ])
モデルをコンパイルする
モデルを訓練する前にそれをコンパイルしなければなりません。2 クラスありますので、二値交差エントロピー損失を使用します。
base_learning_rate = 0.0001 model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984 _________________________________________________________________ global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 1) 1281 ================================================================= Total params: 2,259,265 Trainable params: 1,281 Non-trainable params: 2,257,984 _________________________________________________________________
MobileNet の 2.5M パラメータは凍結されますが、Dense 層に 1.2K の訓練可能なパラメータがあります。これらは 2 つの tf.Variable オブジェクト、重みとバイアスに分けられます。
len(model.trainable_variables)
2
モデルを訓練する
10 エポックの間の訓練後、 ~96% 精度を見るはずです。
num_train, num_val, num_test = ( metadata.splits['train'].num_examples*weight/10 for weight in SPLIT_WEIGHTS )
initial_epochs = 10 steps_per_epoch = round(num_train)//BATCH_SIZE validation_steps = 20 loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)
20/20 [==============================] - 2s 92ms/step - loss: 5.1612 - accuracy: 0.5141
print("initial loss: {:.2f}".format(loss0)) print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 5.16 initial accuracy: 0.51
history = model.fit(train_batches, epochs=initial_epochs, validation_data=validation_batches)
Epoch 1/10 582/582 [==============================] - 36s 61ms/step - loss: 2.3922 - accuracy: 0.7156 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00 Epoch 2/10 582/582 [==============================] - 31s 53ms/step - loss: 1.0886 - accuracy: 0.8540 - val_loss: 0.5596 - val_accuracy: 0.9181 Epoch 3/10 582/582 [==============================] - 31s 53ms/step - loss: 0.7876 - accuracy: 0.8961 - val_loss: 0.4342 - val_accuracy: 0.9453 Epoch 4/10 582/582 [==============================] - 31s 53ms/step - loss: 0.6501 - accuracy: 0.9128 - val_loss: 0.4245 - val_accuracy: 0.9474 Epoch 5/10 582/582 [==============================] - 31s 53ms/step - loss: 0.6074 - accuracy: 0.9217 - val_loss: 0.3910 - val_accuracy: 0.9556 Epoch 6/10 582/582 [==============================] - 31s 53ms/step - loss: 0.5526 - accuracy: 0.9255 - val_loss: 0.5181 - val_accuracy: 0.9427 Epoch 7/10 582/582 [==============================] - 31s 53ms/step - loss: 0.4930 - accuracy: 0.9341 - val_loss: 0.4041 - val_accuracy: 0.9513 Epoch 8/10 582/582 [==============================] - 31s 53ms/step - loss: 0.5091 - accuracy: 0.9348 - val_loss: 0.4530 - val_accuracy: 0.9526 Epoch 9/10 582/582 [==============================] - 31s 54ms/step - loss: 0.4681 - accuracy: 0.9366 - val_loss: 0.4172 - val_accuracy: 0.9539 Epoch 10/10 582/582 [==============================] - 31s 54ms/step - loss: 0.4634 - accuracy: 0.9386 - val_loss: 0.4221 - val_accuracy: 0.9530
学習カーブ
MobileNet V2 ベースモデルを固定された特徴抽出器として使用するときの訓練と検証の精度 / 損失の学習カーブを見てみましょう。
acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.ylabel('Accuracy') plt.ylim([min(plt.ylim()),1]) plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.ylabel('Cross Entropy') plt.ylim([0,1.0]) plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.show()
Note: もし貴方が何故検証メトリクスが訓練メトリクスよりも明らかに良いのか疑問に思うのであれば、主要因は tf.keras.layers.BatchNormalization と tf.keras.layers.Dropout のような層が訓練の間の精度に影響するからです。それらは検証損失を計算するときには無効にされます。
程度は少ないですが、それはまた検証メトリクスがエポック後に評価される一方で訓練メトリクスはエポックのための平均を報告するからでもあります、そのため検証メトリクスは僅かばかり長く訓練されたモデルを見ることになります。
再調整
私達の特徴抽出実験では、MobileNet V2 ベースモデルの上の 2, 3 層だけを訓練していました。事前訓練されたネットワークの重みは訓練の間に更新されませんでした。
より以上にパフォーマンスを増す一つの方法は 、貴方が追加した分類器の訓練と一緒に事前訓練されたモデルの上部の層の重みを訓練 (or「再調整」) することです。訓練過程は一般的な特徴マップから私達のデータセットに特に関連する特徴へと重みが調整されることを強制します。
Note: これは事前訓練されたモデルを非訓練可能に設定しながら top-level 分類器を訓練した後でのみ試されるべきです。もし貴方が事前訓練されたモデルの上にランダムに初期化された分類器を追加して総ての層を一緒に訓練することを試みる場合、勾配更新の大きさが (分類器からのランダム重みのために) 大き過ぎて貴方の事前訓練されたモデルはそれが学習したことを忘れるでしょう。
また、MobileNet モデル全体よりも小さい数の top 層を再調整することを試みるべきです。殆どの畳み込みネットワークでは、層が高位になればなるほど、それはより特化されます。最初の 2, 3 の層は非常に単純で一般的な特徴を学習して、それは殆ど総てのタイプの画像に一般化されます。より高く行くほどに、特徴は段々と (モデルがその上で訓練された) データセット特有になります。再調整のゴールは一般的な学習を上書きすることではなく、これらの固有の特徴を新しいデータで動作するように適応させることです。
モデルのトップ層を解凍する
貴方が行なう必要がある総てのことは base_model を解凍してボトム層を非訓練可能に (i.e. 訓練できないように) 設定することです。それから、モデルを再コンパイルするべきです (これらの変更が有効になるために必要です)、そして訓練を再開します。
base_model.trainable = True
# Let's take a look to see how many layers are in the base model print("Number of layers in the base model: ", len(base_model.layers)) # Fine tune from this layer onwards fine_tune_at = 100 # Freeze all the layers before the `fine_tune_at` layer for layer in base_model.layers[:fine_tune_at]: layer.trainable = False
Number of layers in the base model: 155
モデルをコンパイルする
遥かに低い訓練率 (= training rate) を使用してモデルをコンパイルします。
model.compile(loss='binary_crossentropy', optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10), metrics=['accuracy'])
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984 _________________________________________________________________ global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 1) 1281 ================================================================= Total params: 2,259,265 Trainable params: 1,863,873 Non-trainable params: 395,392 _________________________________________________________________
len(model.trainable_variables)
58
モデルの訓練を継続する
先に収束するために訓練したのであれば、これは数パーセントの更なる精度を貴方に得させるでしょう。
fine_tune_epochs = 10 total_epochs = initial_epochs + fine_tune_epochs history_fine = model.fit(train_batches, epochs=total_epochs, initial_epoch = initial_epochs, validation_data=validation_batches)
Epoch 11/20 582/582 [==============================] - 42s 72ms/step - loss: 0.4341 - accuracy: 0.9476 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00 Epoch 12/20 582/582 [==============================] - 35s 60ms/step - loss: 0.3408 - accuracy: 0.9592 - val_loss: 0.2207 - val_accuracy: 0.9703 Epoch 13/20 582/582 [==============================] - 35s 60ms/step - loss: 0.2940 - accuracy: 0.9645 - val_loss: 0.2162 - val_accuracy: 0.9737 Epoch 14/20 582/582 [==============================] - 35s 60ms/step - loss: 0.2327 - accuracy: 0.9712 - val_loss: 0.2374 - val_accuracy: 0.9728 Epoch 15/20 582/582 [==============================] - 35s 60ms/step - loss: 0.2110 - accuracy: 0.9756 - val_loss: 0.2342 - val_accuracy: 0.9733 Epoch 16/20 582/582 [==============================] - 35s 60ms/step - loss: 0.1862 - accuracy: 0.9783 - val_loss: 0.2156 - val_accuracy: 0.9741 Epoch 17/20 582/582 [==============================] - 35s 60ms/step - loss: 0.1655 - accuracy: 0.9794 - val_loss: 0.2218 - val_accuracy: 0.9746 Epoch 18/20 582/582 [==============================] - 35s 61ms/step - loss: 0.1372 - accuracy: 0.9823 - val_loss: 0.2023 - val_accuracy: 0.9750 Epoch 19/20 582/582 [==============================] - 35s 60ms/step - loss: 0.1186 - accuracy: 0.9846 - val_loss: 0.2137 - val_accuracy: 0.9750 Epoch 20/20 582/582 [==============================] - 35s 60ms/step - loss: 0.1135 - accuracy: 0.9847 - val_loss: 0.2207 - val_accuracy: 0.9759
MobileNet V2 ベースモデルの最後の 2, 3 層を再調整してその上の分類器を訓練するとき、訓練と検証の精度 / 損失の学習カーブを見てみましょう。検証損失は訓練損失よりも遥かに高いので、何某かの overfitting を得るかもしれません。
新しい訓練セットは比較的小さくて元の MobileNet V2 データセットに類似しているので、某かの overfitting もまた得ているかもしれません。
再調整後にモデルは 98% 精度近くに到達します。
acc += history_fine.history['accuracy'] val_acc += history_fine.history['val_accuracy'] loss += history_fine.history['loss'] val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.ylim([0.8, 1]) plt.plot([initial_epochs-1,initial_epochs-1], plt.ylim(), label='Start Fine Tuning') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.ylim([0, 1.0]) plt.plot([initial_epochs-1,initial_epochs-1], plt.ylim(), label='Start Fine Tuning') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.show()
要約 :
- 特徴抽出のために事前訓練されたモデルを使用する: 小さいデータセットで作業するとき、同じドメインのより巨大なデータセット上で訓練されたモデルにより学習された特徴を活用することは一般的です。これは事前訓練されたモデルをインスタンス化して上に完全結合分類器を追加することにより成されます。事前訓練されたモデルは「凍結」されて訓練の間分類器の重みだけが更新されます。この場合、畳み込みベースが各画像に関連する総ての特徴を抽出してそして抽出された特徴のこれらのセットが与えられたとき画像クラスを決定する分類器を単に訓練します。
- 事前訓練されたモデルを再調整する: パフォーマンスを更に改良するために、事前訓練されたモデルの top-level 層を再調整を通して新しいデータセットに再目的化することを望むかもしれません。この場合、貴方のモデルがデータセットに固有の高位な特徴を学習するように重みを調整します。このテクニックは通常は、訓練データセットが巨大で (事前訓練されたモデルがその上で訓練された) 元のデータセットに非常に類似しているときに限り推奨されます。
以上