Keras 2 : examples : TPU 上で肺炎の分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/28/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Pneumonia Classification on TPU (Author: Amy MiHyun Jang)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : TPU 上で肺炎の分類
Description: TPU 上の医用画像分類。
イントロダクション + セットアップ
このチュートリアルは、X 線精査が肺炎の存在を示すかを予測する X 線画像分類モデルを構築する方法を説明します。
import re
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu)
except:
strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)
Device: grpc://10.0.27.122:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470 INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. WARNING:absl:[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is deprecated, please use the non experimental symbol [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) instead. INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) Number of replicas: 8
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system. INFO:tensorflow:Deallocate tpu buffers before initializing tpu system. INFO:tensorflow:Initializing the TPU system: grpc://10.10.81.42:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.10.81.42:8470 INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. Device: grpc://10.10.81.42:8470 INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) Number of replicas: 8
TPU を使用してデータをロードするためにはデータへの Google Cloud リンクが必要です。以下では、このサンプルで使用する主要な設定パラメータを定義します。TPU 上で実行するには、このサンプルは Colab 上 TPU ランタイムが選択されなければなりません。
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 25 * strategy.num_replicas_in_sync
IMAGE_SIZE = [180, 180]
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
データのロード
Cell から利用している胸部 X 線データはデータを訓練とテストファイルに分割します。最初に訓練 TFRecords をロードしましょう。
train_images = tf.data.TFRecordDataset(
"gs://download.tensorflow.org/data/ChestXRay2017/train/images.tfrec"
)
train_paths = tf.data.TFRecordDataset(
"gs://download.tensorflow.org/data/ChestXRay2017/train/paths.tfrec"
)
ds = tf.data.Dataset.zip((train_images, train_paths))
幾つの healthy/normal 胸部 X 線を持ち、幾つの肺炎胸部 X 線を持つかカウントしましょう :
COUNT_NORMAL = len(
[
filename
for filename in train_paths
if "NORMAL" in filename.numpy().decode("utf-8")
]
)
print("Normal images count in training set: " + str(COUNT_NORMAL))
COUNT_PNEUMONIA = len(
[
filename
for filename in train_paths
if "PNEUMONIA" in filename.numpy().decode("utf-8")
]
)
print("Pneumonia images count in training set: " + str(COUNT_PNEUMONIA))
Normal images count in training set: 1349 Pneumonia images count in training set: 3883
正常よりも肺炎として分類される画像が遥かに多くあることに気づいてください。これはデータが不均衡であることを示しています。ノートブックでは後でこの不均衡を正します。
各ファイル名を対応する (画像, ラベル) ペアにマップすることを望みます。以下のメソッドはそれを行なうのに役立ちます。
2 つのラベルを持つだけですので、ラベルを 1 か True が肺炎を示し、0 か False が正常であることを示すようにエンコードします。
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, "/")
# The second to last is the class-directory
return parts[-2] == "PNEUMONIA"
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# resize the image to the desired size.
return tf.image.resize(img, IMAGE_SIZE)
def process_path(image, path):
label = get_label(path)
# load the raw data from the file as a string
img = decode_img(image)
return img, label
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
データを訓練と検証データセットに分割しましょう。
ds = ds.shuffle(10000)
train_ds = ds.take(4200)
val_ds = ds.skip(4200)
(image, label) ペアの shape を表示しましょう。
for image, label in train_ds.take(1):
print("Image shape: ", image.numpy().shape)
print("Label: ", label.numpy())
Image shape: (180, 180, 3) Label: False
テストデータもまたロードしてフォーマットします。
test_images = tf.data.TFRecordDataset(
"gs://download.tensorflow.org/data/ChestXRay2017/test/images.tfrec"
)
test_paths = tf.data.TFRecordDataset(
"gs://download.tensorflow.org/data/ChestXRay2017/test/paths.tfrec"
)
test_ds = tf.data.Dataset.zip((test_images, test_paths))
test_ds = test_ds.map(process_path, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE)
データセットの可視化
まず、I/O がブロックされることなくディスクからデータを生成できるように、バッファリングされた先取りを使用しましよう。
大規模な画像データセットはメモリにキャッシュされるべきではないことに注意してください。ここではそれを行ないます、何故ならばデータセットがそれほど大きくはなく TPU 上で訓練したいからです。
def prepare_for_training(ds, cache=True):
# This is a small dataset, only load it once, and keep it in memory.
# use `.cache(filename)` to cache preprocessing work for datasets that don't
# fit in memory.
if cache:
if isinstance(cache, str):
ds = ds.cache(cache)
else:
ds = ds.cache()
ds = ds.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches in the background while the model
# is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
訓練データの next バッチ iteration を呼び出します。
train_ds = prepare_for_training(train_ds)
val_ds = prepare_for_training(val_ds)
image_batch, label_batch = next(iter(train_ds))
バッチの画像を表示するメソッドを定義します。
def show_batch(image_batch, label_batch):
plt.figure(figsize=(10, 10))
for n in range(25):
ax = plt.subplot(5, 5, n + 1)
plt.imshow(image_batch[n] / 255)
if label_batch[n]:
plt.title("PNEUMONIA")
else:
plt.title("NORMAL")
plt.axis("off")
このメソッドはそのパラメータとして NumPy 配列を取りますので、NumPy 配列形式のテンソルを返すためにバッチに対して numpy 関数を呼び出します。
show_batch(image_batch.numpy(), label_batch.numpy())
CNN の構築
モデルをよりモジュール化して理解するのを容易にするため、幾つかのブロックを定義しましょう。畳み込みニューラルネットワークを構築していますので、畳み込みブロックと dense 層ブロックを作成します。
この CNN のためのアーキテクチャはこの 記事 にインスパイアされています。
from tensorflow import keras
from tensorflow.keras import layers
def conv_block(filters, inputs):
x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(inputs)
x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(x)
x = layers.BatchNormalization()(x)
outputs = layers.MaxPool2D()(x)
return outputs
def dense_block(units, dropout_rate, inputs):
x = layers.Dense(units, activation="relu")(inputs)
x = layers.BatchNormalization()(x)
outputs = layers.Dropout(dropout_rate)(x)
return outputs
以下のメソッドはモデルを構築する関数を定義します。
画像は元々は [0, 255] からの範囲の値を持ちます。CNN はより小さい数値で上手く機能しますので、入力に対してこれをスケールダウンします。
Dropout 層は重要です、モデルが過剰適合する可能性を減じるからです。モデルを 1 つのノードを持つ Dense 層で終わらせたいです、これは X 線が肺炎の存在を示すか決定する二値出力になるからです。
def build_model():
inputs = keras.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
x = layers.Rescaling(1.0 / 255)(inputs)
x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)
x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)
x = layers.MaxPool2D()(x)
x = conv_block(32, x)
x = conv_block(64, x)
x = conv_block(128, x)
x = layers.Dropout(0.2)(x)
x = conv_block(256, x)
x = layers.Dropout(0.2)(x)
x = layers.Flatten()(x)
x = dense_block(512, 0.7, x)
x = dense_block(128, 0.5, x)
x = dense_block(64, 0.3, x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
不均衡を補正する
このサンプルで前にデータが不均衡であることを見ました、正常よりも肺炎として分類されるより多くの画像持ちます。クラス重み付けを使用してこれを補正します :
initial_bias = np.log([COUNT_PNEUMONIA / COUNT_NORMAL])
print("Initial bias: {:.5f}".format(initial_bias[0]))
TRAIN_IMG_COUNT = COUNT_NORMAL + COUNT_PNEUMONIA
weight_for_0 = (1 / COUNT_NORMAL) * (TRAIN_IMG_COUNT) / 2.0
weight_for_1 = (1 / COUNT_PNEUMONIA) * (TRAIN_IMG_COUNT) / 2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
print("Weight for class 0: {:.2f}".format(weight_for_0))
print("Weight for class 1: {:.2f}".format(weight_for_1))
Initial bias: 1.05724 Weight for class 0: 1.94 Weight for class 1: 0.67
クラス 0 (正常) に対する重みはクラス 1 (肺炎) に対する重みよりもかなり高いです。CNN は訓練データが均衡であるときに最善に動作しますので、正常な画像が少ないために、データのバランスを取るために各々の正常な画像はより重み付けられます。
モデルの訓練
コールバックの定義
チェックポイント・コールバックはモデルの最善な重みをセーブしますので、次にモデルを使用したいとき、それを訓練する時間を使う必要がありません。早期停止コールバックは、モデルが停滞し始めたり、過剰適合し始めて悪くなり始めたときに訓練プロセスを停止します。
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("xray_model.h5", save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
patience=10, restore_best_weights=True
)
学習率の調整もまたしたいです。高すぎる学習率はモデルを発散させます。小さすぎる学習率はモデル (の訓練) を遅くさせます。以下では exponential 学習率スケジューリング法を実装します。
initial_learning_rate = 0.015
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
モデルの適合
メトリクスとしては、 precision と recall を含めたいです、それらはモデルがどの程度良いかのより多くの情報を持つ状況の使用を提供するからです。accuracy はラベルのどのくらいの割合いが正しいかを教えてくれます。データが不均衡なので、accuracy は良いモデルの歪んだ感覚を与えるかもしれません (i.e. PNEUMONIA を常に予測するモデルは 74% accurate ですが、良いモデルではありません)。
precision は真陽性 (TP) と偽陽性 (FP) の合計に対する真陽性の数です。それはラベル付けられたポジティブの何割が実際に正しいかを示します。
recall は真陽性と偽陰性 (FN) の合計に対する真陽性の数です。それは実際のポジティブの何割が正しいかを示しています。
画像に対して 2 つの可能なラベルしかないので、二値 crossentropy 損失を使用していきます。モデルを適合させるとき、先に定義したクラス重みを指定することを忘れないでください。TPU を使用していますので、訓練は迅速です – 2 分未満です。
with strategy.scope():
model = build_model()
METRICS = [
tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.Precision(name="precision"),
tf.keras.metrics.Recall(name="recall"),
]
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
loss="binary_crossentropy",
metrics=METRICS,
)
history = model.fit(
train_ds,
epochs=100,
validation_data=val_ds,
class_weight=class_weight,
callbacks=[checkpoint_cb, early_stopping_cb],
)
Epoch 1/100 WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. 21/21 [==============================] - 12s 568ms/step - loss: 0.5857 - binary_accuracy: 0.6960 - precision: 0.8887 - recall: 0.6733 - val_loss: 34.0149 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 2/100 21/21 [==============================] - 3s 128ms/step - loss: 0.2916 - binary_accuracy: 0.8755 - precision: 0.9540 - recall: 0.8738 - val_loss: 97.5194 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 3/100 21/21 [==============================] - 4s 167ms/step - loss: 0.2384 - binary_accuracy: 0.9002 - precision: 0.9663 - recall: 0.8964 - val_loss: 27.7902 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 4/100 21/21 [==============================] - 4s 173ms/step - loss: 0.2046 - binary_accuracy: 0.9145 - precision: 0.9725 - recall: 0.9102 - val_loss: 10.8302 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 5/100 21/21 [==============================] - 4s 174ms/step - loss: 0.1841 - binary_accuracy: 0.9279 - precision: 0.9733 - recall: 0.9279 - val_loss: 3.5860 - val_binary_accuracy: 0.7103 - val_precision: 0.7162 - val_recall: 0.9879 Epoch 6/100 21/21 [==============================] - 4s 185ms/step - loss: 0.1600 - binary_accuracy: 0.9362 - precision: 0.9791 - recall: 0.9337 - val_loss: 0.3014 - val_binary_accuracy: 0.8895 - val_precision: 0.8973 - val_recall: 0.9555 Epoch 7/100 21/21 [==============================] - 3s 130ms/step - loss: 0.1567 - binary_accuracy: 0.9393 - precision: 0.9798 - recall: 0.9372 - val_loss: 0.6763 - val_binary_accuracy: 0.7810 - val_precision: 0.7760 - val_recall: 0.9771 Epoch 8/100 21/21 [==============================] - 3s 131ms/step - loss: 0.1532 - binary_accuracy: 0.9421 - precision: 0.9825 - recall: 0.9385 - val_loss: 0.3169 - val_binary_accuracy: 0.8895 - val_precision: 0.8684 - val_recall: 0.9973 Epoch 9/100 21/21 [==============================] - 4s 184ms/step - loss: 0.1457 - binary_accuracy: 0.9431 - precision: 0.9822 - recall: 0.9401 - val_loss: 0.2064 - val_binary_accuracy: 0.9273 - val_precision: 0.9840 - val_recall: 0.9136 Epoch 10/100 21/21 [==============================] - 3s 132ms/step - loss: 0.1201 - binary_accuracy: 0.9521 - precision: 0.9869 - recall: 0.9479 - val_loss: 0.4364 - val_binary_accuracy: 0.8605 - val_precision: 0.8443 - val_recall: 0.9879 Epoch 11/100 21/21 [==============================] - 3s 127ms/step - loss: 0.1200 - binary_accuracy: 0.9510 - precision: 0.9863 - recall: 0.9469 - val_loss: 0.5197 - val_binary_accuracy: 0.8508 - val_precision: 1.0000 - val_recall: 0.7922 Epoch 12/100 21/21 [==============================] - 4s 186ms/step - loss: 0.1077 - binary_accuracy: 0.9581 - precision: 0.9870 - recall: 0.9559 - val_loss: 0.1349 - val_binary_accuracy: 0.9486 - val_precision: 0.9587 - val_recall: 0.9703 Epoch 13/100 21/21 [==============================] - 4s 173ms/step - loss: 0.0918 - binary_accuracy: 0.9650 - precision: 0.9914 - recall: 0.9611 - val_loss: 0.0926 - val_binary_accuracy: 0.9700 - val_precision: 0.9837 - val_recall: 0.9744 Epoch 14/100 21/21 [==============================] - 3s 130ms/step - loss: 0.0996 - binary_accuracy: 0.9612 - precision: 0.9913 - recall: 0.9559 - val_loss: 0.1811 - val_binary_accuracy: 0.9419 - val_precision: 0.9956 - val_recall: 0.9231 Epoch 15/100 21/21 [==============================] - 3s 129ms/step - loss: 0.0898 - binary_accuracy: 0.9643 - precision: 0.9901 - recall: 0.9614 - val_loss: 0.1525 - val_binary_accuracy: 0.9486 - val_precision: 0.9986 - val_recall: 0.9298 Epoch 16/100 21/21 [==============================] - 3s 128ms/step - loss: 0.0941 - binary_accuracy: 0.9621 - precision: 0.9904 - recall: 0.9582 - val_loss: 0.5101 - val_binary_accuracy: 0.8527 - val_precision: 1.0000 - val_recall: 0.7949 Epoch 17/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0798 - binary_accuracy: 0.9636 - precision: 0.9897 - recall: 0.9607 - val_loss: 0.1239 - val_binary_accuracy: 0.9622 - val_precision: 0.9875 - val_recall: 0.9595 Epoch 18/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0821 - binary_accuracy: 0.9657 - precision: 0.9911 - recall: 0.9623 - val_loss: 0.1597 - val_binary_accuracy: 0.9322 - val_precision: 0.9956 - val_recall: 0.9096 Epoch 19/100 21/21 [==============================] - 3s 143ms/step - loss: 0.0800 - binary_accuracy: 0.9657 - precision: 0.9917 - recall: 0.9617 - val_loss: 0.2538 - val_binary_accuracy: 0.9109 - val_precision: 1.0000 - val_recall: 0.8758 Epoch 20/100 21/21 [==============================] - 3s 127ms/step - loss: 0.0605 - binary_accuracy: 0.9738 - precision: 0.9950 - recall: 0.9694 - val_loss: 0.6594 - val_binary_accuracy: 0.8566 - val_precision: 1.0000 - val_recall: 0.8003 Epoch 21/100 21/21 [==============================] - 4s 167ms/step - loss: 0.0726 - binary_accuracy: 0.9733 - precision: 0.9937 - recall: 0.9701 - val_loss: 0.0593 - val_binary_accuracy: 0.9816 - val_precision: 0.9945 - val_recall: 0.9798 Epoch 22/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0577 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 0.1087 - val_binary_accuracy: 0.9729 - val_precision: 0.9931 - val_recall: 0.9690 Epoch 23/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0652 - binary_accuracy: 0.9729 - precision: 0.9924 - recall: 0.9707 - val_loss: 1.8465 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 24/100 21/21 [==============================] - 3s 124ms/step - loss: 0.0538 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 1.5769 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 25/100 21/21 [==============================] - 4s 167ms/step - loss: 0.0549 - binary_accuracy: 0.9776 - precision: 0.9954 - recall: 0.9743 - val_loss: 0.0590 - val_binary_accuracy: 0.9777 - val_precision: 0.9904 - val_recall: 0.9784 Epoch 26/100 21/21 [==============================] - 3s 131ms/step - loss: 0.0677 - binary_accuracy: 0.9719 - precision: 0.9924 - recall: 0.9694 - val_loss: 2.6008 - val_binary_accuracy: 0.6928 - val_precision: 0.9977 - val_recall: 0.5735 Epoch 27/100 21/21 [==============================] - 3s 127ms/step - loss: 0.0469 - binary_accuracy: 0.9833 - precision: 0.9971 - recall: 0.9804 - val_loss: 1.0184 - val_binary_accuracy: 0.8605 - val_precision: 0.9983 - val_recall: 0.8070 Epoch 28/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0501 - binary_accuracy: 0.9790 - precision: 0.9961 - recall: 0.9755 - val_loss: 0.3737 - val_binary_accuracy: 0.9089 - val_precision: 0.9954 - val_recall: 0.8772 Epoch 29/100 21/21 [==============================] - 3s 128ms/step - loss: 0.0548 - binary_accuracy: 0.9798 - precision: 0.9941 - recall: 0.9784 - val_loss: 1.2928 - val_binary_accuracy: 0.7907 - val_precision: 1.0000 - val_recall: 0.7085 Epoch 30/100 21/21 [==============================] - 3s 129ms/step - loss: 0.0370 - binary_accuracy: 0.9860 - precision: 0.9980 - recall: 0.9829 - val_loss: 0.1370 - val_binary_accuracy: 0.9612 - val_precision: 0.9972 - val_recall: 0.9487 Epoch 31/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0585 - binary_accuracy: 0.9819 - precision: 0.9951 - recall: 0.9804 - val_loss: 1.1955 - val_binary_accuracy: 0.6870 - val_precision: 0.9976 - val_recall: 0.5655 Epoch 32/100 21/21 [==============================] - 3s 140ms/step - loss: 0.0813 - binary_accuracy: 0.9695 - precision: 0.9934 - recall: 0.9652 - val_loss: 1.0394 - val_binary_accuracy: 0.8576 - val_precision: 0.9853 - val_recall: 0.8138 Epoch 33/100 21/21 [==============================] - 3s 128ms/step - loss: 0.1111 - binary_accuracy: 0.9555 - precision: 0.9870 - recall: 0.9524 - val_loss: 4.9438 - val_binary_accuracy: 0.5911 - val_precision: 1.0000 - val_recall: 0.4305 Epoch 34/100 21/21 [==============================] - 3s 130ms/step - loss: 0.0680 - binary_accuracy: 0.9726 - precision: 0.9921 - recall: 0.9707 - val_loss: 2.8822 - val_binary_accuracy: 0.7267 - val_precision: 0.9978 - val_recall: 0.6208 Epoch 35/100 21/21 [==============================] - 4s 187ms/step - loss: 0.0784 - binary_accuracy: 0.9712 - precision: 0.9892 - recall: 0.9717 - val_loss: 0.3940 - val_binary_accuracy: 0.9390 - val_precision: 0.9942 - val_recall: 0.9204
Epoch 1/100 WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/engine/training.py:2970: StrategyBase.unwrap (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version. Instructions for updating: use `experimental_local_results` instead. WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/engine/training.py:2970: StrategyBase.unwrap (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version. Instructions for updating: use `experimental_local_results` instead. 21/21 [==============================] - 48s 595ms/step - loss: 0.5426 - binary_accuracy: 0.7360 - precision: 0.8997 - recall: 0.7251 - val_loss: 0.8867 - val_binary_accuracy: 0.7297 - val_precision: 0.7302 - val_recall: 0.9960 Epoch 2/100 21/21 [==============================] - 3s 167ms/step - loss: 0.2512 - binary_accuracy: 0.9014 - precision: 0.9652 - recall: 0.8996 - val_loss: 33.7673 - val_binary_accuracy: 0.2733 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/100 21/21 [==============================] - 4s 174ms/step - loss: 0.1961 - binary_accuracy: 0.9240 - precision: 0.9739 - recall: 0.9224 - val_loss: 7.7650 - val_binary_accuracy: 0.2771 - val_precision: 1.0000 - val_recall: 0.0053 Epoch 4/100 21/21 [==============================] - 4s 179ms/step - loss: 0.1843 - binary_accuracy: 0.9269 - precision: 0.9776 - recall: 0.9227 - val_loss: 1.0647 - val_binary_accuracy: 0.6570 - val_precision: 0.9975 - val_recall: 0.5293 Epoch 5/100 21/21 [==============================] - 5s 225ms/step - loss: 0.1775 - binary_accuracy: 0.9276 - precision: 0.9779 - recall: 0.9233 - val_loss: 0.8275 - val_binary_accuracy: 0.6860 - val_precision: 1.0000 - val_recall: 0.5680 Epoch 6/100 21/21 [==============================] - 4s 195ms/step - loss: 0.1551 - binary_accuracy: 0.9390 - precision: 0.9799 - recall: 0.9371 - val_loss: 1.7284 - val_binary_accuracy: 0.5407 - val_precision: 0.9964 - val_recall: 0.3693 Epoch 7/100 21/21 [==============================] - 5s 226ms/step - loss: 0.1509 - binary_accuracy: 0.9398 - precision: 0.9831 - recall: 0.9349 - val_loss: 0.1681 - val_binary_accuracy: 0.9535 - val_precision: 0.9466 - val_recall: 0.9920 Epoch 8/100 21/21 [==============================] - 5s 225ms/step - loss: 0.1188 - binary_accuracy: 0.9517 - precision: 0.9873 - recall: 0.9471 - val_loss: 0.1273 - val_binary_accuracy: 0.9603 - val_precision: 0.9823 - val_recall: 0.9627 Epoch 9/100 21/21 [==============================] - 4s 174ms/step - loss: 0.1129 - binary_accuracy: 0.9550 - precision: 0.9874 - recall: 0.9516 - val_loss: 0.2206 - val_binary_accuracy: 0.9215 - val_precision: 0.9941 - val_recall: 0.8973 Epoch 10/100 21/21 [==============================] - 3s 167ms/step - loss: 0.1207 - binary_accuracy: 0.9517 - precision: 0.9873 - recall: 0.9471 - val_loss: 0.6490 - val_binary_accuracy: 0.8333 - val_precision: 0.9983 - val_recall: 0.7720 Epoch 11/100 21/21 [==============================] - 3s 167ms/step - loss: 0.1111 - binary_accuracy: 0.9536 - precision: 0.9870 - recall: 0.9500 - val_loss: 3.0831 - val_binary_accuracy: 0.5136 - val_precision: 1.0000 - val_recall: 0.3307 Epoch 12/100 21/21 [==============================] - 4s 170ms/step - loss: 0.1131 - binary_accuracy: 0.9593 - precision: 0.9884 - recall: 0.9564 - val_loss: 6.0620 - val_binary_accuracy: 0.3556 - val_precision: 1.0000 - val_recall: 0.1133 Epoch 13/100 21/21 [==============================] - 4s 202ms/step - loss: 0.1113 - binary_accuracy: 0.9533 - precision: 0.9880 - recall: 0.9487 - val_loss: 1.0773 - val_binary_accuracy: 0.7355 - val_precision: 1.0000 - val_recall: 0.6360 Epoch 14/100 21/21 [==============================] - 4s 176ms/step - loss: 0.0925 - binary_accuracy: 0.9586 - precision: 0.9907 - recall: 0.9532 - val_loss: 0.7277 - val_binary_accuracy: 0.8391 - val_precision: 0.9983 - val_recall: 0.7800 Epoch 15/100 21/21 [==============================] - 4s 191ms/step - loss: 0.0921 - binary_accuracy: 0.9602 - precision: 0.9907 - recall: 0.9554 - val_loss: 1.0204 - val_binary_accuracy: 0.7626 - val_precision: 1.0000 - val_recall: 0.6733 Epoch 16/100 21/21 [==============================] - 4s 183ms/step - loss: 0.0953 - binary_accuracy: 0.9581 - precision: 0.9910 - recall: 0.9522 - val_loss: 1.1532 - val_binary_accuracy: 0.7054 - val_precision: 0.9978 - val_recall: 0.5960 Epoch 17/100 21/21 [==============================] - 4s 172ms/step - loss: 0.0844 - binary_accuracy: 0.9648 - precision: 0.9920 - recall: 0.9602 - val_loss: 1.8251 - val_binary_accuracy: 0.6211 - val_precision: 1.0000 - val_recall: 0.4787 Epoch 18/100 21/21 [==============================] - 4s 198ms/step - loss: 0.0907 - binary_accuracy: 0.9662 - precision: 0.9940 - recall: 0.9602 - val_loss: 6.3741 - val_binary_accuracy: 0.6143 - val_precision: 0.9972 - val_recall: 0.4707 CPU times: user 53.9 s, sys: 5.78 s, total: 59.7 s Wall time: 2min 18s
モデル性能の可視化
訓練と検証セットに対してモデル精度と損失をプロットしましょう。このノートブックのためにランダムシードは指定されていないことに注意してください。貴方のノートブックについて、僅かなばらつきがあるかもしれません。
fig, ax = plt.subplots(1, 4, figsize=(20, 3))
ax = ax.ravel()
for i, met in enumerate(["precision", "recall", "binary_accuracy", "loss"]):
ax[i].plot(history.history[met])
ax[i].plot(history.history["val_" + met])
ax[i].set_title("Model {}".format(met))
ax[i].set_xlabel("epochs")
ax[i].set_ylabel(met)
ax[i].legend(["train", "val"])
We see that the accuracy for our model is around 95%.
結果の予測と評価
テストデータでモデルを評価しましょう!
model.evaluate(test_ds, return_dict=True)
4/4 [==============================] - 3s 708ms/step - loss: 0.9718 - binary_accuracy: 0.7901 - precision: 0.7524 - recall: 0.9897 {'binary_accuracy': 0.7900640964508057, 'loss': 0.9717951416969299, 'precision': 0.752436637878418, 'recall': 0.9897436499595642}
4/4 [==============================] - 7s 967ms/step - loss: 0.7763 - binary_accuracy: 0.8221 - precision: 0.7900 - recall: 0.9744 {'binary_accuracy': 0.822115421295166, 'loss': 0.7762584090232849, 'precision': 0.7900207042694092, 'recall': 0.9743589758872986}
テストデータでの精度は検証セットに対する精度よりも低いことがわかります。これは過剰適合を示しているかもしれません。
recall が precision よりも大きいのは、殆ど総ての肺炎画像が正しく識別されていますが、幾つかの正常な画像が誤って識別されていることを示しています。precision を上げることを目標にするべきです。
for image, label in test_ds.take(1):
plt.imshow(image[0] / 255.0)
plt.title(CLASS_NAMES[label[0].numpy()])
prediction = model.predict(test_ds.take(1))[0]
scores = [1 - prediction, prediction]
for score, name in zip(scores, CLASS_NAMES):
print("This image is %.2f percent %s" % ((100 * score), name))
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index This is separate from the ipykernel package so we can avoid doing imports until This image is 47.19 percent NORMAL This image is 52.81 percent PNEUMONIA
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:3: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index This is separate from the ipykernel package so we can avoid doing imports until This image is 68.40 percent NORMAL This image is 31.60 percent PNEUMONIA
以上