ホーム » 「Sonnet 2.0」タグがついた投稿

タグアーカイブ: Sonnet 2.0

Sonnet 2.0 : Tutorials : VQ-VAE 訓練サンプル

Sonnet 2.0 : Tutorials : VQ-VAE 訓練サンプル (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

VQ-VAE 訓練サンプル

TF 2 / Sonnet 2 を使用して、https://arxiv.org/abs/1711.00937 で指定されるモデルをどのように訓練するかの実演です。

Mac と Linux 上、単純に各セルを順番に実行してください。

!pip install dm-sonnet dm-tree
Requirement already satisfied: dm-sonnet in /tmp/sonnet-nb-env/lib/python3.7/site-packages (2.0.0)
Requirement already satisfied: dm-tree in /tmp/sonnet-nb-env/lib/python3.7/site-packages (0.1.5)
Requirement already satisfied: six>=1.12.0 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.14.0)
Requirement already satisfied: tabulate>=0.7.5 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.8.7)
Requirement already satisfied: absl-py>=0.7.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.9.0)
Requirement already satisfied: numpy>=1.16.3 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.18.3)
Requirement already satisfied: wrapt>=1.11.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.12.1)
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree

try:
  import sonnet.v2 as snt
  tf.enable_v2_behavior()
except ImportError:
  import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))
TensorFlow version 2.1.0
Sonnet version 2.0.0

 

Cifar10 データをダウンロードする

これはインターネットへの接続を必要として ~160MB をダウンロードします。

cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1))
cifar10.pop("id", None)
cifar10.pop("label")
tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)
{'image': 'uint8[60000, 32, 32, 3]'}

 

データを Numpy にロードする

下で平均二乗誤差を正規化するために訓練セット全体の分散を計算します。

train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10)
valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10)
test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)
def cast_and_normalise_images(data_dict):
  """Convert images to floating point with the range [-0.5, 0.5]"""
  images = data_dict['image']
  data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
  return data_dict

train_data_variance = np.var(train_data_dict['image'] / 255.0)
print('train data variance: %s' % train_data_variance)
train data variance: 0.06327039811675479

 

エンコーダ & デコーダ・アーキテクチャ

  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(ResidualStack, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._layers = []
    for i in range(num_residual_layers):
      conv3 = snt.Conv2D(
          output_channels=num_residual_hiddens,
          kernel_shape=(3, 3),
          stride=(1, 1),
          name="res3x3_%d" % i)
      conv1 = snt.Conv2D(
          output_channels=num_hiddens,
          kernel_shape=(1, 1),
          stride=(1, 1),
          name="res1x1_%d" % i)
      self._layers.append((conv3, conv1))

  def __call__(self, inputs):
    h = inputs
    for conv3, conv1 in self._layers:
      conv3_out = conv3(tf.nn.relu(h))
      conv1_out = conv1(tf.nn.relu(conv3_out))
      h += conv1_out
    return tf.nn.relu(h)  # Resnet V1 style


class Encoder(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Encoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._enc_1 = snt.Conv2D(
        output_channels=self._num_hiddens // 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_1")
    self._enc_2 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_2")
    self._enc_3 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="enc_3")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)

  def __call__(self, x):
    h = tf.nn.relu(self._enc_1(x))
    h = tf.nn.relu(self._enc_2(h))
    h = tf.nn.relu(self._enc_3(h))
    return self._residual_stack(h)


class Decoder(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Decoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._dec_1 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="dec_1")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)
    self._dec_2 = snt.Conv2DTranspose(
        output_channels=self._num_hiddens // 2,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_2")
    self._dec_3 = snt.Conv2DTranspose(
        output_channels=3,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_3")
    
  def __call__(self, x):
    h = self._dec_1(x)
    h = self._residual_stack(h)
    h = tf.nn.relu(self._dec_2(h))
    x_recon = self._dec_3(h)
    return x_recon
    

class VQVAEModel(snt.Module):
  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1, 
               data_variance, name=None):
    super(VQVAEModel, self).__init__(name=name)
    self._encoder = encoder
    self._decoder = decoder
    self._vqvae = vqvae
    self._pre_vq_conv1 = pre_vq_conv1
    self._data_variance = data_variance

  def __call__(self, inputs, is_training):
    z = self._pre_vq_conv1(self._encoder(inputs))
    vq_output = self._vqvae(z, is_training=is_training)
    x_recon = self._decoder(vq_output['quantize'])
    recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance
    loss = recon_error + vq_output['loss']
    return {
        'z': z,
        'x_recon': x_recon,
        'loss': loss,
        'recon_error': recon_error,
        'vq_output': vq_output,
    }

 

モデルを構築して訓練する

%%time

# Set hyper-parameters.
batch_size = 32
image_size = 32

# 100k steps should take < 30 minutes on a modern (>= 2017) GPU.
# 10k steps gives reasonable accuracy with VQVAE on Cifar10.
num_training_updates = 10000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
# These hyper-parameters define the size of the model (number of parameters and layers).
# The hyper-parameters in the paper were (For ImageNet):
# batch_size = 128
# image_size = 128
# num_hiddens = 128
# num_residual_hiddens = 32
# num_residual_layers = 2

# This value is not that important, usually 64 works.
# This will not change the capacity in the information-bottleneck.
embedding_dim = 64

# The higher this value, the higher the capacity in the information bottleneck.
num_embeddings = 512

# commitment_cost should be set appropriately. It's often useful to try a couple
# of values. It mostly depends on the scale of the reconstruction cost
# (log p(x|z)). So if the reconstruction cost is 100x higher, the
# commitment_cost should also be multiplied with the same amount.
commitment_cost = 0.25

# Use EMA updates for the codebook (instead of the Adam optimizer).
# This typically converges faster, and makes the model less dependent on choice
# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was
# developed afterwards). See Appendix of the paper for more details.
vq_use_ema = True

# This is only used for EMA updates.
decay = 0.99

learning_rate = 3e-4


# # Data Loading.
train_dataset = (
    tf.data.Dataset.from_tensor_slices(train_data_dict)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)  # repeat indefinitely
    .batch(batch_size, drop_remainder=True)
    .prefetch(-1))

