ホーム » CNN » TensorFlow 2.0 : 上級Tutorials : 画像 :- 事前訓練された ConvNet で転移学習

TensorFlow 2.0 : 上級Tutorials : 画像 :- 事前訓練された ConvNet で転移学習

TensorFlow 2.0 : 上級 Tutorials : 画像 :- 事前訓練された ConvNet で転移学習 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/07/2019

* 本ページは、TensorFlow org サイトの TF 2.0 – Advanced Tutorials – Images の以下のページを翻訳した上で
適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

画像 :- 事前訓練された ConvNet で転移学習

このチュートリアルでは事前訓練されたネットワークからの転移学習を使用して猫 vs 犬画像をどのように分類するかを学習します。

事前訓練されたモデル は、典型的には巨大スケールの画像分類タスク上、巨大なデータセット上で以前に訓練された (セーブされた) ネットワークです。事前訓練されたモデルをそのまま使用するか与えられたタスクにこのモデルをカスタマイズするために 転移学習 を使用します。

転移学習の背後にある直感は、このモデルが十分に巨大で一般的なデータセット上で訓練された場合、このモデルは視覚世界の一般的なモデルとして効果的に役立つであろうということです。それから巨大なデータセット上で巨大なモデルをスクラッチから訓練し始めることなくこれらの学習された特徴マップを活用できます。

このノートブックでは、事前訓練されたモデルをカスタマイズする 2 つの方法を試します :

  1. 特徴抽出 (= Feature Extraction) – 新しいサンプルから意味がある特徴を抽出するために以前のネットワークにより学習された表現を使用します。事前訓練されたモデルの上に (スクラッチから訓練される) 新しい分類器を単に追加します、その結果前に学習された特徴マップを私達のデータセットのために再目的化できます。

    モデル全体を (再) 訓練する必要はありません。ベース畳み込みネットワークは既に写真を分類するために一般的に有用な特徴を既に含んでいます。けれども、事前訓練されたモデルの最後の分類パートは元の分類タスクに特有で、結果的に (その上で) モデルが訓練されたクラスのセットに特有です。

  2. 再調整 (= Fine-Tuning) – 凍結されたモデルベースの 2, 3 のトップ層を解凍して、新たに追加された分類層とベースモデルの最後の層群の両者を一緒に訓練します。これはベースモデルの高次特徴表現を、特定のタスクのためにより関連付けるために「再調整」することを可能にします。

一般的な機械学習ワークフローに従います。

  1. データを調べて理解する。
  2. 入力パイプラインを構築します、このケースでは Keras ImageDataGenerator を使用します。
  3. モデルを構成する。
    • 事前訓練されたベースモデル (と事前訓練された重み) をロードする
    • トップに分類層をスタックする。
  4. モデルを訓練する。
  5. モデルを評価する。
from __future__ import absolute_import, division, print_function, unicode_literals

import os

import numpy as np

import matplotlib.pyplot as plt
import tensorflow as tf

keras = tf.keras

 

データ前処理

データ・ダウンロード

猫と犬のデータセットをロードするために TensorFlow Dataset を利用します。

この tfds パッケージは事前定義されたデータをロードする最も容易な方法です。もし貴方自身のデータを持ち、インポートしてそれを TensorFlow で使用することに興味があれば loading image data を見てください。

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

tfds.load メソッドはデータをダウンロードしてキャッシュし、tf.data.Dataset オブジェクトを返します。これらのオブジェクトはデータを操作してそれをモデルにパイプするためのパワフルで、効率的なメソッドを提供します。

“cats_vs_dog” は標準的な分割を定義していないので、それをデータの 80%, 10%, 10% で (train, validation, test) にそれぞれ分割するために subsplit 機能を使用します。

SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)
Downloading and preparing dataset cats_vs_dogs (786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/2.0.1...

/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,
WARNING:absl:1738 images were corrupted and were skipped

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/2.0.1. Subsequent calls will reuse this data.

結果としての tf.data.Dataset オブジェクトは (image, label) ペアを含みます。そこでは画像は可変な shape と 3 チャネルを持ち、そしてラベルはスカラーです。

print(raw_train)
print(raw_validation)
print(raw_test)
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>

訓練セットから最初の 2 つの画像とラベルを表示します。

get_label_name = metadata.features['label'].int2str

for image, label in raw_train.take(2):
  plt.figure()
  plt.imshow(image)
  plt.title(get_label_name(label))

 

データをフォーマットする

タスクのために画像をフォーマットするために tf.image モジュールを使用します。

画像を固定入力サイズにリサイズして、入力チャネルを [-1, 1] の範囲にリスケールします。

IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

map メソッドを使用してこの関数をデータセットの各アイテムに適用します :

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

今はデータをシャッフルしてバッチ化します。

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

データのバッチを調べます :

for image_batch, label_batch in train_batches.take(1):
   pass

image_batch.shape
TensorShape([32, 160, 160, 3])

 

事前訓練された convnet からベースモデルを作成する

Google で開発された MobileNet V2 モデルからベースモデルを作成します。これは ImageNet データセット、web 画像の 1.4 M 画像と 1000 クラスの巨大データセット上で事前訓練されています。ImageNet はパンノキ (= jackfruit) と注射器のようなカテゴリを持つ非常に恣意的な研究訓練データセットを持ちますが、この知識の土台は特定のデータセットから猫と犬を識別するのに役立ちます。

最初に、特徴抽出のために使用する MobileNet V2 の層を選択する必要があります。明らかに、最も最後の分類層 (「トップ」上、何故ならば機械学習モデルの殆どの図はボトムからトップに進みます) は全く役立ちません。代わりに、flatten 演算の前の最も最後の層に依拠する一般的な実践に従います。この層は「ボトルネック層」と呼ばれます。ボトルネック特徴は final/top 層に比較して遥かに汎用性を保持します。

最初に、ImageNet 上で訓練された重みとともに事前ロードされた MobileNet V2 モデルをインスタンス化します。include_top=False 引数を指定することにより、トップに分類層を含まないネットワークをロードします、これは特徴抽出のために理想的です。

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 1s 0us/step

この特徴抽出器は各 160x160x3 画像を 5x5x1280 特徴ブロックに変換します。画像のサンプルバッチにそれが何をするかを見ます :

feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

 

特徴抽出

前のステップで作成された畳み込みベースを凍結してそれを特徴抽出器として使用し、その上に分類器を追加して top-level 分類器を訓練します。

 

畳み込みベースを凍結する

compile してモデルを訓練する前に畳み込みベースを凍結することは重要です。凍結する (あるいは layer.trainable = False を設定する) ことにより、与えられた層の重みが訓練の間に更新されることを回避します。MobileNet V2 は多くの層を持ちますが、全体のモデルの trainable フラグを False に設定すれば総ての層を凍結します。

base_model.trainable = False
# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 161, 161, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

 

分類ヘッドを追加する

特徴ブロックから予測を生成するために、特徴を画像毎に単一 1280-要素ベクトルに変換するために tf.keras.layers.GlobalAveragePlloing2d 層を使用して 5×5 空間的位置に渡り平均します。

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

これらの特徴を画像毎に単一の予測に変換するために tf.keras.layers.Dense 層を適用します。ここでは活性化関数は必要ありません、何故ならばこの予測はロジット、あるいは生の予測値として扱われるからです。正数はクラス 1 を予測し、負数はクラス 0 を予測します。

prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

さて特徴抽出器、そしてこれらの 2 つの層を tf.keras.Sequential モデルを使用してスタックします :

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

 

モデルをコンパイルする

モデルを訓練する前にそれをコンパイルしなければなりません。2 クラスありますので、二値交差エントロピー損失を使用します。

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

MobileNet の 2.5M パラメータは凍結されますが、Dense 層に 1.2K の訓練可能なパラメータがあります。これらは 2 つの tf.Variable オブジェクト、重みとバイアスに分けられます。

len(model.trainable_variables)
2

 

モデルを訓練する

10 エポックの間の訓練後、 ~96% 精度を見るはずです。

num_train, num_val, num_test = (
  metadata.splits['train'].num_examples*weight/10
  for weight in SPLIT_WEIGHTS
)
initial_epochs = 10
steps_per_epoch = round(num_train)//BATCH_SIZE
validation_steps = 20

loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)
20/20 [==============================] - 2s 92ms/step - loss: 5.1612 - accuracy: 0.5141
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 5.16
initial accuracy: 0.51
history = model.fit(train_batches,
                    epochs=initial_epochs,
                    validation_data=validation_batches)
