TensorFlow 2.0 Alpha : Beginner Tutorials : 画像 :- 事前訓練された ConvNet を使用する転移学習 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/01/2019
* 本ページは、TensorFlow の本家サイトの TF 2.0 Alpha – Beginner Tutorials – Images の以下のページを翻訳した上で
適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
画像 :- 事前訓練された ConvNet を使用する転移学習
このチュートリアルでは事前訓練されたネットワークからの転移学習を使用して猫 vs 犬画像をどのように分類するかを議論します。これはスクラッチからネットワークを訓練することによって見たものより高い精度を得ることを可能にします。
**事前訓練されたモデル** は、典型的には巨大スケール画像分類タスク上の、巨大なデータセット上で以前に訓練された (セーブされた) ネットワークです。事前訓練されたモデルをそのまま使用するか事前訓練された convnet を使用した学習を転移させることができます。**転移学習** の背後にある直感は、このモデルが十分に巨大で一般的なデータセット上で訓練された場合、このモデルは視覚世界の一般的なモデルとして効果的に役立つであろうということです。これらのモデルを私達のタスクに固有の私達自身のモデルの基盤として使用することで巨大なデータセット上で巨大なモデルを訓練しなくてもこれらの学習された特徴マップを活用できます。事前訓練されたモデルを使用する転移学習の 2 つのシナリオがあります :
- 特徴抽出 (= Feature Extraction) – 新しいサンプルから意味がある特徴を抽出するために以前のネットワークにより学習された表現を使用します。事前訓練されたモデルの上に (スクラッチから訓練される) 新しい分類器を単に追加します、その結果前に学習された特徴マップを私達のデータセットのために最目的化できます。事前訓練されたモデル全体か単に畳み込みベースを使用しますか? – これらの事前訓練された convnet (畳み込みベース) の特徴抽出部分を使用します、何故ならばそれらは一般的な特徴で写真に渡り学習された概念でありがちだからです。けれども、事前訓練されたモデルの分類パートはしばしば元の分類タスクに特有で、結果的にモデルが訓練された上のクラスのセットに固有です。
- 再調整 (= Fine-Tuning) – 特徴抽出のために使用される凍結されたモデルベースの 2, 3 の top 層を解凍して、凍結モデルの最後の層群に加えて新たに追加された分類層の両者を連帯して訓練します。これは最後の分類器に加えて高次特徴表現を、関係する特定のタスクのためにより関連させるために「再調整」することを可能にします。
一般的な機械学習ワークフローに従います : 1. データを調べて理解する。2. 入力パイプラインを構築する – 画像分類チュートリアルで行なったように Keras ImageDataGenerator を使用する。3. モデルを構成する * 事前訓練されたモデル (と事前訓練された重み) をロードする * top に分類層をスタックする。4. モデルを訓練する。5. モデルを評価する。
事前訓練された convnet を特徴抽出として使用する例を見てからベースモデルの最後の 2, 3 層を訓練するために再調整します。
from __future__ import absolute_import, division, print_function import os import numpy as np import matplotlib.pyplot as plt
!pip install -q tensorflow-gpu==2.0.0-alpha0 import tensorflow as tf keras = tf.keras
Collecting tensorflow-gpu==2.0.0-alpha0 Successfully installed google-pasta-0.1.4 tb-nightly-1.14.0a20190303 tensorflow-estimator-2.0-preview-1.14.0.dev2019030300 tensorflow-gpu==2.0.0-alpha0-2.0.0.dev20190303
データ前処理
データ・ダウンロード
猫と犬のデータセットをロードするために TensorFlow Dataset を利用します。
この tfds パッケージは事前定義されたデータをロードするための最も容易な方法です。もし貴方自身のデータを持ち、インポートしてそれを TensorFlow で使用することに興味がれば loading image data を見てください。
import tensorflow_datasets as tfds
tfds.load メソッドはデータをダウンロードしてキャッシュし、tf.data.Dataset オブジェクトを返します。これらのオブジェクトはデータを操作してそれをモデルにパイプするためのパワフルで、効率的なメソッドを提供します。
“cats_vs_dog” は標準的な分割を定義していないので、それをデータの 80%, 10%, 10% で (train, validation, test) にそれぞれ分割するために subsplit 機能を使用します。
SPLIT_WEIGHTS = (8, 1, 1) splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS) (raw_train, raw_validation, raw_test), metadata = tfds.load( 'cats_vs_dogs', split=list(splits), with_info=True, as_supervised=True)
Downloading / extracting dataset cats_vs_dogs (786.68 MiB) to /root/tensorflow_datasets/cats_vs_dogs/2.0.0...
結果としての tf.data.Dataset オブジェクトは (image, label) ペアを含みます。そこでは画像は可変な shape と 3 チャネルを持ち、ラベルはスカラーです。
print(raw_train) print(raw_train)
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)> <_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
訓練セットから最初の 2 つの画像とラベルを表示します。
get_label_name = metadata.features['label'].int2str for image, label in raw_train.take(2): plt.figure() plt.imshow(image) plt.title(get_label_name(label))
データをフォーマットする
タスクのために画像をフォーマットするために tf.image モジュールを使用します。
画像を固定入力サイズにリサイズし、入力チャネルを [-1, 1] の範囲にリスケールします。
IMG_SIZE = 160 # All images will be resized to 160x160 def format_example(image, label): image = tf.cast(image, tf.float32) image = (image/127.5) - 1 image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) return image, label
map メソッドを使用してこの関数をデータセットの各アイテムに適用します :
train = raw_train.map(format_example) validation = raw_validation.map(format_example) test = raw_test.map(format_example)
データをシャッフルしてバッチ化します。
BATCH_SIZE = 32 SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE) validation_batches = validation.batch(BATCH_SIZE) test_batches = test.batch(BATCH_SIZE)
データのバッチを調べます :
for image_batch, label_batch in train_batches.take(1): pass image_batch.shape
TensorShape([32, 160, 160, 3])
ベースモデルを作成する
Google で開発されて、ImageNet データセット (web 画像の 1.4 M 画像の巨大データセットと 1000 クラス) で事前訓練された MobileNet V2 からベースモデルを作成します。これはパワフルなモデルです。それが学習した特徴が「猫 vs.犬」問題に何ができるかを見ましょう。
最初に、特徴抽出のために使用する MobileNet V2 の中間層を選択する必要があります。一般的な実践では flatten 演算の前のまさに最後の層の出力を使用します、いわゆる「ボトルネック層」です。ここでの論拠は次の完全結合層はニューラルネットワークがその上で訓練されたタスクに特化され過ぎていることで、これらの層により学習された特徴は新しいタスクのためにあまり有用ではありません。けれども、ボトルネック特徴は遥かに一般性を保持しています。
ImageNet 上で訓練された重みとともに事前ロードされた MobileNet V2 モデルをインスタンス化しましょう。include_top=False 引数を指定することにより、トップに分類層を含まないネットワークをロードします、これは特徴抽出に理想的です。
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3) # Create the base model from the pre-trained model MobileNet V2 base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 1s 0us/step
この特徴抽出器は各 160x160x3 画像を 5x5x1280 特徴ブロックに変換します。画像のサンプルバッチにそれが何をするかを見ます :
feature_batch = base_model(image_batch) print(feature_batch.shape)
(32, 5, 5, 1280)
特徴抽出
前のステップで作成された畳み込みベースを凍結してそれを特徴抽出器として使用し、その上に分類器を追加して top-level 分類器を訓練します。
畳み込みベースを凍結する
compile してモデルを訓練する前に畳み込みベースを凍結することは重要です。凍結する (あるいは layer.trainable = False を設定する) ことにより、この層の重みが訓練の間に更新されることを回避します。
base_model.trainable = False
# Let's take a look at the base model architecture base_model.summary()
Model: "mobilenetv2_1.00_160" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 160, 160, 3) 0 __________________________________________________________________________________________________ Conv1_pad (ZeroPadding2D) (None, 161, 161, 3) 0 input_1[0][0] __________________________________________________________________________________________________ Conv1 (Conv2D) (None, 80, 80, 32) 864 Conv1_pad[0][0] __________________________________________________________________________________________________ bn_Conv1 (BatchNormalizationV1) (None, 80, 80, 32) 128 Conv1[0][0] __________________________________________________________________________________________________ Conv1_relu (ReLU) (None, 80, 80, 32) 0 bn_Conv1[0][0] __________________________________________________________________________________________________ expanded_conv_depthwise (Depthw (None, 80, 80, 32) 288 Conv1_relu[0][0] __________________________________________________________________________________________________ expanded_conv_depthwise_BN (Bat (None, 80, 80, 32) 128 expanded_conv_depthwise[0][0] __________________________________________________________________________________________________ expanded_conv_depthwise_relu (R (None, 80, 80, 32) 0 expanded_conv_depthwise_BN[0][0] __________________________________________________________________________________________________ expanded_conv_project (Conv2D) (None, 80, 80, 16) 512 expanded_conv_depthwise_relu[0][0 __________________________________________________________________________________________________ expanded_conv_project_BN (Batch (None, 80, 80, 16) 64 expanded_conv_project[0][0] __________________________________________________________________________________________________ block_1_expand (Conv2D) (None, 80, 80, 96) 1536 expanded_conv_project_BN[0][0] __________________________________________________________________________________________________ block_1_expand_BN (BatchNormali (None, 80, 80, 96) 384 block_1_expand[0][0] __________________________________________________________________________________________________ block_1_expand_relu (ReLU) (None, 80, 80, 96) 0 block_1_expand_BN[0][0] __________________________________________________________________________________________________ block_1_pad (ZeroPadding2D) (None, 81, 81, 96) 0 block_1_expand_relu[0][0] __________________________________________________________________________________________________ block_1_depthwise (DepthwiseCon (None, 40, 40, 96) 864 block_1_pad[0][0] __________________________________________________________________________________________________ block_1_depthwise_BN (BatchNorm (None, 40, 40, 96) 384 block_1_depthwise[0][0] __________________________________________________________________________________________________ block_1_depthwise_relu (ReLU) (None, 40, 40, 96) 0 block_1_depthwise_BN[0][0] __________________________________________________________________________________________________ block_1_project (Conv2D) (None, 40, 40, 24) 2304 block_1_depthwise_relu[0][0] __________________________________________________________________________________________________ block_1_project_BN (BatchNormal (None, 40, 40, 24) 96 block_1_project[0][0] __________________________________________________________________________________________________ block_2_expand (Conv2D) (None, 40, 40, 144) 3456 block_1_project_BN[0][0] __________________________________________________________________________________________________ block_2_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_2_expand[0][0] __________________________________________________________________________________________________ block_2_expand_relu (ReLU) (None, 40, 40, 144) 0 block_2_expand_BN[0][0] __________________________________________________________________________________________________ block_2_depthwise (DepthwiseCon (None, 40, 40, 144) 1296 block_2_expand_relu[0][0] __________________________________________________________________________________________________ block_2_depthwise_BN (BatchNorm (None, 40, 40, 144) 576 block_2_depthwise[0][0] __________________________________________________________________________________________________ block_2_depthwise_relu (ReLU) (None, 40, 40, 144) 0 block_2_depthwise_BN[0][0] __________________________________________________________________________________________________ block_2_project (Conv2D) (None, 40, 40, 24) 3456 block_2_depthwise_relu[0][0] __________________________________________________________________________________________________ block_2_project_BN (BatchNormal (None, 40, 40, 24) 96 block_2_project[0][0] __________________________________________________________________________________________________ block_2_add (Add) (None, 40, 40, 24) 0 block_1_project_BN[0][0] block_2_project_BN[0][0] __________________________________________________________________________________________________ block_3_expand (Conv2D) (None, 40, 40, 144) 3456 block_2_add[0][0] __________________________________________________________________________________________________ block_3_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_3_expand[0][0] __________________________________________________________________________________________________ block_3_expand_relu (ReLU) (None, 40, 40, 144) 0 block_3_expand_BN[0][0] __________________________________________________________________________________________________ block_3_pad (ZeroPadding2D) (None, 41, 41, 144) 0 block_3_expand_relu[0][0] __________________________________________________________________________________________________ block_3_depthwise (DepthwiseCon (None, 20, 20, 144) 1296 block_3_pad[0][0] __________________________________________________________________________________________________ block_3_depthwise_BN (BatchNorm (None, 20, 20, 144) 576 block_3_depthwise[0][0] __________________________________________________________________________________________________ block_3_depthwise_relu (ReLU) (None, 20, 20, 144) 0 block_3_depthwise_BN[0][0] __________________________________________________________________________________________________ block_3_project (Conv2D) (None, 20, 20, 32) 4608 block_3_depthwise_relu[0][0] __________________________________________________________________________________________________ block_3_project_BN (BatchNormal (None, 20, 20, 32) 128 block_3_project[0][0] __________________________________________________________________________________________________ block_4_expand (Conv2D) (None, 20, 20, 192) 6144 block_3_project_BN[0][0] __________________________________________________________________________________________________ block_4_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_4_expand[0][0] __________________________________________________________________________________________________ block_4_expand_relu (ReLU) (None, 20, 20, 192) 0 block_4_expand_BN[0][0] __________________________________________________________________________________________________ block_4_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_4_expand_relu[0][0] __________________________________________________________________________________________________ block_4_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_4_depthwise[0][0] __________________________________________________________________________________________________ block_4_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_4_depthwise_BN[0][0] __________________________________________________________________________________________________ block_4_project (Conv2D) (None, 20, 20, 32) 6144 block_4_depthwise_relu[0][0] __________________________________________________________________________________________________ block_4_project_BN (BatchNormal (None, 20, 20, 32) 128 block_4_project[0][0] __________________________________________________________________________________________________ block_4_add (Add) (None, 20, 20, 32) 0 block_3_project_BN[0][0] block_4_project_BN[0][0] __________________________________________________________________________________________________ block_5_expand (Conv2D) (None, 20, 20, 192) 6144 block_4_add[0][0] __________________________________________________________________________________________________ block_5_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_5_expand[0][0] __________________________________________________________________________________________________ block_5_expand_relu (ReLU) (None, 20, 20, 192) 0 block_5_expand_BN[0][0] __________________________________________________________________________________________________ block_5_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_5_expand_relu[0][0] __________________________________________________________________________________________________ block_5_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_5_depthwise[0][0] __________________________________________________________________________________________________ block_5_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_5_depthwise_BN[0][0] __________________________________________________________________________________________________ block_5_project (Conv2D) (None, 20, 20, 32) 6144 block_5_depthwise_relu[0][0] __________________________________________________________________________________________________ block_5_project_BN (BatchNormal (None, 20, 20, 32) 128 block_5_project[0][0] __________________________________________________________________________________________________ block_5_add (Add) (None, 20, 20, 32) 0 block_4_add[0][0] block_5_project_BN[0][0] __________________________________________________________________________________________________ block_6_expand (Conv2D) (None, 20, 20, 192) 6144 block_5_add[0][0] __________________________________________________________________________________________________ block_6_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_6_expand[0][0] __________________________________________________________________________________________________ block_6_expand_relu (ReLU) (None, 20, 20, 192) 0 block_6_expand_BN[0][0] __________________________________________________________________________________________________ block_6_pad (ZeroPadding2D) (None, 21, 21, 192) 0 block_6_expand_relu[0][0] __________________________________________________________________________________________________ block_6_depthwise (DepthwiseCon (None, 10, 10, 192) 1728 block_6_pad[0][0] __________________________________________________________________________________________________ block_6_depthwise_BN (BatchNorm (None, 10, 10, 192) 768 block_6_depthwise[0][0] __________________________________________________________________________________________________ block_6_depthwise_relu (ReLU) (None, 10, 10, 192) 0 block_6_depthwise_BN[0][0] __________________________________________________________________________________________________ block_6_project (Conv2D) (None, 10, 10, 64) 12288 block_6_depthwise_relu[0][0] __________________________________________________________________________________________________ block_6_project_BN (BatchNormal (None, 10, 10, 64) 256 block_6_project[0][0] __________________________________________________________________________________________________ block_7_expand (Conv2D) (None, 10, 10, 384) 24576 block_6_project_BN[0][0] __________________________________________________________________________________________________ block_7_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_7_expand[0][0] __________________________________________________________________________________________________ block_7_expand_relu (ReLU) (None, 10, 10, 384) 0 block_7_expand_BN[0][0] __________________________________________________________________________________________________ block_7_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_7_expand_relu[0][0] __________________________________________________________________________________________________ block_7_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_7_depthwise[0][0] __________________________________________________________________________________________________ block_7_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_7_depthwise_BN[0][0] __________________________________________________________________________________________________ block_7_project (Conv2D) (None, 10, 10, 64) 24576 block_7_depthwise_relu[0][0] __________________________________________________________________________________________________ block_7_project_BN (BatchNormal (None, 10, 10, 64) 256 block_7_project[0][0] __________________________________________________________________________________________________ block_7_add (Add) (None, 10, 10, 64) 0 block_6_project_BN[0][0] block_7_project_BN[0][0] __________________________________________________________________________________________________ block_8_expand (Conv2D) (None, 10, 10, 384) 24576 block_7_add[0][0] __________________________________________________________________________________________________ block_8_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_8_expand[0][0] __________________________________________________________________________________________________ block_8_expand_relu (ReLU) (None, 10, 10, 384) 0 block_8_expand_BN[0][0] __________________________________________________________________________________________________ block_8_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_8_expand_relu[0][0] __________________________________________________________________________________________________ block_8_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_8_depthwise[0][0] __________________________________________________________________________________________________ block_8_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_8_depthwise_BN[0][0] __________________________________________________________________________________________________ block_8_project (Conv2D) (None, 10, 10, 64) 24576 block_8_depthwise_relu[0][0] __________________________________________________________________________________________________ block_8_project_BN (BatchNormal (None, 10, 10, 64) 256 block_8_project[0][0] __________________________________________________________________________________________________ block_8_add (Add) (None, 10, 10, 64) 0 block_7_add[0][0] block_8_project_BN[0][0] __________________________________________________________________________________________________ block_9_expand (Conv2D) (None, 10, 10, 384) 24576 block_8_add[0][0] __________________________________________________________________________________________________ block_9_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_9_expand[0][0] __________________________________________________________________________________________________ block_9_expand_relu (ReLU) (None, 10, 10, 384) 0 block_9_expand_BN[0][0] __________________________________________________________________________________________________ block_9_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_9_expand_relu[0][0] __________________________________________________________________________________________________ block_9_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_9_depthwise[0][0] __________________________________________________________________________________________________ block_9_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_9_depthwise_BN[0][0] __________________________________________________________________________________________________ block_9_project (Conv2D) (None, 10, 10, 64) 24576 block_9_depthwise_relu[0][0] __________________________________________________________________________________________________ block_9_project_BN (BatchNormal (None, 10, 10, 64) 256 block_9_project[0][0] __________________________________________________________________________________________________ block_9_add (Add) (None, 10, 10, 64) 0 block_8_add[0][0] block_9_project_BN[0][0] __________________________________________________________________________________________________ block_10_expand (Conv2D) (None, 10, 10, 384) 24576 block_9_add[0][0] __________________________________________________________________________________________________ block_10_expand_BN (BatchNormal (None, 10, 10, 384) 1536 block_10_expand[0][0] __________________________________________________________________________________________________ block_10_expand_relu (ReLU) (None, 10, 10, 384) 0 block_10_expand_BN[0][0] __________________________________________________________________________________________________ block_10_depthwise (DepthwiseCo (None, 10, 10, 384) 3456 block_10_expand_relu[0][0] __________________________________________________________________________________________________ block_10_depthwise_BN (BatchNor (None, 10, 10, 384) 1536 block_10_depthwise[0][0] __________________________________________________________________________________________________ block_10_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_10_depthwise_BN[0][0] __________________________________________________________________________________________________ block_10_project (Conv2D) (None, 10, 10, 96) 36864 block_10_depthwise_relu[0][0] __________________________________________________________________________________________________ block_10_project_BN (BatchNorma (None, 10, 10, 96) 384 block_10_project[0][0] __________________________________________________________________________________________________ block_11_expand (Conv2D) (None, 10, 10, 576) 55296 block_10_project_BN[0][0] __________________________________________________________________________________________________ block_11_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_11_expand[0][0] __________________________________________________________________________________________________ block_11_expand_relu (ReLU) (None, 10, 10, 576) 0 block_11_expand_BN[0][0] __________________________________________________________________________________________________ block_11_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_11_expand_relu[0][0] __________________________________________________________________________________________________ block_11_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_11_depthwise[0][0] __________________________________________________________________________________________________ block_11_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_11_depthwise_BN[0][0] __________________________________________________________________________________________________ block_11_project (Conv2D) (None, 10, 10, 96) 55296 block_11_depthwise_relu[0][0] __________________________________________________________________________________________________ block_11_project_BN (BatchNorma (None, 10, 10, 96) 384 block_11_project[0][0] __________________________________________________________________________________________________ block_11_add (Add) (None, 10, 10, 96) 0 block_10_project_BN[0][0] block_11_project_BN[0][0] __________________________________________________________________________________________________ block_12_expand (Conv2D) (None, 10, 10, 576) 55296 block_11_add[0][0] __________________________________________________________________________________________________ block_12_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_12_expand[0][0] __________________________________________________________________________________________________ block_12_expand_relu (ReLU) (None, 10, 10, 576) 0 block_12_expand_BN[0][0] __________________________________________________________________________________________________ block_12_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_12_expand_relu[0][0] __________________________________________________________________________________________________ block_12_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_12_depthwise[0][0] __________________________________________________________________________________________________ block_12_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_12_depthwise_BN[0][0] __________________________________________________________________________________________________ block_12_project (Conv2D) (None, 10, 10, 96) 55296 block_12_depthwise_relu[0][0] __________________________________________________________________________________________________ block_12_project_BN (BatchNorma (None, 10, 10, 96) 384 block_12_project[0][0] __________________________________________________________________________________________________ block_12_add (Add) (None, 10, 10, 96) 0 block_11_add[0][0] block_12_project_BN[0][0] __________________________________________________________________________________________________ block_13_expand (Conv2D) (None, 10, 10, 576) 55296 block_12_add[0][0] __________________________________________________________________________________________________ block_13_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_13_expand[0][0] __________________________________________________________________________________________________ block_13_expand_relu (ReLU) (None, 10, 10, 576) 0 block_13_expand_BN[0][0] __________________________________________________________________________________________________ block_13_pad (ZeroPadding2D) (None, 11, 11, 576) 0 block_13_expand_relu[0][0] __________________________________________________________________________________________________ block_13_depthwise (DepthwiseCo (None, 5, 5, 576) 5184 block_13_pad[0][0] __________________________________________________________________________________________________ block_13_depthwise_BN (BatchNor (None, 5, 5, 576) 2304 block_13_depthwise[0][0] __________________________________________________________________________________________________ block_13_depthwise_relu (ReLU) (None, 5, 5, 576) 0 block_13_depthwise_BN[0][0] __________________________________________________________________________________________________ block_13_project (Conv2D) (None, 5, 5, 160) 92160 block_13_depthwise_relu[0][0] __________________________________________________________________________________________________ block_13_project_BN (BatchNorma (None, 5, 5, 160) 640 block_13_project[0][0] __________________________________________________________________________________________________ block_14_expand (Conv2D) (None, 5, 5, 960) 153600 block_13_project_BN[0][0] __________________________________________________________________________________________________ block_14_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_14_expand[0][0] __________________________________________________________________________________________________ block_14_expand_relu (ReLU) (None, 5, 5, 960) 0 block_14_expand_BN[0][0] __________________________________________________________________________________________________ block_14_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_14_expand_relu[0][0] __________________________________________________________________________________________________ block_14_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_14_depthwise[0][0] __________________________________________________________________________________________________ block_14_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_14_depthwise_BN[0][0] __________________________________________________________________________________________________ block_14_project (Conv2D) (None, 5, 5, 160) 153600 block_14_depthwise_relu[0][0] __________________________________________________________________________________________________ block_14_project_BN (BatchNorma (None, 5, 5, 160) 640 block_14_project[0][0] __________________________________________________________________________________________________ block_14_add (Add) (None, 5, 5, 160) 0 block_13_project_BN[0][0] block_14_project_BN[0][0] __________________________________________________________________________________________________ block_15_expand (Conv2D) (None, 5, 5, 960) 153600 block_14_add[0][0] __________________________________________________________________________________________________ block_15_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_15_expand[0][0] __________________________________________________________________________________________________ block_15_expand_relu (ReLU) (None, 5, 5, 960) 0 block_15_expand_BN[0][0] __________________________________________________________________________________________________ block_15_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_15_expand_relu[0][0] __________________________________________________________________________________________________ block_15_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_15_depthwise[0][0] __________________________________________________________________________________________________ block_15_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_15_depthwise_BN[0][0] __________________________________________________________________________________________________ block_15_project (Conv2D) (None, 5, 5, 160) 153600 block_15_depthwise_relu[0][0] __________________________________________________________________________________________________ block_15_project_BN (BatchNorma (None, 5, 5, 160) 640 block_15_project[0][0] __________________________________________________________________________________________________ block_15_add (Add) (None, 5, 5, 160) 0 block_14_add[0][0] block_15_project_BN[0][0] __________________________________________________________________________________________________ block_16_expand (Conv2D) (None, 5, 5, 960) 153600 block_15_add[0][0] __________________________________________________________________________________________________ block_16_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_16_expand[0][0] __________________________________________________________________________________________________ block_16_expand_relu (ReLU) (None, 5, 5, 960) 0 block_16_expand_BN[0][0] __________________________________________________________________________________________________ block_16_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_16_expand_relu[0][0] __________________________________________________________________________________________________ block_16_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_16_depthwise[0][0] __________________________________________________________________________________________________ block_16_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_16_depthwise_BN[0][0] __________________________________________________________________________________________________ block_16_project (Conv2D) (None, 5, 5, 320) 307200 block_16_depthwise_relu[0][0] __________________________________________________________________________________________________ block_16_project_BN (BatchNorma (None, 5, 5, 320) 1280 block_16_project[0][0] __________________________________________________________________________________________________ Conv_1 (Conv2D) (None, 5, 5, 1280) 409600 block_16_project_BN[0][0] __________________________________________________________________________________________________ Conv_1_bn (BatchNormalizationV1 (None, 5, 5, 1280) 5120 Conv_1[0][0] __________________________________________________________________________________________________ out_relu (ReLU) (None, 5, 5, 1280) 0 Conv_1_bn[0][0] ================================================================================================== Total params: 2,257,984 Trainable params: 0 Non-trainable params: 2,257,984 __________________________________________________________________________________________________
分類ヘッドを追加する
特徴ブロックから予測を生成するために、特徴を画像毎に単一 1280-要素ベクトルに変換するために tf.keras.layers.GlobalAveragePlloing2d を使用して 5×5 空間的位置に渡り平均します。
global_average_layer = tf.keras.layers.GlobalAveragePooling2D() feature_batch_average = global_average_layer(feature_batch) print(feature_batch_average.shape)
(32, 1280)
その上にこれらの特徴を画像毎に単一の予測に変換するために tf.keras.layers.Dense 層を適用します。ここでは活性化関数は使用しないでください、何故ならばこの予測はロジットとして扱われるからです : 正の数はクラス 1 を予測し、負の数はクラス 0 を予測します。
prediction_layer = keras.layers.Dense(1) prediction_batch = prediction_layer(feature_batch_average) print(prediction_batch.shape)
(32, 1)
さて特徴抽出器これらの 2 つの層を tf.keras.Sequential モデルを使用してスタックします :
model = tf.keras.Sequential([ base_model, global_average_layer, prediction_layer ])
モデルをコンパイルする
モデルを訓練する前にそれをコンパイルしなければなりません。
base_learning_rate = 0.0001 model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984 _________________________________________________________________ global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 1) 1281 ================================================================= Total params: 2,259,265 Trainable params: 1,281 Non-trainable params: 2,257,984 _________________________________________________________________
これら 1.2K の訓練可能なパラメータは 2 つの tf.Variable オブジェクトの中で分割されます、2 つの dense 層の重みとバイアスです :
len(model.trainable_variables)
2
モデルを訓練する
10 エポックの間の訓練後、 ~96% 精度を得ることができます。
num_train, num_val, num_test = ( metadata.splits['train'].num_examples*weight/10 for weight in SPLIT_WEIGHTS )
initial_epochs = 10 steps_per_epoch = round(num_train)//BATCH_SIZE validation_steps = 20 loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)
20/20 [==============================] - 4s 219ms/step - loss: 3.1885 - accuracy: 0.6109
print("initial loss: {:.2f}".format(loss0)) print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 3.19 initial accuracy: 0.61
history = model.fit(train_batches.repeat(), epochs=initial_epochs, steps_per_epoch = steps_per_epoch, validation_data=validation_batches.repeat(), validation_steps=validation_steps)
Epoch 1/10 581/581 [==============================] - 102s 175ms/step - loss: 1.8917 - accuracy: 0.7606 - val_loss: 0.8860 - val_accuracy: 0.8828 Epoch 2/10 581/581 [==============================] - 94s 161ms/step - loss: 0.9989 - accuracy: 0.8697 - val_loss: 0.4330 - val_accuracy: 0.9281 Epoch 3/10 581/581 [==============================] - 99s 170ms/step - loss: 0.7417 - accuracy: 0.8995 - val_loss: 0.2629 - val_accuracy: 0.9469 Epoch 4/10 581/581 [==============================] - 94s 162ms/step - loss: 0.6681 - accuracy: 0.9133 - val_loss: 0.2753 - val_accuracy: 0.9563 Epoch 5/10 581/581 [==============================] - 98s 169ms/step - loss: 0.5944 - accuracy: 0.9220 - val_loss: 0.2410 - val_accuracy: 0.9641 Epoch 6/10 581/581 [==============================] - 98s 169ms/step - loss: 0.5779 - accuracy: 0.9252 - val_loss: 0.2003 - val_accuracy: 0.9656 Epoch 7/10 581/581 [==============================] - 94s 162ms/step - loss: 0.5151 - accuracy: 0.9328 - val_loss: 0.2115 - val_accuracy: 0.9688 Epoch 8/10 581/581 [==============================] - 95s 164ms/step - loss: 0.5076 - accuracy: 0.9346 - val_loss: 0.1678 - val_accuracy: 0.9703 Epoch 9/10 581/581 [==============================] - 96s 165ms/step - loss: 0.4679 - accuracy: 0.9368 - val_loss: 0.1668 - val_accuracy: 0.9703 Epoch 10/10 581/581 [==============================] - 96s 165ms/step - loss: 0.4921 - accuracy: 0.9381 - val_loss: 0.1847 - val_accuracy: 0.9719
学習カーブ
MobileNet V2 ベースモデルを固定特徴抽出器として使用するときの、訓練と検証精度 / 損失の学習カーブを見てみましょう。
acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.ylabel('Accuracy') plt.ylim([min(plt.ylim()),1]) plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.ylabel('Cross Entropy') plt.ylim([0,1.0]) plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.show()
Note: もし貴方が何故検証メトリクスが訓練メトリクスよりも明らかに良いのか不思議に思うのであれば、主要因は訓練をより困難にするランダムネスを追加する tf.keras.layers.BatchNormalization と tf.keras.layers.Dropout のような層のためです。程度は少ないですが、それはまた検証メトリクスがエポック後に評価される一方で訓練メトリクスがエポックのための平均を報告するからです、そのため検証メトリクスは僅かばかり長く訓練されたモデルを見ます。
再調整
私達の特徴抽出実験では、MobileNet V2 ベースモデルの上の 2, 3 層だけを訓練していました。事前訓練されたネットワークの重みは訓練の間に更新されませんでした。より以上にパフォーマンスを増す一つの方法は top-level 分類器の訓練と一緒に事前訓練されたモデルの上部の層の重みを「再調整」することです。訓練過程は一般的な特徴マップから私達のデータセットに特に関連する特徴へと重みが調整されることを強制します。
Note: これは事前訓練されたモデルを非訓練可能に設定しながら top-level 分類器を訓練した後で試されるべきです。もし貴方が事前訓練されたモデルの上のランダムに初期化された分類器を追加して総ての層を一緒に訓練することを試みる場合、勾配更新の大きさが (分類器からのランダム重みゆえに) 大きすぎて貴方の事前訓練されたモデルはそれが学習したことを総て単に忘れるでしょう。
更に、事前訓練されたモデルの総ての層ではなく事前訓練されたモデルの top 層を再調整する裏の根拠は以下です : convnet では、層が高位になればなるほど、それはより特化されます。convnet の最初の 2, 3 の層は非常に単純で一般的な特徴を学習して、それは殆ど総てのタイプの画像に一般化されます。しかしより高く行くほどに、特徴は次第に (モデルがその上で訓練された) データセット特有になります。再調整の目標はこれらの専門的な特徴を新しいデータで動作するように適応させることです。
モデルのトップ層を解凍する
私達が行なう必要がある総てのことは base_model を解凍することです、そしてボトム層は非訓練可能に設定します。それから、モデルを再コンパイルして (これらの変更が効果を持つようにするために必要です)、訓練を再開します。
base_model.trainable = True
# Let's take a look to see how many layers are in the base model print("Number of layers in the base model: ", len(base_model.layers)) # Fine tune from this layer onwards fine_tune_at = 100 # Freeze all the layers before the `fine_tune_at` layer for layer in base_model.layers[:fine_tune_at]: layer.trainable = False
Number of layers in the base model: 155
モデルをコンパイルする
遥かに低い訓練率 (= training rate) を使用してモデルをコンパイルします。
model.compile(loss='binary_crossentropy', optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10), metrics=['accuracy'])
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984 _________________________________________________________________ global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 1) 1281 ================================================================= Total params: 2,259,265 Trainable params: 1,863,873 Non-trainable params: 395,392 _________________________________________________________________
len(model.trainable_variables)
58
モデルの訓練を継続する
先に収束するために訓練したのであれば、これは数パーセントの更なる精度を貴方に得させるでしょう。
fine_tune_epochs = 10 total_epochs = initial_epochs + fine_tune_epochs history_fine = model.fit(train_batches.repeat(), steps_per_epoch = steps_per_epoch, epochs=total_epochs, initial_epoch = initial_epochs, validation_data=validation_batches.repeat(), validation_steps=validation_steps)
Epoch 11/20 581/581 [==============================] - 134s 231ms/step - loss: 0.4208 - accuracy: 0.9484 - val_loss: 0.1907 - val_accuracy: 0.9812 Epoch 12/20 581/581 [==============================] - 114s 197ms/step - loss: 0.3359 - accuracy: 0.9570 - val_loss: 0.1835 - val_accuracy: 0.9844 Epoch 13/20 581/581 [==============================] - 116s 200ms/step - loss: 0.2930 - accuracy: 0.9650 - val_loss: 0.1505 - val_accuracy: 0.9844 Epoch 14/20 581/581 [==============================] - 114s 196ms/step - loss: 0.2561 - accuracy: 0.9701 - val_loss: 0.1575 - val_accuracy: 0.9859 Epoch 15/20 581/581 [==============================] - 119s 206ms/step - loss: 0.2302 - accuracy: 0.9715 - val_loss: 0.1600 - val_accuracy: 0.9812 Epoch 16/20 581/581 [==============================] - 115s 197ms/step - loss: 0.2134 - accuracy: 0.9747 - val_loss: 0.1407 - val_accuracy: 0.9828 Epoch 17/20 581/581 [==============================] - 115s 197ms/step - loss: 0.1546 - accuracy: 0.9813 - val_loss: 0.0944 - val_accuracy: 0.9828 Epoch 18/20 581/581 [==============================] - 116s 200ms/step - loss: 0.1636 - accuracy: 0.9794 - val_loss: 0.0947 - val_accuracy: 0.9844 Epoch 19/20 581/581 [==============================] - 115s 198ms/step - loss: 0.1356 - accuracy: 0.9823 - val_loss: 0.1169 - val_accuracy: 0.9828 Epoch 20/20 581/581 [==============================] - 116s 199ms/step - loss: 0.1243 - accuracy: 0.9849 - val_loss: 0.1121 - val_accuracy: 0.9875
学習カーブ
再調整後、訓練と検証精度 / 損失の学習カーブを見てみましょう。
Note: 訓練データセットは非常に小さくて MobileNet V2 がその上で訓練された元のデータセットに類似していますので、再調整は overfitting の結果になるかもしれません。
再調整後にモデルは 98% 精度近くに到達します。
acc += history_fine.history['accuracy'] val_acc += history_fine.history['val_accuracy'] loss += history_fine.history['loss'] val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.ylim([0.8, 1]) plt.plot([initial_epochs-1,initial_epochs-1], plt.ylim(), label='Start Fine Tuning') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.ylim([0, 1.0]) plt.plot([initial_epochs-1,initial_epochs-1], plt.ylim(), label='Start Fine Tuning') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.show()
Key takeaways
要約すれば、精度を改善するために事前訓練されたモデルを使用してどのように転移学習を行なうかについてこのチュートリアルでカバーしたものがここにあります : * 特徴抽出 のために事前訓練されたモデルを使用する – 小さいデータセットで作業するとき、同じドメインのより巨大なデータセット上で訓練されたモデルにより学習された特徴を活用することは一般的です。これは事前訓練されたモデルをインスタンス化して完全結合分類器を上に追加することにより成されます。事前訓練されたモデルは「凍結」されて訓練の間分類器の重みだけが更新されます。この場合、畳み込みベースが各画像に関係する総ての特徴を抽出してそして特徴のこれらのセットが与えられたときそれがどのクラスに属するか決定する分類器を訓練します。
* 事前訓練されたモデルを 再調整 する – パフォーマンスを更に改良するために、事前訓練されたモデルの top-level 層を再調整を通して新しいデータセットに再目的化することを望むかもしれません。この場合、データセットに固有の極度に特化されて高位な特徴を学習するように重みを調整します。これは、訓練データセットが巨大で (事前訓練されたモデルがその上で訓練された) 元のデータセットに非常に類似しているときに限り意味を持ちます。
以上