valid_dataset = (
    tf.data.Dataset.from_tensor_slices(valid_data_dict)
    .map(cast_and_normalise_images)
    .repeat(1)  # 1 epoch
    .batch(batch_size)
    .prefetch(-1))

# # Build modules.
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,
    kernel_shape=(1, 1),
    stride=(1, 1),
    name="to_vq")

if vq_use_ema:
  vq_vae = snt.nets.VectorQuantizerEMA(
      embedding_dim=embedding_dim,
      num_embeddings=num_embeddings,
      commitment_cost=commitment_cost,
      decay=decay)
else:
  vq_vae = snt.nets.VectorQuantizer(
      embedding_dim=embedding_dim,
      num_embeddings=num_embeddings,
      commitment_cost=commitment_cost)
  
model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
                   data_variance=train_data_variance)

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

@tf.function
def train_step(data):
  with tf.GradientTape() as tape:
    model_output = model(data['image'], is_training=True)
  trainable_variables = model.trainable_variables
  grads = tape.gradient(model_output['loss'], trainable_variables)
  optimizer.apply(grads, trainable_variables)

  return model_output

train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []

for step_index, data in enumerate(train_dataset):
  train_results = train_step(data)
  train_losses.append(train_results['loss'])
  train_recon_errors.append(train_results['recon_error'])
  train_perplexities.append(train_results['vq_output']['perplexity'])
  train_vqvae_loss.append(train_results['vq_output']['loss'])

  if (step_index + 1) % 100 == 0:
    print('%d train loss: %f ' % (step_index + 1,
                                   np.mean(train_losses[-100:])) +
          ('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +
          ('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +
          ('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))
  if step_index == num_training_updates:
    break
WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING: AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
100 train loss: 0.523625 recon_error: 0.483 perplexity: 10.356 vqvae loss: 0.041
200 train loss: 0.248232 recon_error: 0.223 perplexity: 18.294 vqvae loss: 0.026
300 train loss: 0.215068 recon_error: 0.190 perplexity: 23.106 vqvae loss: 0.025
400 train loss: 0.191891 recon_error: 0.164 perplexity: 29.139 vqvae loss: 0.028
500 train loss: 0.180945 recon_error: 0.147 perplexity: 34.253 vqvae loss: 0.033
600 train loss: 0.167115 recon_error: 0.134 perplexity: 39.961 vqvae loss: 0.033
700 train loss: 0.157724 recon_error: 0.124 perplexity: 46.521 vqvae loss: 0.033
800 train loss: 0.153761 recon_error: 0.119 perplexity: 53.559 vqvae loss: 0.035
900 train loss: 0.145033 recon_error: 0.112 perplexity: 62.442 vqvae loss: 0.033
1000 train loss: 0.137589 recon_error: 0.105 perplexity: 71.831 vqvae loss: 0.033
1100 train loss: 0.133044 recon_error: 0.101 perplexity: 79.135 vqvae loss: 0.032
1200 train loss: 0.129990 recon_error: 0.098 perplexity: 87.959 vqvae loss: 0.032
1300 train loss: 0.126507 recon_error: 0.095 perplexity: 96.704 vqvae loss: 0.031
1400 train loss: 0.122403 recon_error: 0.092 perplexity: 104.202 vqvae loss: 0.031
1500 train loss: 0.122003 recon_error: 0.091 perplexity: 112.476 vqvae loss: 0.031
1600 train loss: 0.120192 recon_error: 0.089 perplexity: 122.269 vqvae loss: 0.032
1700 train loss: 0.117041 recon_error: 0.086 perplexity: 129.887 vqvae loss: 0.031
1800 train loss: 0.115004 recon_error: 0.083 perplexity: 138.603 vqvae loss: 0.032
1900 train loss: 0.114134 recon_error: 0.082 perplexity: 147.545 vqvae loss: 0.032
2000 train loss: 0.112840 recon_error: 0.081 perplexity: 153.993 vqvae loss: 0.032
2100 train loss: 0.108815 recon_error: 0.077 perplexity: 161.729 vqvae loss: 0.031
2200 train loss: 0.108596 recon_error: 0.078 perplexity: 171.971 vqvae loss: 0.031
2300 train loss: 0.108132 recon_error: 0.077 perplexity: 181.157 vqvae loss: 0.031
2400 train loss: 0.106273 recon_error: 0.076 perplexity: 186.200 vqvae loss: 0.031
2500 train loss: 0.105936 recon_error: 0.075 perplexity: 194.301 vqvae loss: 0.031
2600 train loss: 0.103880 recon_error: 0.073 perplexity: 201.674 vqvae loss: 0.030
2700 train loss: 0.101655 recon_error: 0.072 perplexity: 207.131 vqvae loss: 0.030
2800 train loss: 0.102564 recon_error: 0.072 perplexity: 216.983 vqvae loss: 0.030
2900 train loss: 0.101613 recon_error: 0.072 perplexity: 219.649 vqvae loss: 0.030
3000 train loss: 0.101227 recon_error: 0.071 perplexity: 226.789 vqvae loss: 0.030
3100 train loss: 0.100786 recon_error: 0.071 perplexity: 235.522 vqvae loss: 0.030
3200 train loss: 0.100130 recon_error: 0.070 perplexity: 243.282 vqvae loss: 0.030
3300 train loss: 0.097764 recon_error: 0.067 perplexity: 249.584 vqvae loss: 0.030
3400 train loss: 0.100630 recon_error: 0.069 perplexity: 260.551 vqvae loss: 0.031
3500 train loss: 0.099929 recon_error: 0.068 perplexity: 266.012 vqvae loss: 0.032
3600 train loss: 0.099245 recon_error: 0.067 perplexity: 272.031 vqvae loss: 0.032
3700 train loss: 0.097812 recon_error: 0.066 perplexity: 279.691 vqvae loss: 0.032
3800 train loss: 0.097137 recon_error: 0.064 perplexity: 284.240 vqvae loss: 0.033
3900 train loss: 0.099217 recon_error: 0.066 perplexity: 293.507 vqvae loss: 0.034
4000 train loss: 0.098570 recon_error: 0.065 perplexity: 300.891 vqvae loss: 0.034
4100 train loss: 0.099238 recon_error: 0.065 perplexity: 306.762 vqvae loss: 0.034
4200 train loss: 0.098172 recon_error: 0.064 perplexity: 311.918 vqvae loss: 0.035
4300 train loss: 0.096449 recon_error: 0.063 perplexity: 316.246 vqvae loss: 0.034
4400 train loss: 0.096487 recon_error: 0.062 perplexity: 319.591 vqvae loss: 0.034
4500 train loss: 0.096092 recon_error: 0.062 perplexity: 322.313 vqvae loss: 0.034
4600 train loss: 0.096474 recon_error: 0.062 perplexity: 324.620 vqvae loss: 0.035
4700 train loss: 0.097075 recon_error: 0.063 perplexity: 324.357 vqvae loss: 0.035
4800 train loss: 0.094709 recon_error: 0.060 perplexity: 326.024 vqvae loss: 0.034
4900 train loss: 0.096557 recon_error: 0.061 perplexity: 327.701 vqvae loss: 0.035
5000 train loss: 0.096185 recon_error: 0.061 perplexity: 326.664 vqvae loss: 0.035
5100 train loss: 0.095646 recon_error: 0.060 perplexity: 327.617 vqvae loss: 0.035
5200 train loss: 0.094689 recon_error: 0.059 perplexity: 328.692 vqvae loss: 0.035
5300 train loss: 0.097047 recon_error: 0.061 perplexity: 327.988 vqvae loss: 0.036
5400 train loss: 0.096259 recon_error: 0.060 perplexity: 327.075 vqvae loss: 0.036
5500 train loss: 0.094588 recon_error: 0.059 perplexity: 327.083 vqvae loss: 0.036
5600 train loss: 0.095947 recon_error: 0.060 perplexity: 328.213 vqvae loss: 0.036
5700 train loss: 0.095466 recon_error: 0.059 perplexity: 329.375 vqvae loss: 0.036
5800 train loss: 0.094849 recon_error: 0.059 perplexity: 326.821 vqvae loss: 0.036
5900 train loss: 0.093799 recon_error: 0.058 perplexity: 328.409 vqvae loss: 0.036
6000 train loss: 0.095373 recon_error: 0.059 perplexity: 326.791 vqvae loss: 0.036
6100 train loss: 0.093989 recon_error: 0.059 perplexity: 325.959 vqvae loss: 0.035
6200 train loss: 0.095549 recon_error: 0.059 perplexity: 330.829 vqvae loss: 0.036
6300 train loss: 0.094730 recon_error: 0.058 perplexity: 330.906 vqvae loss: 0.036
6400 train loss: 0.095038 recon_error: 0.058 perplexity: 329.353 vqvae loss: 0.037
6500 train loss: 0.095891 recon_error: 0.059 perplexity: 330.197 vqvae loss: 0.037
6600 train loss: 0.094342 recon_error: 0.058 perplexity: 331.240 vqvae loss: 0.036
6700 train loss: 0.095096 recon_error: 0.058 perplexity: 330.618 vqvae loss: 0.037
6800 train loss: 0.095581 recon_error: 0.059 perplexity: 324.493 vqvae loss: 0.037
6900 train loss: 0.094467 recon_error: 0.058 perplexity: 328.868 vqvae loss: 0.037
7000 train loss: 0.092967 recon_error: 0.057 perplexity: 328.276 vqvae loss: 0.036
7100 train loss: 0.094339 recon_error: 0.058 perplexity: 327.318 vqvae loss: 0.037
7200 train loss: 0.095227 recon_error: 0.058 perplexity: 326.306 vqvae loss: 0.037
7300 train loss: 0.093832 recon_error: 0.057 perplexity: 328.262 vqvae loss: 0.037
7400 train loss: 0.093331 recon_error: 0.057 perplexity: 327.987 vqvae loss: 0.037
7500 train loss: 0.094718 recon_error: 0.058 perplexity: 328.948 vqvae loss: 0.037
7600 train loss: 0.094199 recon_error: 0.058 perplexity: 328.468 vqvae loss: 0.037
7700 train loss: 0.094603 recon_error: 0.058 perplexity: 327.501 vqvae loss: 0.037
7800 train loss: 0.092299 recon_error: 0.056 perplexity: 327.630 vqvae loss: 0.037
7900 train loss: 0.095228 recon_error: 0.058 perplexity: 329.946 vqvae loss: 0.037
8000 train loss: 0.094291 recon_error: 0.058 perplexity: 326.790 vqvae loss: 0.037
8100 train loss: 0.094481 recon_error: 0.057 perplexity: 328.667 vqvae loss: 0.037
8200 train loss: 0.093992 recon_error: 0.057 perplexity: 329.655 vqvae loss: 0.037
8300 train loss: 0.093976 recon_error: 0.057 perplexity: 323.950 vqvae loss: 0.037
8400 train loss: 0.093422 recon_error: 0.057 perplexity: 324.523 vqvae loss: 0.036
8500 train loss: 0.092898 recon_error: 0.056 perplexity: 325.402 vqvae loss: 0.037
8600 train loss: 0.094298 recon_error: 0.057 perplexity: 329.251 vqvae loss: 0.037
8700 train loss: 0.094489 recon_error: 0.057 perplexity: 331.027 vqvae loss: 0.037
8800 train loss: 0.093022 recon_error: 0.056 perplexity: 327.495 vqvae loss: 0.037
8900 train loss: 0.093427 recon_error: 0.057 perplexity: 328.008 vqvae loss: 0.037
9000 train loss: 0.094884 recon_error: 0.058 perplexity: 327.057 vqvae loss: 0.037
9100 train loss: 0.093559 recon_error: 0.056 perplexity: 331.800 vqvae loss: 0.037
9200 train loss: 0.093282 recon_error: 0.056 perplexity: 328.689 vqvae loss: 0.037
9300 train loss: 0.092217 recon_error: 0.056 perplexity: 323.903 vqvae loss: 0.036
9400 train loss: 0.093902 recon_error: 0.057 perplexity: 326.350 vqvae loss: 0.037
9500 train loss: 0.093772 recon_error: 0.057 perplexity: 325.627 vqvae loss: 0.037
9600 train loss: 0.093123 recon_error: 0.056 perplexity: 327.352 vqvae loss: 0.037
9700 train loss: 0.092934 recon_error: 0.056 perplexity: 328.674 vqvae loss: 0.037
9800 train loss: 0.093284 recon_error: 0.056 perplexity: 329.437 vqvae loss: 0.037
9900 train loss: 0.094147 recon_error: 0.057 perplexity: 330.146 vqvae loss: 0.037
10000 train loss: 0.092876 recon_error: 0.056 perplexity: 326.349 vqvae loss: 0.037
CPU times: user 1h 47min 46s, sys: 14min 12s, total: 2h 1min 59s
Wall time: 4min 29s

 

損失をプロットする

f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')

ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')
Text(0.5, 1.0, 'Average codebook usage (perplexity).')

 

再構築を見る

# Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))

# Put data through the model with is_training=False, so that in the case of 
# using EMA the codebook is not updated.
train_reconstructions = model(train_batch['image'],
                              is_training=False)['x_recon'].numpy()
valid_reconstructions = model(valid_batch['image'],
                              is_training=False)['x_recon'].numpy()


def convert_batch_to_image_grid(image_batch):
  reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
              .transpose(0, 2, 1, 3, 4)
              .reshape(4 * 32, 8 * 32, 3))
  return reshaped + 0.5



f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['image'].numpy()),
          interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')

ax = f.add_subplot(2,2,2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions),
          interpolation='nearest')
ax.set_title('training data reconstructions')
plt.axis('off')

ax = f.add_subplot(2,2,3)
ax.imshow(convert_batch_to_image_grid(valid_batch['image'].numpy()),
          interpolation='nearest')
ax.set_title('validation data originals')
plt.axis('off')

ax = f.add_subplot(2,2,4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions),
          interpolation='nearest')