Epoch 1/10
582/582 [==============================] - 36s 61ms/step - loss: 2.3922 - accuracy: 0.7156 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/10
582/582 [==============================] - 31s 53ms/step - loss: 1.0886 - accuracy: 0.8540 - val_loss: 0.5596 - val_accuracy: 0.9181
Epoch 3/10
582/582 [==============================] - 31s 53ms/step - loss: 0.7876 - accuracy: 0.8961 - val_loss: 0.4342 - val_accuracy: 0.9453
Epoch 4/10
582/582 [==============================] - 31s 53ms/step - loss: 0.6501 - accuracy: 0.9128 - val_loss: 0.4245 - val_accuracy: 0.9474
Epoch 5/10
582/582 [==============================] - 31s 53ms/step - loss: 0.6074 - accuracy: 0.9217 - val_loss: 0.3910 - val_accuracy: 0.9556
Epoch 6/10
582/582 [==============================] - 31s 53ms/step - loss: 0.5526 - accuracy: 0.9255 - val_loss: 0.5181 - val_accuracy: 0.9427
Epoch 7/10
582/582 [==============================] - 31s 53ms/step - loss: 0.4930 - accuracy: 0.9341 - val_loss: 0.4041 - val_accuracy: 0.9513
Epoch 8/10
582/582 [==============================] - 31s 53ms/step - loss: 0.5091 - accuracy: 0.9348 - val_loss: 0.4530 - val_accuracy: 0.9526
Epoch 9/10
582/582 [==============================] - 31s 54ms/step - loss: 0.4681 - accuracy: 0.9366 - val_loss: 0.4172 - val_accuracy: 0.9539
Epoch 10/10
582/582 [==============================] - 31s 54ms/step - loss: 0.4634 - accuracy: 0.9386 - val_loss: 0.4221 - val_accuracy: 0.9530

 

学習カーブ

MobileNet V2 ベースモデルを固定された特徴抽出器として使用するときの訓練と検証の精度 / 損失の学習カーブを見てみましょう。

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

Note: もし貴方が何故検証メトリクスが訓練メトリクスよりも明らかに良いのか疑問に思うのであれば、主要因は tf.keras.layers.BatchNormalizationtf.keras.layers.Dropout のような層が訓練の間の精度に影響するからです。それらは検証損失を計算するときには無効にされます。

程度は少ないですが、それはまた検証メトリクスがエポック後に評価される一方で訓練メトリクスはエポックのための平均を報告するからでもあります、そのため検証メトリクスは僅かばかり長く訓練されたモデルを見ることになります。

 

再調整

私達の特徴抽出実験では、MobileNet V2 ベースモデルの上の 2, 3 層だけを訓練していました。事前訓練されたネットワークの重みは訓練の間に更新されませんでした。

より以上にパフォーマンスを増す一つの方法は 、貴方が追加した分類器の訓練と一緒に事前訓練されたモデルの上部の層の重みを訓練 (or「再調整」) することです。訓練過程は一般的な特徴マップから私達のデータセットに特に関連する特徴へと重みが調整されることを強制します。

Note: これは事前訓練されたモデルを非訓練可能に設定しながら top-level 分類器を訓練した後でのみ試されるべきです。もし貴方が事前訓練されたモデルの上にランダムに初期化された分類器を追加して総ての層を一緒に訓練することを試みる場合、勾配更新の大きさが (分類器からのランダム重みのために) 大き過ぎて貴方の事前訓練されたモデルはそれが学習したことを忘れるでしょう。

また、MobileNet モデル全体よりも小さい数の top 層を再調整することを試みるべきです。殆どの畳み込みネットワークでは、層が高位になればなるほど、それはより特化されます。最初の 2, 3 の層は非常に単純で一般的な特徴を学習して、それは殆ど総てのタイプの画像に一般化されます。より高く行くほどに、特徴は段々と (モデルがその上で訓練された) データセット特有になります。再調整のゴールは一般的な学習を上書きすることではなく、これらの固有の特徴を新しいデータで動作するように適応させることです。

 

モデルのトップ層を解凍する

