Keras 2 : examples : Efficient Sub-Pixel CNN を使用した画像超解像 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/18/2021 (keras 2.7.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Computer Vision : Image Super-Resolution using an Efficient Sub-Pixel CNN (Author: Xingyu Long)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

-  人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
 
-  人工知能研修サービス
 
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
 

Keras 2 : examples : Efficient Sub-Pixel CNN を使用した画像超解像
Description: Efficient sub-pixel モデルを使用して BSDS500 上で超解像を実装する。
イントロダクション
Shi, 2016 により提案された ESPCN (Efficient Sub-Pixel CNN) は、画像の低解像度バージョンが与えられたときに高解像度バージョンを再構築するモデルです。それは効率的な「サブピクセル畳み込み」層を活用し、これは画像のアップスケーリング・フィルタの配列を学習します。
このコードサンプルでは、論文からのモデルを実装して小さいデータセット, BSDS500 で訓練します。
セットアップ
import tensorflow as tf
import os
import math
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing import image_dataset_from_directory
from IPython.display import display
データ: BSDS500 のロード
データセットのダウンロード
データセットを取得するために組込みの keras.utils.get_file ユティリティを使用します。
dataset_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
data_dir = keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
root_dir = os.path.join(data_dir, "BSDS500/data")
image_dataset_from_directory を通して訓練と検証データセットを作成します。
crop_size = 300
upscale_factor = 3
input_size = crop_size // upscale_factor
batch_size = 8
train_ds = image_dataset_from_directory(
    root_dir,
    batch_size=batch_size,
    image_size=(crop_size, crop_size),
    validation_split=0.2,
    subset="training",
    seed=1337,
    label_mode=None,
)
valid_ds = image_dataset_from_directory(
    root_dir,
    batch_size=batch_size,
    image_size=(crop_size, crop_size),
    validation_split=0.2,
    subset="validation",
    seed=1337,
    label_mode=None,
)
Found 500 files belonging to 2 classes. Using 400 files for training. Found 500 files belonging to 2 classes. Using 100 files for validation.
画像を範囲 [0, 1] の値を取るようにリスケールします。
def scaling(input_image):
    input_image = input_image / 255.0
    return input_image
# Scale from (0, 255) to (0, 1)
train_ds = train_ds.map(scaling)
valid_ds = valid_ds.map(scaling)
幾つかサンプル画像を可視化しましょう :
for batch in train_ds.take(1):
    for img in batch:
        display(array_to_img(img))








このサンプルの最後に視覚評価のために使用するテスト画像のデータセットのパスを準備します。
dataset = os.path.join(root_dir, "images")
test_path = os.path.join(dataset, "test")
test_img_paths = sorted(
    [
        os.path.join(test_path, fname)
        for fname in os.listdir(test_path)
        if fname.endswith(".jpg")
    ]
)
画像のクロップとリサイズ
画像データを加工処理しましょう。最初に、画像を RGB カラー空間から YUV カラー空間に変換します。
入力データ (低解像度画像) に対して、画像をクロップし、y チャネル (輝度) を取得し、そしてそれを面積法 (PIL を使用する場合 BICUBIC を使用) でリサイズします。人間は輝度の変化により敏感ですので、YUV カラー空間の輝度チャネルだけを考えます。
ターゲットデータ (高解像度画像) については、画像を単にクロップして y チャネルを取得します。
# Use TF Ops to process.
def process_input(input, input_size, upscale_factor):
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return tf.image.resize(y, [input_size, input_size], method="area")
def process_target(input):
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return y
train_ds = train_ds.map(
    lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
train_ds = train_ds.prefetch(buffer_size=32)
valid_ds = valid_ds.map(
    lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
valid_ds = valid_ds.prefetch(buffer_size=32)
入力とターゲットデータを見てみましょう。
for batch in train_ds.take(1):
    for img in batch[0]:
        display(array_to_img(img))
    for img in batch[1]:
        display(array_to_img(img))
















モデルの構築
論文と比較して、もう一つの層を追加して tanh の代わりに relu 活性化関数を使用しています。それはモデルをより少ないエポック訓練してさえもより良いパフォーマンスを実現します。
def get_model(upscale_factor=3, channels=1):
    conv_args = {
        "activation": "relu",
        "kernel_initializer": "Orthogonal",
        "padding": "same",
    }
    inputs = keras.Input(shape=(None, None, channels))
    x = layers.Conv2D(64, 5, **conv_args)(inputs)
    x = layers.Conv2D(64, 3, **conv_args)(x)
    x = layers.Conv2D(32, 3, **conv_args)(x)
    x = layers.Conv2D(channels * (upscale_factor ** 2), 3, **conv_args)(x)
    outputs = tf.nn.depth_to_space(x, upscale_factor)
    return keras.Model(inputs, outputs)
ユティリティ関数の定義
結果を監視するために幾つかのユティリティ関数を定義する必要があります :
- plot_results : 画像をセーブしてプロットします。
- get_lowres_image : 画像をその低解像度バージョンに変換します。
- upscale_image : 低解像度画像をモデルにより再構築された高解像度画像に変えます。この関数では、モデルへの入力として YUV カラー空間からの y チャネルを使用して、それから RGB 画像を得るために出力を他のチャネルと結合します。
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
import PIL
def plot_results(img, prefix, title):
    """Plot the result with zoom-in area."""
    img_array = img_to_array(img)
    img_array = img_array.astype("float32") / 255.0
    # Create a new figure with a default 111 subplot.
    fig, ax = plt.subplots()
    im = ax.imshow(img_array[::-1], origin="lower")
    plt.title(title)
    # zoom-factor: 2.0, location: upper-left
    axins = zoomed_inset_axes(ax, 2, loc=2)
    axins.imshow(img_array[::-1], origin="lower")
    # Specify the limits.
    x1, x2, y1, y2 = 200, 300, 100, 200
    # Apply the x-limits.
    axins.set_xlim(x1, x2)
    # Apply the y-limits.
    axins.set_ylim(y1, y2)
    plt.yticks(visible=False)
    plt.xticks(visible=False)
    # Make the line.
    mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
    plt.savefig(str(prefix) + "-" + title + ".png")
    plt.show()
def get_lowres_image(img, upscale_factor):
    """Return low-resolution image to use as model input."""
    return img.resize(
        (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
        PIL.Image.BICUBIC,
    )
def upscale_image(model, img):
    """Predict the result based on input image and restore the image as RGB."""
    ycbcr = img.convert("YCbCr")
    y, cb, cr = ycbcr.split()
    y = img_to_array(y)
    y = y.astype("float32") / 255.0
    input = np.expand_dims(y, axis=0)
    out = model.predict(input)
    out_img_y = out[0]
    out_img_y *= 255.0
    # Restore the image in RGB color space.
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
    out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
    out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)
    out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
        "RGB"
    )
    return out_img
訓練を監視するためにコールバックを定義する
ESPCNCallback オブジェクトは PSNR メトリックを計算して表示します。これは超解像度パフォーマンスを評価するために使用する主要なメトリックです。
class ESPCNCallback(keras.callbacks.Callback):
    def __init__(self):
        super(ESPCNCallback, self).__init__()
        self.test_img = get_lowres_image(load_img(test_img_paths[0]), upscale_factor)
    # Store PSNR value in each epoch.
    def on_epoch_begin(self, epoch, logs=None):
        self.psnr = []
    def on_epoch_end(self, epoch, logs=None):
        print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))
        if epoch % 20 == 0:
            prediction = upscale_image(self.model, self.test_img)
            plot_results(prediction, "epoch-" + str(epoch), "prediction")
    def on_test_batch_end(self, batch, logs=None):
        self.psnr.append(10 * math.log10(1 / logs["loss"]))