ax.set_title('validation data reconstructions')
plt.axis('off')
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 255.5, 127.5, -0.5)

 

以上






Sonnet 2.0 : Tutorials : snt.distribute で分散訓練 (CIFAR-10)

Sonnet 2.0 : Tutorials : snt.distribute で分散訓練 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

snt.distribute で分散訓練

イントロダクション

このチュートリアルは Sonnet 2 “Hello, world!” サンプル (MLP on MNIST) を既に完了していることを仮定しています。

このチュートリアルでは、より大きなモデルとより大きなデータセットで物事をスケールアップしていきます、そして計算をマルチデバイスに渡り分散していきます。

 

import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"
!pip install dm-sonnet tqdm
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
print("TensorFlow version: {}".format(tf.__version__))
print("    Sonnet version: {}".format(snt.__version__))

最後に利用可能な GPU を素早く見ましょう :

!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'

 

分散ストラテジー

幾つかのデバイスに渡り計算を分散するためのストラテジーが必要です。Google Colab は単一 GPU を提供するだけですのでそれを 4 つの仮想 GPU に分割します :

physical_gpus = tf.config.experimental.list_physical_devices("GPU")
physical_gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
tf.config.experimental.set_virtual_device_configuration(
    physical_gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4
)
gpus = tf.config.experimental.list_logical_devices("GPU")
gpus
[LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:1', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:2', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:3', device_type='GPU')]

Sonnet optimizer を使用するとき、snt.distribute からの Replicator か TpuReplicator を利用しなければあんりません、あるいは tf.distribute.OneDeviceStrategy を利用できます。Replicator は MirroredStrategy と等値でそして TpuReplicator は TPUStrategy と等値です。

strategy = snt.distribute.Replicator(
    ["/device:GPU:{}".format(i) for i in range(4)],
    tf.distribute.ReductionToOneDevice("GPU:0"))

 

データセット

基本的には MNIST サンプルと同じですが、今回は CIFAR-10 を使用しています。CIFAR-10 は 10 の異なるクラス (飛行機、自動車、鳥、猫、鹿、犬、蛙、馬、船そしてトラック) にある 32×32 ピクセルカラー画像を含みます。

# NOTE: This is the batch size across all GPUs.
batch_size = 100 * 4

def process_batch(images, labels):
  images = tf.cast(images, dtype=tf.float32)
  images = ((images / 255.) - .5) * 2.
  return images, labels

def cifar10(split):
  dataset = tfds.load("cifar10", split=split, as_supervised=True)
  dataset = dataset.map(process_batch)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  return dataset

cifar10_train = cifar10("train").shuffle(10)
cifar10_test = cifar10("test")

 

モデル & Optimizer

都合良く、snt.nets にこのデータセットのために特に設計された事前ビルドされたモデルがあります。

作成された任意の変数が正しく分散されることを確実にするために、モデルと optimizer は strategy スコープ内で構築しなければなりません。代わりに、tf.distribute.experimental_set_strategy を使用してプログラム全体のためのスコープに入ることもでできるでしょう。

learning_rate = 0.1

with strategy.scope():
  model = snt.nets.Cifar10ConvNet()
  optimizer = snt.optimizers.Momentum(learning_rate, 0.9)

 

モデルを訓練する

Sonnet optimizer はできる限り綺麗でそして単純であるように設計されています。それらは分散実行を扱うためのどのようなコードも含みません。従ってそれはコードの 2, 3 の追加行を必要とします。

異なるデバイス上で計算された勾配を集めなければなりません。これは ReplicaContext.all_reduce を使用して成されます。

Replicator / TpuReplicator を使用するとき values が総てのレプリカで同一で在り続けることを確かなものにすることはユーザの責任であることに注意してください。

def step(images, labels):
  """Performs a single training step, returning the cross-entropy loss."""
  with tf.GradientTape() as tape:
    logits = model(images, is_training=True)["logits"]
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                       logits=logits))

  grads = tape.gradient(loss, model.trainable_variables)

  # Aggregate the gradients from the full batch.
  replica_ctx = tf.distribute.get_replica_context()
  grads = replica_ctx.all_reduce("mean", grads)

  optimizer.apply(grads, model.trainable_variables)
  return loss

