ホーム » Sonnet » Sonnet 2.0 : Tutorials : MLP で MNIST を予測する

Sonnet 2.0 : Tutorials : MLP で MNIST を予測する

Sonnet 2.0 : Tutorials : MLP で MNIST を予測する (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

MLP で MNIST を予測する

import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"
!pip install dm-sonnet tqdm
Requirement already satisfied: dm-sonnet==2.0.0b0 in /usr/local/lib/python3.6/dist-packages (2.0.0b0)
Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (0.2.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)
Requirement already satisfied: wrapt>=1.11.1 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.11.2)
Requirement already satisfied: numpy>=1.16.3 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.17.2)
Requirement already satisfied: absl-py>=0.7.1 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (0.8.0)
Requirement already satisfied: six>=1.12.0 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.12.0)
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
print("TensorFlow version: {}".format(tf.__version__))
print("    Sonnet version: {}".format(snt.__version__))
TensorFlow version: 2.0.0-rc1
    Sonnet version: 2.0.0b0

最後に利用可能な GPU を素早く見ましょう :

!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'
 Tesla K80

 

データセット

データセットを容易に反復できる状態で得る必要があります。TensorFlow Datasets パッケージはこのために単純な API を提供します。それはデータセットをダウンロードして私達のために GPU 上で速やかに処理できるように準備します。モデルがそれを見る前にデータセットを変更する私達自身の前処理関数を追加することもできます :

batch_size = 100

def process_batch(images, labels):
  images = tf.squeeze(images, axis=[-1])
  images = tf.cast(images, dtype=tf.float32)
  images = ((images / 255.) - .5) * 2.
  return images, labels

def mnist(split):
  dataset = tfds.load("mnist", split=split, as_supervised=True)
  dataset = dataset.map(process_batch)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  return dataset

mnist_train = mnist("train").shuffle(10)
mnist_test = mnist("test")

MNIST は 28×28 グレースケール手書き数字を含みます。一つ見てみましょう :

import matplotlib.pyplot as plt

images, _ = next(iter(mnist_test))
plt.imshow(images[0]);

 

Sonnet

次のステップはモデルを定義することです。Sonnet では TensorFlow variables (tf.Variable) を含む総ては snt.Module を拡張しています、これは低位ニューラルネットワーク・コンポーネント (e.g. snt.Linear, snt.Conv2D)、サブコンポーネントを含むより大きなネット (e.g. snt.nets.MLP)、optimizer (e.g. snt.optimizers.Adam) そして貴方が考えられるものは何でも含みます。

モジュールはパラメータ (そして BatchNorm の移動平均をストアするためのような他の目的のために使用される Variable) をストアするための単純な抽象を提供します。

与えられたモジュールのための総てのパラメータを見つけるには、単純に: module.variables を行ないます。これはこのモジュール、あるいはそれが参照する任意のモジュールのために存在する総てのパラメータのタプルを返します :

 

モデルを構築する

Sonnet では snt.Modules からニューラルネットワークを構築します。この場合多層パーセプトロンを、入力を ReLU 非線形を伴う幾つかの完全結合層を通してロジットを計算する __call__ メソッドを持つ新しいクラスとして構築します。

class MLP(snt.Module):

  def __init__(self):
    super(MLP, self).__init__()
    self.flatten = snt.Flatten()
    self.hidden1 = snt.Linear(1024, name="hidden1")
    self.hidden2 = snt.Linear(1024, name="hidden2")
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images):
    output = self.flatten(images)
    output = tf.nn.relu(self.hidden1(output))
    output = tf.nn.relu(self.hidden2(output))
    output = self.logits(output)
    return output

今はクラスのインスタンスを作成します、その重みはランダムに初期化されます。この MLP をそれが MNIST データセットの数字を認識することを学習するように訓練します。

mlp = MLP()
mlp
MLP()

 

モデルを使用する

サンプル入力をモデルに供給してそれが何を予測するかを見ましょう。モデルはランダムに初期化されましたので、それが正しいクラスを予測する 1/10 のチャンスがあります!

images, labels = next(iter(mnist_test))
logits = mlp(images)
  
prediction = tf.argmax(logits[0]).numpy()
actual = labels[0].numpy()
print("Predicted class: {} actual class: {}".format(prediction, actual))
plt.imshow(images[0]);
Predicted class: 0 actual class: 6

 

モデルを訓練する

モデルを訓練するためには optimizer が必要です。この単純なサンプルのために SGD optimizer で実装されている確率的勾配降下を使用します。勾配を計算するために tf.GradientTape を使用します、これは通して逆伝播することを望む計算のためだけに勾配を選択的に記録することを可能にします :

#@title Utility function to show progress bar.
from tqdm import tqdm

# MNIST training set has 60k images.
num_images = 60000

def progress_bar(generator):
  return tqdm(
      generator,
      unit='images',
      unit_scale=batch_size,
      total=(num_images // batch_size) * num_epochs)
opt = snt.optimizers.SGD(learning_rate=0.1)

num_epochs = 10

def step(images, labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    logits = mlp(images)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=labels)
    loss = tf.reduce_mean(loss)

  params = mlp.trainable_variables
  grads = tape.gradient(loss, params)
  opt.apply(grads, params)
  return loss

for images, labels in progress_bar(mnist_train.repeat(num_epochs)):
  loss = step(images, labels)

print("\n\nFinal loss: {}".format(loss.numpy()))
100%|██████████| 600000/600000 [01:02<00:00, 9660.48images/s] 

Final loss: 0.039747316390275955

 

モデルを評価する

モデルがこのデータセットに対してどのくらい上手くやるのかの感触を得るためにモデルの非常に単純な解析を行ないます :

total = 0
correct = 0
for images, labels in mnist_test:
  predictions = tf.argmax(mlp(images), axis=1)
  correct += tf.math.count_nonzero(tf.equal(predictions, labels))
  total += images.shape[0]

print("Got %d/%d (%.02f%%) correct" % (correct, total, correct / total * 100.))
Got 9767/10000 (97.67%) correct

結果を少しだけより良く理解するために、モデルがどこで数字を正しく識別したかの小さいサンプルを見ましょう :

#@title Utility function to show a sample of images.
def sample(correct, rows, cols):
  n = 0

  f, ax = plt.subplots(rows, cols)
  if rows > 1:    
    ax = tf.nest.flatten([tuple(ax[i]) for i in range(rows)])
  f.set_figwidth(14)
  f.set_figheight(4 * rows)


  for images, labels in mnist_test:
    predictions = tf.argmax(mlp(images), axis=1)
    eq = tf.equal(predictions, labels)
    for i, x in enumerate(eq):
      if x.numpy() == correct:
        label = labels[i]
        prediction = predictions[i]
        image = images[i]

        ax[n].imshow(image)
        ax[n].set_title("Prediction:{}\nActual:{}".format(prediction, label))

        n += 1
        if n == (rows * cols):
          break

    if n == (rows * cols):
      break
sample(correct=True, rows=1, cols=5)

今はそれがどこで入力を間違って分類するかを見ましょう。MNIST は幾つか寧ろ疑わしい手書きを持ちます。下のサンプルの幾つかは少し曖昧であることに貴方は同意することでしょう :

sample(correct=False, rows=2, cols=5)

 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com