ModelCheckpoint と EarlyStopping コールバックを定義します。
early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)
checkpoint_filepath = "/tmp/checkpoint"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="loss",
    mode="min",
    save_best_only=True,
)
model = get_model(upscale_factor=upscale_factor, channels=1)
model.summary()
callbacks = [ESPCNCallback(), early_stopping_callback, model_checkpoint_callback]
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(learning_rate=0.001)
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, None, None, 1)] 0 _________________________________________________________________ conv2d (Conv2D) (None, None, None, 64) 1664 _________________________________________________________________ conv2d_1 (Conv2D) (None, None, None, 64) 36928 _________________________________________________________________ conv2d_2 (Conv2D) (None, None, None, 32) 18464 _________________________________________________________________ conv2d_3 (Conv2D) (None, None, None, 9) 2601 _________________________________________________________________ tf.nn.depth_to_space (TFOpLa (None, None, None, 1) 0 ================================================================= Total params: 59,657 Trainable params: 59,657 Non-trainable params: 0 _________________________________________________________________
モデルの訓練
epochs = 100
model.compile(
    optimizer=optimizer, loss=loss_fn,
)
model.fit(
    train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds, verbose=2
)
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
(訳者注: ログが長いので以下は実験結果のみを記載します)
Epoch 1/100 Mean PSNR for epoch: 22.1550/50 - 12s - loss: 0.0311 - val_loss: 0.0061 - 12s/epoch - 234ms/step
Epoch 21/100 Mean PSNR for epoch: 26.5350/50 - 3s - loss: 0.0026 - val_loss: 0.0023 - 3s/epoch - 56ms/step
Epoch 41/100 Mean PSNR for epoch: 26.2650/50 - 3s - loss: 0.0026 - val_loss: 0.0024 - 3s/epoch - 53ms/step
Epoch 61/100 Mean PSNR for epoch: 26.7750/50 - 2s - loss: 0.0025 - val_loss: 0.0022 - 2s/epoch - 49ms/step
Epoch 81/100 Mean PSNR for epoch: 26.9150/50 - 3s - loss: 0.0025 - val_loss: 0.0022 - 3s/epoch - 51ms/step
Epoch 100/100 Mean PSNR for epoch: 27.14 50/50 - 2s - loss: 0.0024 - val_loss: 0.0022 - 2s/epoch - 37ms/step CPU times: user 5min 26s, sys: 16.5 s, total: 5min 43s Wall time: 4min
モデル予測の実行と結果のプロット
Let’s compute the reconstructed version of a few images and save the results.
total_bicubic_psnr = 0.0
total_test_psnr = 0.0
for index, test_img_path in enumerate(test_img_paths[50:60]):
    img = load_img(test_img_path)
    lowres_input = get_lowres_image(img, upscale_factor)
    w = lowres_input.size[0] * upscale_factor
    h = lowres_input.size[1] * upscale_factor
    highres_img = img.resize((w, h))
    prediction = upscale_image(model, lowres_input)
    lowres_img = lowres_input.resize((w, h))
    lowres_img_arr = img_to_array(lowres_img)
    highres_img_arr = img_to_array(highres_img)
    predict_img_arr = img_to_array(prediction)
    bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)
    test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)
    total_bicubic_psnr += bicubic_psnr
    total_test_psnr += test_psnr
    print(
        "PSNR of low resolution image and high resolution image is %.4f" % bicubic_psnr
    )
    print("PSNR of predict and high resolution is %.4f" % test_psnr)
    plot_results(lowres_img, index, "lowres")
    plot_results(highres_img, index, "highres")
    plot_results(prediction, index, "prediction")
print("Avg. PSNR of lowres images is %.4f" % (total_bicubic_psnr / 10))
print("Avg. PSNR of reconstructions is %.4f" % (total_test_psnr / 10))
PSNR of low resolution image and high resolution image is 29.8502 PSNR of predict and high resolution is 30.4603


PSNR of low resolution image and high resolution image is 24.9783 PSNR of predict and high resolution is 26.0037


PSNR of low resolution image and high resolution image is 28.0314 PSNR of predict and high resolution is 28.3032


PSNR of low resolution image and high resolution image is 25.7630 PSNR of predict and high resolution is 26.3621


PSNR of low resolution image and high resolution image is 26.2512 PSNR of predict and high resolution is 27.1774


以上