@tf.function
def train_step(images, labels):
  per_replica_loss = strategy.run(step, args=(images, labels))
  return strategy.reduce("sum", per_replica_loss, axis=None)

def train_epoch(dataset):
  """Performs one epoch of training, returning the mean cross-entropy loss."""
  total_loss = 0.0
  num_batches = 0

  # Loop over the entire training set.
  for images, labels in dataset:
    total_loss += train_step(images, labels).numpy()
    num_batches += 1

  return total_loss / num_batches

cifar10_train_dist = strategy.experimental_distribute_dataset(cifar10_train)

for epoch in range(20):
  print("Training epoch", epoch, "...", end=" ")
  print("loss :=", train_epoch(cifar10_train_dist))

 

モデルを評価する

バッチ次元に渡り削減するために strategy.reduce による axis パラメータの使用方法に注意してください。

num_cifar10_test_examples = 10000

def is_predicted(images, labels):
  logits = model(images, is_training=False)["logits"]
  # The reduction over the batch happens in `strategy.reduce`, below.
  return tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.int32)

cifar10_test_dist = strategy.experimental_distribute_dataset(cifar10_test)

@tf.function
def evaluate():
  """Returns the top-1 accuracy over the entire test set."""
  total_correct = 0

  for images, labels in cifar10_test_dist:
    per_replica_correct = strategy.run(is_predicted, args=(images, labels))
    total_correct += strategy.reduce("sum", per_replica_correct, axis=0)

  return tf.cast(total_correct, tf.float32) / num_cifar10_test_examples

