TensorFlow : コード解説 : MNIST データダウンロード
* TensorFlow : Tutorials : MNIST データ・ダウンロード に、数式排除/コード重視の方針で詳細な解説を加筆したものです。
MNIST データ
MNIST は機械学習の古典的な分類問題です。0 から 9 までの数字について手書き数字のグレースケール 28×28 ピクセル画像を見て画像がどの数字を表しているかを決定します。

詳細は Yann LeCun’s MNIST または Chris Olah’s visualizations of MNIST を参照。
目標とチュートリアル・ファイル
MNIST チュートリアル全体は以下に含まれていますが :
コード: tensorflow/examples/tutorials/mnist/
ここでの目標は MNIST を使った手書き数字分類に必要なデータセット・ファイルをダウンロードすることで、次のファイルを参照します。
また、必要に応じて引用します :
ファイル | 目的 |
---|---|
input_data.py | 訓練/評価するための MNIST データセットをダウンロードするコード |
MNIST データファイルのダウンロード
Yann LeCun’s MNIST はダウンロードのための訓練/テスト・データをホストしてくれています。
ファイル | 目的 |
---|---|
train-images-idx3-ubyte.gz | 訓練セット画像 – 55000 訓練画像、5000 検証画像 |
train-labels-idx1-ubyte.gz | 訓練セット・ラベル |
t10k-images-idx3-ubyte.gz | テストセット画像 – 10000 画像 |
t10k-labels-idx1-ubyte.gz | テストセット・ラベル |
input_data.py の maybe_download() 関数がこれらのファイルが local にダウンロードされたことを保証します。
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' def maybe_download(filename, work_directory): """(ローカルにまだ落とされていない場合は)Yann's web サイトからデータをダウンロードします。""" if not tf.gfile.Exists(work_directory): tf.gfile.MakeDirs(work_directory) filepath = os.path.join(work_directory, filename) if not tf.gfile.Exists(filepath): filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) with tf.gfile.GFile(filepath) as f: size = f.Size() print('Successfully downloaded', filename, size, 'bytes.') return filepath
フォルダ名は fully_connected_feed.py の冒頭の flag 変数で指定されています。
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
MNIST データファイルのアンパックと形状変更
ファイルそのものは標準画像フォーマットではありません。
Yann LeCun’s MNIST の後半に掲載されている仕様に従って、input_data.py の extract_images() と extract_labels() 関数によりアンパックされます。
[offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 60000 number of images 0008 32 bit integer 28 number of rows 0012 32 bit integer 28 number of columns 0016 unsigned byte ?? pixel 0017 unsigned byte ?? pixel ........ xxxx unsigned byte ?? pixel Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
def extract_images(filename): """画像を 4次元 uint8 の numpy 配列 [index, y, x, depth] に抽出します。""" print('Extracting', filename) with tf.gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2051: raise ValueError( 'Invalid magic number %d in MNIST image file: %s' % (magic, filename)) num_images = _read32(bytestream) rows = _read32(bytestream) cols = _read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = numpy.frombuffer(buf, dtype=numpy.uint8) data = data.reshape(num_images, rows, cols, 1) return data
extract_images() でアンパックされた画像データは、
DataSet クラスで形状変更 (reshape) され 2 次元テンソル: [画像 index, ピクセル index] に抽出されます。
ここで各エントリは特定の画像の特定のピクセルの強度 (intensity) で、[0, 255] から [0.0, 1.0] にスケール変更されています。
「画像 index」はデータセットの画像に該当し、0 からデータセットのサイズまで数えられます。
そして「ピクセル index」はその画像の特定のピクセルに相当し、0 から画像のピクセル数までの範囲を取ります。
class DataSet(object): def __init__(self, images, labels, fake_data=False, one_hot=False, dtype=tf.float32): """Construct a DataSet. one_hot arg is used only if fake_data is true. `dtype` can be either `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into `[0, 1]`. """ dtype = tf.as_dtype(dtype).base_dtype if dtype not in (tf.uint8, tf.float32): raise TypeError('Invalid image dtype %r, expected uint8 or float32' % dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot else: assert images.shape[0] == labels.shape[0], ( 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) self._num_examples = images.shape[0] # 形状 [num examples, rows, columns, depth] から形状 [num examples, rows*columns] (depth == 1 と仮定) に変換。 assert images.shape[3] == 1 images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]) if dtype == tf.float32: # [0, 255] -> [0.0, 1.0] へ変換。 images = images.astype(numpy.float32) images = numpy.multiply(images, 1.0 / 255.0) self._images = images self._labels = labels self._epochs_completed = 0 self._index_in_epoch = 0
ラベル・データは各サンプルのための値としてのクラス識別子で 1 次元テンソル: [画像インデックス] に抽出されます。
訓練セット・ラベルは従って形状 [55000] になります。
[offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 60000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label The labels values are 0 to 9.
def extract_labels(filename, one_hot=False): """ラベルを 1次元 uint8 の numpy 配列 [index] に抽出します。""" print('Extracting', filename) with tf.gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2049: raise ValueError( 'Invalid magic number %d in MNIST label file: %s' % (magic, filename)) num_items = _read32(bytestream) buf = bytestream.read(num_items) labels = numpy.frombuffer(buf, dtype=numpy.uint8) if one_hot: return dense_to_one_hot(labels) return labels
MNIST データセット・オブジェクト
read_data_sets() は train-* ファイルの 60000 サンプルを、訓練用の 55000 サンプルと検証用の 5000 サンプルに分割します。
データセットの 28×28 ピクセル・グレースケール画像全てについて画像サイズは 784、そして訓練セット画像のための出力テンソルは 形状 [55000, 784] になります。
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): ... VALIDATION_SIZE = 5000 ... validation_images = train_images[:VALIDATION_SIZE] validation_labels = train_labels[:VALIDATION_SIZE] train_images = train_images[VALIDATION_SIZE:] train_labels = train_labels[VALIDATION_SIZE:]
結局、input_data.py のコードは、次のデータセットのために画像/ラベルをダウンロードし、アンパックし、形状変更していることになります :
データセット | 目的 |
---|---|
data_sets.train : | 55000 画像とラベル、主要な訓練のため。 |
data_sets.validation : | 5000 画像とラベル、訓練精度の繰り返し検証のため。 |
data_sets.test : | 10000 画像とラベル、訓練精度の最終テストのため。 |
read_data_sets() 関数は、これらの3つのデータセットそれぞれための DataSet インスタンスで辞書 (dictionary) を返します。DataSet.next_batch() メソッドは、実行中の TensorFlow セッションに供給される画像とラベルの batch_size リストからなるタプルを取得するために使用できます。
class DataSet(object): ... def next_batch(self, batch_size, fake_data=False): """Return the next `batch_size` examples from this data set.""" if fake_data: fake_image = [1] * 784 if self.one_hot: fake_label = [1] + [0] * 9 else: fake_label = 0 return [fake_image for _ in xrange(batch_size)], [ fake_label for _ in xrange(batch_size)] start = self._index_in_epoch self._index_in_epoch += batch_size if self._index_in_epoch > self._num_examples: # Finished epoch self._epochs_completed += 1 # Shuffle the data perm = numpy.arange(self._num_examples) numpy.random.shuffle(perm) self._images = self._images[perm] self._labels = self._labels[perm] # Start next epoch start = 0 self._index_in_epoch = batch_size assert batch_size <= self._num_examples end = self._index_in_epoch return self._images[start:end], self._labels[start:end]
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
以上