ホーム » TensorFlow » TensorFlow : MNIST データ・ダウンロード (コード解説)

TensorFlow : MNIST データ・ダウンロード (コード解説)

TensorFlow : コード解説 : MNIST データダウンロード

* TensorFlow : Tutorials : MNIST データ・ダウンロード に、数式排除/コード重視の方針で詳細な解説を加筆したものです。

 

MNIST データ

MNIST は機械学習の古典的な分類問題です。0 から 9 までの数字について手書き数字のグレースケール 28×28 ピクセル画像を見て画像がどの数字を表しているかを決定します。

MNIST Digits

詳細は Yann LeCun’s MNIST または Chris Olah’s visualizations of MNIST を参照。

目標とチュートリアル・ファイル

MNIST チュートリアル全体は以下に含まれていますが :
        コード: tensorflow/examples/tutorials/mnist/

ここでの目標は MNIST を使った手書き数字分類に必要なデータセット・ファイルをダウンロードすることで、次のファイルを参照します。
また、必要に応じて引用します :

ファイル

目的

input_data.py

訓練/評価するための MNIST データセットをダウンロードするコード
 

MNIST データファイルのダウンロード

Yann LeCun’s MNIST はダウンロードのための訓練/テスト・データをホストしてくれています。

ファイル

目的

train-images-idx3-ubyte.gz

訓練セット画像 – 55000 訓練画像、5000 検証画像

train-labels-idx1-ubyte.gz

訓練セット・ラベル

t10k-images-idx3-ubyte.gz

テストセット画像 – 10000 画像

t10k-labels-idx1-ubyte.gz

テストセット・ラベル

input_data.py の maybe_download() 関数がこれらのファイルが local にダウンロードされたことを保証します。

SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

def maybe_download(filename, work_directory):
  """(ローカルにまだ落とされていない場合は)Yann's web サイトからデータをダウンロードします。"""
  if not tf.gfile.Exists(work_directory):
    tf.gfile.MakeDirs(work_directory)
  filepath = os.path.join(work_directory, filename)
  if not tf.gfile.Exists(filepath):
    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
    with tf.gfile.GFile(filepath) as f:
      size = f.Size()
    print('Successfully downloaded', filename, size, 'bytes.')
  return filepath

フォルダ名は fully_connected_feed.py の冒頭の flag 変数で指定されています。

flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
 

MNIST データファイルのアンパックと形状変更

ファイルそのものは標準画像フォーマットではありません。
Yann LeCun’s MNIST の後半に掲載されている仕様に従って、input_data.py の extract_images()extract_labels() 関数によりアンパックされます。

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

Pixels are organized row-wise.
Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black). 
def extract_images(filename):
  """画像を 4次元 uint8 の numpy 配列 [index, y, x, depth] に抽出します。"""
  print('Extracting', filename)
  with tf.gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(
          'Invalid magic number %d in MNIST image file: %s' %
          (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data

extract_images() でアンパックされた画像データは、
DataSet クラスで形状変更 (reshape) され 2 次元テンソル: [画像 index, ピクセル index] に抽出されます。

ここで各エントリは特定の画像の特定のピクセルの強度 (intensity) で、[0, 255] から [0.0, 1.0] にスケール変更されています。
「画像 index」はデータセットの画像に該当し、0 からデータセットのサイズまで数えられます。
そして「ピクセル index」はその画像の特定のピクセルに相当し、0 から画像のピクセル数までの範囲を取ります。

class DataSet(object):

  def __init__(self, images, labels, fake_data=False, one_hot=False,
               dtype=tf.float32):
    """Construct a DataSet.
    one_hot arg is used only if fake_data is true.  `dtype` can be either
    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
    `[0, 1]`.
    """
    dtype = tf.as_dtype(dtype).base_dtype
    if dtype not in (tf.uint8, tf.float32):
      raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
                      dtype)
    if fake_data:
      self._num_examples = 10000
      self.one_hot = one_hot
    else:
      assert images.shape[0] == labels.shape[0], (
          'images.shape: %s labels.shape: %s' % (images.shape,
                                                 labels.shape))
      self._num_examples = images.shape[0]

      # 形状 [num examples, rows, columns, depth] から形状 [num examples, rows*columns] (depth == 1 と仮定) に変換。
      assert images.shape[3] == 1
      images = images.reshape(images.shape[0],
                              images.shape[1] * images.shape[2])
      if dtype == tf.float32:
        # [0, 255] -> [0.0, 1.0] へ変換。
        images = images.astype(numpy.float32)
        images = numpy.multiply(images, 1.0 / 255.0)
    self._images = images
    self._labels = labels
    self._epochs_completed = 0
    self._index_in_epoch = 0

ラベル・データは各サンプルのための値としてのクラス識別子で 1 次元テンソル: [画像インデックス] に抽出されます。
訓練セット・ラベルは従って形状 [55000] になります。

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label

The labels values are 0 to 9.
def extract_labels(filename, one_hot=False):
  """ラベルを 1次元 uint8 の numpy 配列 [index] に抽出します。"""
  print('Extracting', filename)
  with tf.gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError(
          'Invalid magic number %d in MNIST label file: %s' %
          (magic, filename))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels)
    return labels
 

MNIST データセット・オブジェクト

read_data_sets() は train-* ファイルの 60000 サンプルを、訓練用の 55000 サンプルと検証用の 5000 サンプルに分割します。
データセットの 28×28 ピクセル・グレースケール画像全てについて画像サイズは 784、そして訓練セット画像のための出力テンソルは 形状 [55000, 784] になります。

def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
  ...
 VALIDATION_SIZE = 5000
  ...
  validation_images = train_images[:VALIDATION_SIZE]
  validation_labels = train_labels[:VALIDATION_SIZE]
  train_images = train_images[VALIDATION_SIZE:]
  train_labels = train_labels[VALIDATION_SIZE:]

結局、input_data.py のコードは、次のデータセットのために画像/ラベルをダウンロードし、アンパックし、形状変更していることになります :

データセット

目的

data_sets.train :

55000 画像とラベル、主要な訓練のため。

data_sets.validation :

5000 画像とラベル、訓練精度の繰り返し検証のため。

data_sets.test :

10000 画像とラベル、訓練精度の最終テストのため。

read_data_sets() 関数は、これらの3つのデータセットそれぞれための DataSet インスタンスで辞書 (dictionary) を返します。DataSet.next_batch() メソッドは、実行中の TensorFlow セッションに供給される画像とラベルの batch_size リストからなるタプルを取得するために使用できます。

class DataSet(object):
  ...

  def next_batch(self, batch_size, fake_data=False):
    """Return the next `batch_size` examples from this data set."""
    if fake_data:
      fake_image = [1] * 784
      if self.one_hot:
        fake_label = [1] + [0] * 9
      else:
        fake_label = 0
      return [fake_image for _ in xrange(batch_size)], [
          fake_label for _ in xrange(batch_size)]
    start = self._index_in_epoch
    self._index_in_epoch += batch_size
    if self._index_in_epoch > self._num_examples:
      # Finished epoch
      self._epochs_completed += 1
      # Shuffle the data
      perm = numpy.arange(self._num_examples)
      numpy.random.shuffle(perm)
      self._images = self._images[perm]
      self._labels = self._labels[perm]
      # Start next epoch
      start = 0
      self._index_in_epoch = batch_size
      assert batch_size <= self._num_examples
    end = self._index_in_epoch
    return self._images[start:end], self._labels[start:end]
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
 

以上

AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com