print("Testing...", end=" ")
print("top-1 accuracy =", evaluate().numpy())
 

以上






Sonnet 2.0 : Tutorials : MLP で MNIST を予測する

Sonnet 2.0 : Tutorials : MLP で MNIST を予測する (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

MLP で MNIST を予測する

import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"
!pip install dm-sonnet tqdm
Requirement already satisfied: dm-sonnet==2.0.0b0 in /usr/local/lib/python3.6/dist-packages (2.0.0b0)
Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (0.2.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)
Requirement already satisfied: wrapt>=1.11.1 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.11.2)
Requirement already satisfied: numpy>=1.16.3 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.17.2)
Requirement already satisfied: absl-py>=0.7.1 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (0.8.0)
Requirement already satisfied: six>=1.12.0 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.12.0)
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
print("TensorFlow version: {}".format(tf.__version__))
print("    Sonnet version: {}".format(snt.__version__))
TensorFlow version: 2.0.0-rc1
    Sonnet version: 2.0.0b0

最後に利用可能な GPU を素早く見ましょう :

!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'
 Tesla K80

 

データセット

データセットを容易に反復できる状態で得る必要があります。TensorFlow Datasets パッケージはこのために単純な API を提供します。それはデータセットをダウンロードして私達のために GPU 上で速やかに処理できるように準備します。モデルがそれを見る前にデータセットを変更する私達自身の前処理関数を追加することもできます :

batch_size = 100

def process_batch(images, labels):
  images = tf.squeeze(images, axis=[-1])
  images = tf.cast(images, dtype=tf.float32)
  images = ((images / 255.) - .5) * 2.
  return images, labels

def mnist(split):
  dataset = tfds.load("mnist", split=split, as_supervised=True)
  dataset = dataset.map(process_batch)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  return dataset

mnist_train = mnist("train").shuffle(10)
mnist_test = mnist("test")