貴方が行なう必要がある総てのことは base_model を解凍してボトム層を非訓練可能に (i.e. 訓練できないように) 設定することです。それから、モデルを再コンパイルするべきです (これらの変更が有効になるために必要です)、そして訓練を再開します。

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  155

 

モデルをコンパイルする

遥かに低い訓練率 (= training rate) を使用してモデルをコンパイルします。

model.compile(loss='binary_crossentropy',
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,863,873
Non-trainable params: 395,392
_________________________________________________________________
len(model.trainable_variables)
58

 

モデルの訓練を継続する

先に収束するために訓練したのであれば、これは数パーセントの更なる精度を貴方に得させるでしょう。

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_batches,
                         epochs=total_epochs,
                         initial_epoch = initial_epochs,
                         validation_data=validation_batches)
Epoch 11/20
582/582 [==============================] - 42s 72ms/step - loss: 0.4341 - accuracy: 0.9476 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 12/20
582/582 [==============================] - 35s 60ms/step - loss: 0.3408 - accuracy: 0.9592 - val_loss: 0.2207 - val_accuracy: 0.9703
Epoch 13/20
582/582 [==============================] - 35s 60ms/step - loss: 0.2940 - accuracy: 0.9645 - val_loss: 0.2162 - val_accuracy: 0.9737
Epoch 14/20
582/582 [==============================] - 35s 60ms/step - loss: 0.2327 - accuracy: 0.9712 - val_loss: 0.2374 - val_accuracy: 0.9728
Epoch 15/20
582/582 [==============================] - 35s 60ms/step - loss: 0.2110 - accuracy: 0.9756 - val_loss: 0.2342 - val_accuracy: 0.9733
Epoch 16/20
582/582 [==============================] - 35s 60ms/step - loss: 0.1862 - accuracy: 0.9783 - val_loss: 0.2156 - val_accuracy: 0.9741
Epoch 17/20
582/582 [==============================] - 35s 60ms/step - loss: 0.1655 - accuracy: 0.9794 - val_loss: 0.2218 - val_accuracy: 0.9746
Epoch 18/20
582/582 [==============================] - 35s 61ms/step - loss: 0.1372 - accuracy: 0.9823 - val_loss: 0.2023 - val_accuracy: 0.9750
Epoch 19/20
582/582 [==============================] - 35s 60ms/step - loss: 0.1186 - accuracy: 0.9846 - val_loss: 0.2137 - val_accuracy: 0.9750
Epoch 20/20
582/582 [==============================] - 35s 60ms/step - loss: 0.1135 - accuracy: 0.9847 - val_loss: 0.2207 - val_accuracy: 0.9759

MobileNet V2 ベースモデルの最後の 2, 3 層を再調整してその上の分類器を訓練するとき、訓練と検証の精度 / 損失の学習カーブを見てみましょう。検証損失は訓練損失よりも遥かに高いので、何某かの overfitting を得るかもしれません。

新しい訓練セットは比較的小さくて元の MobileNet V2 データセットに類似しているので、某かの overfitting もまた得ているかもしれません。

再調整後にモデルは 98% 精度近くに到達します。

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

 

要約 :

  • 特徴抽出のために事前訓練されたモデルを使用する: 小さいデータセットで作業するとき、同じドメインのより巨大なデータセット上で訓練されたモデルにより学習された特徴を活用することは一般的です。これは事前訓練されたモデルをインスタンス化して上に完全結合分類器を追加することにより成されます。事前訓練されたモデルは「凍結」されて訓練の間分類器の重みだけが更新されます。この場合、畳み込みベースが各画像に関連する総ての特徴を抽出してそして抽出された特徴のこれらのセットが与えられたとき画像クラスを決定する分類器を単に訓練します。
  • 事前訓練されたモデルを再調整する: パフォーマンスを更に改良するために、事前訓練されたモデルの top-level 層を再調整を通して新しいデータセットに再目的化することを望むかもしれません。この場合、貴方のモデルがデータセットに固有の高位な特徴を学習するように重みを調整します。このテクニックは通常は、訓練データセットが巨大で (事前訓練されたモデルがその上で訓練された) 元のデータセットに非常に類似しているときに限り推奨されます。
 

以上



AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com