MNIST は 28×28 グレースケール手書き数字を含みます。一つ見てみましょう :

import matplotlib.pyplot as plt

images, _ = next(iter(mnist_test))
plt.imshow(images[0]);

 

Sonnet

次のステップはモデルを定義することです。Sonnet では TensorFlow variables (tf.Variable) を含む総ては snt.Module を拡張しています、これは低位ニューラルネットワーク・コンポーネント (e.g. snt.Linear, snt.Conv2D)、サブコンポーネントを含むより大きなネット (e.g. snt.nets.MLP)、optimizer (e.g. snt.optimizers.Adam) そして貴方が考えられるものは何でも含みます。

モジュールはパラメータ (そして BatchNorm の移動平均をストアするためのような他の目的のために使用される Variable) をストアするための単純な抽象を提供します。

与えられたモジュールのための総てのパラメータを見つけるには、単純に: module.variables を行ないます。これはこのモジュール、あるいはそれが参照する任意のモジュールのために存在する総てのパラメータのタプルを返します :

 

モデルを構築する

Sonnet では snt.Modules からニューラルネットワークを構築します。この場合多層パーセプトロンを、入力を ReLU 非線形を伴う幾つかの完全結合層を通してロジットを計算する __call__ メソッドを持つ新しいクラスとして構築します。

class MLP(snt.Module):

  def __init__(self):
    super(MLP, self).__init__()
    self.flatten = snt.Flatten()
    self.hidden1 = snt.Linear(1024, name="hidden1")
    self.hidden2 = snt.Linear(1024, name="hidden2")
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images):
    output = self.flatten(images)
    output = tf.nn.relu(self.hidden1(output))
    output = tf.nn.relu(self.hidden2(output))
    output = self.logits(output)
    return output

今はクラスのインスタンスを作成します、その重みはランダムに初期化されます。この MLP をそれが MNIST データセットの数字を認識することを学習するように訓練します。

mlp = MLP()
mlp
MLP()

 

モデルを使用する

サンプル入力をモデルに供給してそれが何を予測するかを見ましょう。モデルはランダムに初期化されましたので、それが正しいクラスを予測する 1/10 のチャンスがあります!

images, labels = next(iter(mnist_test))
logits = mlp(images)
  
prediction = tf.argmax(logits[0]).numpy()
actual = labels[0].numpy()
print("Predicted class: {} actual class: {}".format(prediction, actual))
plt.imshow(images[0]);
Predicted class: 0 actual class: 6

 

モデルを訓練する

モデルを訓練するためには optimizer が必要です。この単純なサンプルのために SGD optimizer で実装されている確率的勾配降下を使用します。勾配を計算するために tf.GradientTape を使用します、これは通して逆伝播することを望む計算のためだけに勾配を選択的に記録することを可能にします :

#@title Utility function to show progress bar.
from tqdm import tqdm

# MNIST training set has 60k images.
num_images = 60000

def progress_bar(generator):
  return tqdm(
      generator,
      unit='images',
      unit_scale=batch_size,
      total=(num_images // batch_size) * num_epochs)
opt = snt.optimizers.SGD(learning_rate=0.1)

num_epochs = 10

def step(images, labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    logits = mlp(images)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=labels)
    loss = tf.reduce_mean(loss)

  params = mlp.trainable_variables
  grads = tape.gradient(loss, params)
  opt.apply(grads, params)
  return loss

for images, labels in progress_bar(mnist_train.repeat(num_epochs)):
  loss = step(images, labels)

print("\n\nFinal loss: {}".format(loss.numpy()))
100%|██████████| 600000/600000 [01:02<00:00, 9660.48images/s] 

Final loss: 0.039747316390275955

 

モデルを評価する

モデルがこのデータセットに対してどのくらい上手くやるのかの感触を得るためにモデルの非常に単純な解析を行ないます :

total = 0
correct = 0
for images, labels in mnist_test:
  predictions = tf.argmax(mlp(images), axis=1)
  correct += tf.math.count_nonzero(tf.equal(predictions, labels))
  total += images.shape[0]

print("Got %d/%d (%.02f%%) correct" % (correct, total, correct / total * 100.))
Got 9767/10000 (97.67%) correct

結果を少しだけより良く理解するために、モデルがどこで数字を正しく識別したかの小さいサンプルを見ましょう :

#@title Utility function to show a sample of images.
def sample(correct, rows, cols):
  n = 0

  f, ax = plt.subplots(rows, cols)
  if rows > 1:    
    ax = tf.nest.flatten([tuple(ax[i]) for i in range(rows)])
  f.set_figwidth(14)
  f.set_figheight(4 * rows)


  for images, labels in mnist_test:
    predictions = tf.argmax(mlp(images), axis=1)
    eq = tf.equal(predictions, labels)
    for i, x in enumerate(eq):
      if x.numpy() == correct:
        label = labels[i]
        prediction = predictions[i]
        image = images[i]

        ax[n].imshow(image)
        ax[n].set_title("Prediction:{}\nActual:{}".format(prediction, label))

        n += 1
        if n == (rows * cols):
          break

    if n == (rows * cols):
      break
sample(correct=True, rows=1, cols=5)

今はそれがどこで入力を間違って分類するかを見ましょう。MNIST は幾つか寧ろ疑わしい手書きを持ちます。下のサンプルの幾つかは少し曖昧であることに貴方は同意することでしょう :

sample(correct=False, rows=2, cols=5)

 

以上






Sonnet 2.0 : イントロダクション, Getting Started & 直列化

Sonnet 2.0 : イントロダクション, Getting Started & 直列化 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

イントロダクション

Sonnet は機械学習研究のための単純で、構成可能な抽象を提供するために設計された、TensorFlow 2 上に構築されたライブラリです。

Sonnet は DeepMind の研究者により設計されて構築されました。それは多くの様々な目的のためにニューラルネットワークを構築するために利用可能です (教師なし/あり学習、強化学習, …)。それが私達の組織のためには成功的な抽象であると見出しています、you might too !

より具体的には、Sonnet は単一の概念: snt.Module を中心とする単純でしかしパワフルなプログラミングモデルを提供します。Module はパラメータ、他のモジュールそしてユーザ入力にある関数を適用するメソッドへの参照を保持できます。Sonnet は多くの事前定義されたモジュール (e.g. snt.Linear, snt.Conv2D, snt.BatchNorm) と幾つかの事前定義されたネットワーク・モジュール (e.g. snt.nets.MLP) とともに出荷されますが、ユーザはまた自身のモジュールを構築することも奨励されます。

多くのフレームワークとは違い、Sonnet は貴方のモジュールをどのように使用するかについて非常に意固地ではありません。Module は自己充足的で他の一つから完全に分離されるように設計されています。Sonnet は訓練フレームワークとともに出荷されませんので、ユーザは自身のものを構築するか他の人により構築されたものを採用することが推奨されます。

Sonnet はまた理解するために単純であるようにも設計されていて、私達のコードは (願わくば!) 明瞭で焦点が合っています。デフォルト (e.g. 初期パラメータ値のためのデフォルト) を選択したとこでは何故かを指摘することを試みます。

 

Getting Started

サンプル

Sonnet を試す最も容易な方法は Google Colab を利用することです、これは GPU か TPU に装着された free な Python ノートブックを供給します。

 

インストール

始めるには TensorFlow 2.0 と Sonnet 2 をインストールします :

To get started install TensorFlow 2.0 and Sonnet 2:

$ pip install tensorflow-gpu tensorflow-probability
$ pip install dm-sonnet

正しくインストールされたかを検証するに以下を実行できます :

import tensorflow as tf
import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))

 

既存のモジュールを使用する

Sonnet は自明に利用可能な幾つかの組込みモジュールとともに出荷されます。例えば MLP を定義するためにモジュールのシークエンスを呼び出すために snt.Sequential モジュールを利用できて、与えられたモジュールの出力を次のモジュールの入力として渡せます。計算を実際に定義するために snt.Linear と tf.nn.relu を利用できます :

mlp = snt.Sequential([
    snt.Linear(1024),
    tf.nn.relu,
    snt.Linear(10),
])

モジュールを使用するにはそれを「呼び出す」必要があります。Sequential モジュール (そして殆どのモジュール) は __call__ メソッドを定義します、これはそれらを名前で呼び出せることを意味します :

logits = mlp(tf.random.normal([batch_size, input_size]))

貴方のモジュールのためのパラメータ総てをリクエストすることも非常に一般的です。Sonnet の殆どのモジュールはそれらのパラメータをある入力で最初に呼び出されるとき作成します (何故ならば殆どの場合パラメータの shape は入力の関数であるからです)。Sonnet モジュールはパラメータにアクセスするために 2 つのプロパティを提供します。

variables プロパティは与えられたモジュールにより参照される 総ての tf.Variables を返します :

all_variables = mlp.variables

注目すべき点は tf.Variables は単に貴方のモデルのパラメータのために使用されるだけではないことです。例えばそれらは snt.BatchNorm で使用されるメトリクスの状態を保持するために使用されます。殆どの場合ユーザはモジュール変数を取得してそれらを更新されるために optimizer に渡します。この場合非訓練可能変数は典型的にはそのリストにあるべきではありません、何故ならばそれらは異なるメカニズムを通して更新されるからです。TensorFlow は変数を「訓練可能」 (モデルのパラメータ) vs. 非訓練可能 (他の変数) として印をつける組込みメカニズムを持ちます。Sonnet はモジュールから総ての訓練可能な変数を集めるメカニズムを提供します、これは多分貴方が optimizer に渡すことを望むものです :

model_parameters = mlp.trainable_variables

 

貴方自身のモジュールを構築する

Sonnet はユーザに自身のモジュールを定義するために snt.Module をサブクラス化することを強く奨励します。MyLinear と呼ばれる単純な線形層を作成することから始めましょう :

class MyLinear(snt.Module):

  def __init__(self, output_size, name=None):
    super(MyLinear, self).__init__(name=name)
    self.output_size = output_size

  @snt.once
  def _initialize(self, x):
    initial_w = tf.random.normal([x.shape[1], self.output_size])
    self.w = tf.Variable(initial_w, name="w")
    self.b = tf.Variable(tf.zeros([self.output_size]), name="b")

  def __call__(self, x):
    self._initialize(x)
    return tf.matmul(x, self.w) + self.b

このモジュールの使用は自明です :

mod = MyLinear(32)
mod(tf.ones([batch_size, input_size]))

snt.Module をサブクラス化することにより多くの素晴らしいプロパティをただで得ます。例えば __repr__ のデフォルト実装です、これはコンストラクタ引数を示します (デバッグと内省のために非常に有用です) :

>>> print(repr(mod))
MyLinear(output_size=10)

variables と trainable_variables プロパティもまた得ます :

>>> mod.variables
(<tf.Variable 'my_linear/b:0' shape=(10,) ...)>,
 <tf.Variable 'my_linear/w:0' shape=(1, 10) ...)>)

上の variables 上の my_linear prefix に気付くかもしれません。これは Sonnet モジュールもメソッドが呼び出されるときはいつでもモジュール名前空間に入るためです。モジュール名前空間に入ることにより消費する TensorBoard のようなツールのための遥かにより有用なグラフを提供します (e.g. my_linear 内で発生する総ての演算は my_linear と呼ばれるグループにあります)。更にモジュールは今では TensorFlow チェックポイントと saved モデルをサポートします、これは後でカバーされる進んだ特徴です。

 

シリアライゼーション

Sonnet は複数のシリアライゼーション形式をサポートします。サポートする最も単純な形式は Python の pickle です、そして総ての組込みモジュールは同じ Python プロセスで pickle を通してセーブ/ロードできることを確実にするためにテストされています。一般には pickle の利用は推奨されません、それは TensorFlow の多くのパートで上手くサポートされません、そして経験的に非常に不安定である可能性があります。

 

TensorFlow チェックポイント

参照: https://www.tensorflow.org/guide/checkpoint

訓練の間に定期的にパラメータ値をセーブするために TensorFlow チェックポイントが利用できます。これは、貴方のプログラムがクラッシュするか停止する場合に訓練の進捗をセーブするために有用であり得ます。Sonnet は TensorFlow チェックポイントとともにきれいに動作するように設計されています :

checkpoint_root = "/tmp/checkpoints"
checkpoint_name = "example"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)

my_module = create_my_sonnet_module()  # Can be anything extending snt.Module.

# A `Checkpoint` object manages checkpointing of the TensorFlow state associated
# with the objects passed to it's constructor. Note that Checkpoint supports
# restore on create, meaning that the variables of `my_module` do **not** need
# to be created before you restore from a checkpoint (their value will be
# restored when they are created).
checkpoint = tf.train.Checkpoint(module=my_module)

# Most training scripts will want to restore from a checkpoint if one exists. This
# would be the case if you interrupted your training (e.g. to use your GPU for
# something else, or in a cloud environment if your instance is preempted).
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
  checkpoint.restore(latest)

for step_num in range(num_steps):
  train(my_module)

  # During training we will occasionally save the values of weights. Note that
  # this is a blocking call and can be slow (typically we are writing to the
  # slowest storage on the machine). If you have a more reliable setup it might be
  # appropriate to save less frequently.
  if step_num and not step_num % 1000:
    checkpoint.save(save_prefix)

# Make sure to save your final values!!
checkpoint.save(save_prefix)

 

TensorFlow Saved モデル

参照: https://www.tensorflow.org/guide/saved_model

TensorFlow saved モデルはネットワークのコピーをセーブするために利用できます、これはそのための Python ソースから切り離されています。これは計算を記述する TensorFlow グラフと重みの値を含むチェックポイントをセーブすることにより可能になります。saved モデルを作成するために行なう最初のことはセーブすることを望む snt.Module を作成することです :

my_module = snt.nets.MLP([1024, 1024, 10])
my_module(tf.ones([1, input_size]))

次に、エクスポートすることを望むモデルの特定のパーツを記述するもう一つのモジュールを作成する必要があります。(元のモデルを in-place で変更するよりも) これを行なうことを勧めます、そうすれば実際にエクスポートされるものに渡る極め細かい制御を持ちます。これは非常に大きい saved モデルを作成することを回避するために典型的には重要です、そしてそのようなものとして貴方が望むモデルのパーツを共有するだけです (e.g. GAN のために generator を共有することを望むだけで discriminator は private に保持します)。

@tf.function(input_signature=[tf.TensorSpec([None, input_size])])
def inference(x):
  return my_module(x)

to_save = snt.Module()
to_save.inference = inference
to_save.all_variables = list(my_module.variables)
tf.saved_model.save(to_save, "/tmp/example_saved_model")

今は /tmp/example_saved_model フォルダで saved モデルを持ちます :

$ ls -lh /tmp/example_saved_model
total 24K
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:14 assets
-rw-rw-r-- 1 tomhennigan 154432098  14K Apr 28 00:15 saved_model.pb
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:15 variables

このモデルのロードは単純で、saved モデルを構築したどのような Python コードもなしに異なるマシン上で成されます :

loaded = tf.saved_model.load("/tmp/example_saved_model")

# Use the inference method. Note this doesn't run the Python code from `to_save`
# but instead uses the TensorFlow Graph that is part of the saved model.
loaded.inference(tf.ones([1, input_size]))

# The all_variables property can be used to retrieve the restored variables.
assert len(loaded.all_variables) > 0

ロードされたオブジェクトは Sonnet モジュールではないことに注意してください、それは前のブロックで追加した特定のメソッド (e.g. inference) とプロパティ (e.g. all_variables) を持つコンテナ・オブジェクトです。

 

分散訓練

サンプル: https://github.com/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb

Sonnet は カスタム TensorFlow 分散ストラテジー を使用して分散訓練のためのサポートを持ちます。

Sonnet と tf.keras を使用する分散訓練の間の主要な違いは Sonnet モジュールと optimizer は分散ストラテジーで動作するとき異なる動作をしないことです (e.g. 勾配を平均したりバッチ norm スタッツを同期しません)。ユーザは訓練のこれらの様相の完全な制御にあるべきでライブラリに焼き固められるべきではないと信じます。ここでのトレードオフはこれらの特徴を貴方の訓練スクリプトで実装するか (optimizer を適用する前に勾配を総て減じるために典型的にはこれは単に 2 行のコードです) 明示的に分散 aware なモジュール (e.g. snt.distribute.CrossReplicaBatchNorm) と交換する必要があります。

分散 Cifar-10 サンプルは Sonnet でマルチ GPU 訓練を行なうことをウォークスルーします。

 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com