ホーム » Sonnet

Sonnet」カテゴリーアーカイブ

Sonnet 2.0 : Tutorials : VQ-VAE 訓練サンプル

Sonnet 2.0 : Tutorials : VQ-VAE 訓練サンプル (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

VQ-VAE 訓練サンプル

TF 2 / Sonnet 2 を使用して、https://arxiv.org/abs/1711.00937 で指定されるモデルをどのように訓練するかの実演です。

Mac と Linux 上、単純に各セルを順番に実行してください。

!pip install dm-sonnet dm-tree
Requirement already satisfied: dm-sonnet in /tmp/sonnet-nb-env/lib/python3.7/site-packages (2.0.0)
Requirement already satisfied: dm-tree in /tmp/sonnet-nb-env/lib/python3.7/site-packages (0.1.5)
Requirement already satisfied: six>=1.12.0 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.14.0)
Requirement already satisfied: tabulate>=0.7.5 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.8.7)
Requirement already satisfied: absl-py>=0.7.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.9.0)
Requirement already satisfied: numpy>=1.16.3 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.18.3)
Requirement already satisfied: wrapt>=1.11.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.12.1)
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree

try:
  import sonnet.v2 as snt
  tf.enable_v2_behavior()
except ImportError:
  import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))
TensorFlow version 2.1.0
Sonnet version 2.0.0

 

Cifar10 データをダウンロードする

これはインターネットへの接続を必要として ~160MB をダウンロードします。

cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1))
cifar10.pop("id", None)
cifar10.pop("label")
tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)
{'image': 'uint8[60000, 32, 32, 3]'}

 

データを Numpy にロードする

下で平均二乗誤差を正規化するために訓練セット全体の分散を計算します。

train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10)
valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10)
test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)
def cast_and_normalise_images(data_dict):
  """Convert images to floating point with the range [-0.5, 0.5]"""
  images = data_dict['image']
  data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
  return data_dict

train_data_variance = np.var(train_data_dict['image'] / 255.0)
print('train data variance: %s' % train_data_variance)
train data variance: 0.06327039811675479

 

エンコーダ & デコーダ・アーキテクチャ

  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(ResidualStack, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._layers = []
    for i in range(num_residual_layers):
      conv3 = snt.Conv2D(
          output_channels=num_residual_hiddens,
          kernel_shape=(3, 3),
          stride=(1, 1),
          name="res3x3_%d" % i)
      conv1 = snt.Conv2D(
          output_channels=num_hiddens,
          kernel_shape=(1, 1),
          stride=(1, 1),
          name="res1x1_%d" % i)
      self._layers.append((conv3, conv1))

  def __call__(self, inputs):
    h = inputs
    for conv3, conv1 in self._layers:
      conv3_out = conv3(tf.nn.relu(h))
      conv1_out = conv1(tf.nn.relu(conv3_out))
      h += conv1_out
    return tf.nn.relu(h)  # Resnet V1 style


class Encoder(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Encoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._enc_1 = snt.Conv2D(
        output_channels=self._num_hiddens // 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_1")
    self._enc_2 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_2")
    self._enc_3 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="enc_3")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)

  def __call__(self, x):
    h = tf.nn.relu(self._enc_1(x))
    h = tf.nn.relu(self._enc_2(h))
    h = tf.nn.relu(self._enc_3(h))
    return self._residual_stack(h)


class Decoder(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Decoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._dec_1 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="dec_1")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)
    self._dec_2 = snt.Conv2DTranspose(
        output_channels=self._num_hiddens // 2,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_2")
    self._dec_3 = snt.Conv2DTranspose(
        output_channels=3,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_3")
    
  def __call__(self, x):
    h = self._dec_1(x)
    h = self._residual_stack(h)
    h = tf.nn.relu(self._dec_2(h))
    x_recon = self._dec_3(h)
    return x_recon
    

class VQVAEModel(snt.Module):
  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1, 
               data_variance, name=None):
    super(VQVAEModel, self).__init__(name=name)
    self._encoder = encoder
    self._decoder = decoder
    self._vqvae = vqvae
    self._pre_vq_conv1 = pre_vq_conv1
    self._data_variance = data_variance

  def __call__(self, inputs, is_training):
    z = self._pre_vq_conv1(self._encoder(inputs))
    vq_output = self._vqvae(z, is_training=is_training)
    x_recon = self._decoder(vq_output['quantize'])
    recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance
    loss = recon_error + vq_output['loss']
    return {
        'z': z,
        'x_recon': x_recon,
        'loss': loss,
        'recon_error': recon_error,
        'vq_output': vq_output,
    }

 

モデルを構築して訓練する

%%time

# Set hyper-parameters.
batch_size = 32
image_size = 32

# 100k steps should take < 30 minutes on a modern (>= 2017) GPU.
# 10k steps gives reasonable accuracy with VQVAE on Cifar10.
num_training_updates = 10000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
# These hyper-parameters define the size of the model (number of parameters and layers).
# The hyper-parameters in the paper were (For ImageNet):
# batch_size = 128
# image_size = 128
# num_hiddens = 128
# num_residual_hiddens = 32
# num_residual_layers = 2

# This value is not that important, usually 64 works.
# This will not change the capacity in the information-bottleneck.
embedding_dim = 64

# The higher this value, the higher the capacity in the information bottleneck.
num_embeddings = 512

# commitment_cost should be set appropriately. It's often useful to try a couple
# of values. It mostly depends on the scale of the reconstruction cost
# (log p(x|z)). So if the reconstruction cost is 100x higher, the
# commitment_cost should also be multiplied with the same amount.
commitment_cost = 0.25

# Use EMA updates for the codebook (instead of the Adam optimizer).
# This typically converges faster, and makes the model less dependent on choice
# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was
# developed afterwards). See Appendix of the paper for more details.
vq_use_ema = True

# This is only used for EMA updates.
decay = 0.99

learning_rate = 3e-4


# # Data Loading.
train_dataset = (
    tf.data.Dataset.from_tensor_slices(train_data_dict)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)  # repeat indefinitely
    .batch(batch_size, drop_remainder=True)
    .prefetch(-1))

valid_dataset = (
    tf.data.Dataset.from_tensor_slices(valid_data_dict)
    .map(cast_and_normalise_images)
    .repeat(1)  # 1 epoch
    .batch(batch_size)
    .prefetch(-1))

# # Build modules.
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,
    kernel_shape=(1, 1),
    stride=(1, 1),
    name="to_vq")

if vq_use_ema:
  vq_vae = snt.nets.VectorQuantizerEMA(
      embedding_dim=embedding_dim,
      num_embeddings=num_embeddings,
      commitment_cost=commitment_cost,
      decay=decay)
else:
  vq_vae = snt.nets.VectorQuantizer(
      embedding_dim=embedding_dim,
      num_embeddings=num_embeddings,
      commitment_cost=commitment_cost)
  
model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
                   data_variance=train_data_variance)

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

@tf.function
def train_step(data):
  with tf.GradientTape() as tape:
    model_output = model(data['image'], is_training=True)
  trainable_variables = model.trainable_variables
  grads = tape.gradient(model_output['loss'], trainable_variables)
  optimizer.apply(grads, trainable_variables)

  return model_output

train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []

for step_index, data in enumerate(train_dataset):
  train_results = train_step(data)
  train_losses.append(train_results['loss'])
  train_recon_errors.append(train_results['recon_error'])
  train_perplexities.append(train_results['vq_output']['perplexity'])
  train_vqvae_loss.append(train_results['vq_output']['loss'])

  if (step_index + 1) % 100 == 0:
    print('%d train loss: %f ' % (step_index + 1,
                                   np.mean(train_losses[-100:])) +
          ('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +
          ('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +
          ('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))
  if step_index == num_training_updates:
    break
WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING: AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
100 train loss: 0.523625 recon_error: 0.483 perplexity: 10.356 vqvae loss: 0.041
200 train loss: 0.248232 recon_error: 0.223 perplexity: 18.294 vqvae loss: 0.026
300 train loss: 0.215068 recon_error: 0.190 perplexity: 23.106 vqvae loss: 0.025
400 train loss: 0.191891 recon_error: 0.164 perplexity: 29.139 vqvae loss: 0.028
500 train loss: 0.180945 recon_error: 0.147 perplexity: 34.253 vqvae loss: 0.033
600 train loss: 0.167115 recon_error: 0.134 perplexity: 39.961 vqvae loss: 0.033
700 train loss: 0.157724 recon_error: 0.124 perplexity: 46.521 vqvae loss: 0.033
800 train loss: 0.153761 recon_error: 0.119 perplexity: 53.559 vqvae loss: 0.035
900 train loss: 0.145033 recon_error: 0.112 perplexity: 62.442 vqvae loss: 0.033
1000 train loss: 0.137589 recon_error: 0.105 perplexity: 71.831 vqvae loss: 0.033
1100 train loss: 0.133044 recon_error: 0.101 perplexity: 79.135 vqvae loss: 0.032
1200 train loss: 0.129990 recon_error: 0.098 perplexity: 87.959 vqvae loss: 0.032
1300 train loss: 0.126507 recon_error: 0.095 perplexity: 96.704 vqvae loss: 0.031
1400 train loss: 0.122403 recon_error: 0.092 perplexity: 104.202 vqvae loss: 0.031
1500 train loss: 0.122003 recon_error: 0.091 perplexity: 112.476 vqvae loss: 0.031
1600 train loss: 0.120192 recon_error: 0.089 perplexity: 122.269 vqvae loss: 0.032
1700 train loss: 0.117041 recon_error: 0.086 perplexity: 129.887 vqvae loss: 0.031
1800 train loss: 0.115004 recon_error: 0.083 perplexity: 138.603 vqvae loss: 0.032
1900 train loss: 0.114134 recon_error: 0.082 perplexity: 147.545 vqvae loss: 0.032
2000 train loss: 0.112840 recon_error: 0.081 perplexity: 153.993 vqvae loss: 0.032
2100 train loss: 0.108815 recon_error: 0.077 perplexity: 161.729 vqvae loss: 0.031
2200 train loss: 0.108596 recon_error: 0.078 perplexity: 171.971 vqvae loss: 0.031
2300 train loss: 0.108132 recon_error: 0.077 perplexity: 181.157 vqvae loss: 0.031
2400 train loss: 0.106273 recon_error: 0.076 perplexity: 186.200 vqvae loss: 0.031
2500 train loss: 0.105936 recon_error: 0.075 perplexity: 194.301 vqvae loss: 0.031
2600 train loss: 0.103880 recon_error: 0.073 perplexity: 201.674 vqvae loss: 0.030
2700 train loss: 0.101655 recon_error: 0.072 perplexity: 207.131 vqvae loss: 0.030
2800 train loss: 0.102564 recon_error: 0.072 perplexity: 216.983 vqvae loss: 0.030
2900 train loss: 0.101613 recon_error: 0.072 perplexity: 219.649 vqvae loss: 0.030
3000 train loss: 0.101227 recon_error: 0.071 perplexity: 226.789 vqvae loss: 0.030
3100 train loss: 0.100786 recon_error: 0.071 perplexity: 235.522 vqvae loss: 0.030
3200 train loss: 0.100130 recon_error: 0.070 perplexity: 243.282 vqvae loss: 0.030
3300 train loss: 0.097764 recon_error: 0.067 perplexity: 249.584 vqvae loss: 0.030
3400 train loss: 0.100630 recon_error: 0.069 perplexity: 260.551 vqvae loss: 0.031
3500 train loss: 0.099929 recon_error: 0.068 perplexity: 266.012 vqvae loss: 0.032
3600 train loss: 0.099245 recon_error: 0.067 perplexity: 272.031 vqvae loss: 0.032
3700 train loss: 0.097812 recon_error: 0.066 perplexity: 279.691 vqvae loss: 0.032
3800 train loss: 0.097137 recon_error: 0.064 perplexity: 284.240 vqvae loss: 0.033
3900 train loss: 0.099217 recon_error: 0.066 perplexity: 293.507 vqvae loss: 0.034
4000 train loss: 0.098570 recon_error: 0.065 perplexity: 300.891 vqvae loss: 0.034
4100 train loss: 0.099238 recon_error: 0.065 perplexity: 306.762 vqvae loss: 0.034
4200 train loss: 0.098172 recon_error: 0.064 perplexity: 311.918 vqvae loss: 0.035
4300 train loss: 0.096449 recon_error: 0.063 perplexity: 316.246 vqvae loss: 0.034
4400 train loss: 0.096487 recon_error: 0.062 perplexity: 319.591 vqvae loss: 0.034
4500 train loss: 0.096092 recon_error: 0.062 perplexity: 322.313 vqvae loss: 0.034
4600 train loss: 0.096474 recon_error: 0.062 perplexity: 324.620 vqvae loss: 0.035
4700 train loss: 0.097075 recon_error: 0.063 perplexity: 324.357 vqvae loss: 0.035
4800 train loss: 0.094709 recon_error: 0.060 perplexity: 326.024 vqvae loss: 0.034
4900 train loss: 0.096557 recon_error: 0.061 perplexity: 327.701 vqvae loss: 0.035
5000 train loss: 0.096185 recon_error: 0.061 perplexity: 326.664 vqvae loss: 0.035
5100 train loss: 0.095646 recon_error: 0.060 perplexity: 327.617 vqvae loss: 0.035
5200 train loss: 0.094689 recon_error: 0.059 perplexity: 328.692 vqvae loss: 0.035
5300 train loss: 0.097047 recon_error: 0.061 perplexity: 327.988 vqvae loss: 0.036
5400 train loss: 0.096259 recon_error: 0.060 perplexity: 327.075 vqvae loss: 0.036
5500 train loss: 0.094588 recon_error: 0.059 perplexity: 327.083 vqvae loss: 0.036
5600 train loss: 0.095947 recon_error: 0.060 perplexity: 328.213 vqvae loss: 0.036
5700 train loss: 0.095466 recon_error: 0.059 perplexity: 329.375 vqvae loss: 0.036
5800 train loss: 0.094849 recon_error: 0.059 perplexity: 326.821 vqvae loss: 0.036
5900 train loss: 0.093799 recon_error: 0.058 perplexity: 328.409 vqvae loss: 0.036
6000 train loss: 0.095373 recon_error: 0.059 perplexity: 326.791 vqvae loss: 0.036
6100 train loss: 0.093989 recon_error: 0.059 perplexity: 325.959 vqvae loss: 0.035
6200 train loss: 0.095549 recon_error: 0.059 perplexity: 330.829 vqvae loss: 0.036
6300 train loss: 0.094730 recon_error: 0.058 perplexity: 330.906 vqvae loss: 0.036
6400 train loss: 0.095038 recon_error: 0.058 perplexity: 329.353 vqvae loss: 0.037
6500 train loss: 0.095891 recon_error: 0.059 perplexity: 330.197 vqvae loss: 0.037
6600 train loss: 0.094342 recon_error: 0.058 perplexity: 331.240 vqvae loss: 0.036
6700 train loss: 0.095096 recon_error: 0.058 perplexity: 330.618 vqvae loss: 0.037
6800 train loss: 0.095581 recon_error: 0.059 perplexity: 324.493 vqvae loss: 0.037
6900 train loss: 0.094467 recon_error: 0.058 perplexity: 328.868 vqvae loss: 0.037
7000 train loss: 0.092967 recon_error: 0.057 perplexity: 328.276 vqvae loss: 0.036
7100 train loss: 0.094339 recon_error: 0.058 perplexity: 327.318 vqvae loss: 0.037
7200 train loss: 0.095227 recon_error: 0.058 perplexity: 326.306 vqvae loss: 0.037
7300 train loss: 0.093832 recon_error: 0.057 perplexity: 328.262 vqvae loss: 0.037
7400 train loss: 0.093331 recon_error: 0.057 perplexity: 327.987 vqvae loss: 0.037
7500 train loss: 0.094718 recon_error: 0.058 perplexity: 328.948 vqvae loss: 0.037
7600 train loss: 0.094199 recon_error: 0.058 perplexity: 328.468 vqvae loss: 0.037
7700 train loss: 0.094603 recon_error: 0.058 perplexity: 327.501 vqvae loss: 0.037
7800 train loss: 0.092299 recon_error: 0.056 perplexity: 327.630 vqvae loss: 0.037
7900 train loss: 0.095228 recon_error: 0.058 perplexity: 329.946 vqvae loss: 0.037
8000 train loss: 0.094291 recon_error: 0.058 perplexity: 326.790 vqvae loss: 0.037
8100 train loss: 0.094481 recon_error: 0.057 perplexity: 328.667 vqvae loss: 0.037
8200 train loss: 0.093992 recon_error: 0.057 perplexity: 329.655 vqvae loss: 0.037
8300 train loss: 0.093976 recon_error: 0.057 perplexity: 323.950 vqvae loss: 0.037
8400 train loss: 0.093422 recon_error: 0.057 perplexity: 324.523 vqvae loss: 0.036
8500 train loss: 0.092898 recon_error: 0.056 perplexity: 325.402 vqvae loss: 0.037
8600 train loss: 0.094298 recon_error: 0.057 perplexity: 329.251 vqvae loss: 0.037
8700 train loss: 0.094489 recon_error: 0.057 perplexity: 331.027 vqvae loss: 0.037
8800 train loss: 0.093022 recon_error: 0.056 perplexity: 327.495 vqvae loss: 0.037
8900 train loss: 0.093427 recon_error: 0.057 perplexity: 328.008 vqvae loss: 0.037
9000 train loss: 0.094884 recon_error: 0.058 perplexity: 327.057 vqvae loss: 0.037
9100 train loss: 0.093559 recon_error: 0.056 perplexity: 331.800 vqvae loss: 0.037
9200 train loss: 0.093282 recon_error: 0.056 perplexity: 328.689 vqvae loss: 0.037
9300 train loss: 0.092217 recon_error: 0.056 perplexity: 323.903 vqvae loss: 0.036
9400 train loss: 0.093902 recon_error: 0.057 perplexity: 326.350 vqvae loss: 0.037
9500 train loss: 0.093772 recon_error: 0.057 perplexity: 325.627 vqvae loss: 0.037
9600 train loss: 0.093123 recon_error: 0.056 perplexity: 327.352 vqvae loss: 0.037
9700 train loss: 0.092934 recon_error: 0.056 perplexity: 328.674 vqvae loss: 0.037
9800 train loss: 0.093284 recon_error: 0.056 perplexity: 329.437 vqvae loss: 0.037
9900 train loss: 0.094147 recon_error: 0.057 perplexity: 330.146 vqvae loss: 0.037
10000 train loss: 0.092876 recon_error: 0.056 perplexity: 326.349 vqvae loss: 0.037
CPU times: user 1h 47min 46s, sys: 14min 12s, total: 2h 1min 59s
Wall time: 4min 29s

 

損失をプロットする

f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')

ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')
Text(0.5, 1.0, 'Average codebook usage (perplexity).')

 

再構築を見る

# Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))

# Put data through the model with is_training=False, so that in the case of 
# using EMA the codebook is not updated.
train_reconstructions = model(train_batch['image'],
                              is_training=False)['x_recon'].numpy()
valid_reconstructions = model(valid_batch['image'],
                              is_training=False)['x_recon'].numpy()


def convert_batch_to_image_grid(image_batch):
  reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
              .transpose(0, 2, 1, 3, 4)
              .reshape(4 * 32, 8 * 32, 3))
  return reshaped + 0.5



f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['image'].numpy()),
          interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')

ax = f.add_subplot(2,2,2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions),
          interpolation='nearest')
ax.set_title('training data reconstructions')
plt.axis('off')

ax = f.add_subplot(2,2,3)
ax.imshow(convert_batch_to_image_grid(valid_batch['image'].numpy()),
          interpolation='nearest')
ax.set_title('validation data originals')
plt.axis('off')

ax = f.add_subplot(2,2,4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions),
          interpolation='nearest')
ax.set_title('validation data reconstructions')
plt.axis('off')
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 255.5, 127.5, -0.5)

 

以上






Sonnet 2.0 : Tutorials : snt.distribute で分散訓練 (CIFAR-10)

Sonnet 2.0 : Tutorials : snt.distribute で分散訓練 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

snt.distribute で分散訓練

イントロダクション

このチュートリアルは Sonnet 2 “Hello, world!” サンプル (MLP on MNIST) を既に完了していることを仮定しています。

このチュートリアルでは、より大きなモデルとより大きなデータセットで物事をスケールアップしていきます、そして計算をマルチデバイスに渡り分散していきます。

 

import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"
!pip install dm-sonnet tqdm
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
print("TensorFlow version: {}".format(tf.__version__))
print("    Sonnet version: {}".format(snt.__version__))

最後に利用可能な GPU を素早く見ましょう :

!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'

 

分散ストラテジー

幾つかのデバイスに渡り計算を分散するためのストラテジーが必要です。Google Colab は単一 GPU を提供するだけですのでそれを 4 つの仮想 GPU に分割します :

physical_gpus = tf.config.experimental.list_physical_devices("GPU")
physical_gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
tf.config.experimental.set_virtual_device_configuration(
    physical_gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4
)
gpus = tf.config.experimental.list_logical_devices("GPU")
gpus
[LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:1', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:2', device_type='GPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:3', device_type='GPU')]

Sonnet optimizer を使用するとき、snt.distribute からの Replicator か TpuReplicator を利用しなければあんりません、あるいは tf.distribute.OneDeviceStrategy を利用できます。Replicator は MirroredStrategy と等値でそして TpuReplicator は TPUStrategy と等値です。

strategy = snt.distribute.Replicator(
    ["/device:GPU:{}".format(i) for i in range(4)],
    tf.distribute.ReductionToOneDevice("GPU:0"))

 

データセット

基本的には MNIST サンプルと同じですが、今回は CIFAR-10 を使用しています。CIFAR-10 は 10 の異なるクラス (飛行機、自動車、鳥、猫、鹿、犬、蛙、馬、船そしてトラック) にある 32×32 ピクセルカラー画像を含みます。

# NOTE: This is the batch size across all GPUs.
batch_size = 100 * 4

def process_batch(images, labels):
  images = tf.cast(images, dtype=tf.float32)
  images = ((images / 255.) - .5) * 2.
  return images, labels

def cifar10(split):
  dataset = tfds.load("cifar10", split=split, as_supervised=True)
  dataset = dataset.map(process_batch)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  return dataset

cifar10_train = cifar10("train").shuffle(10)
cifar10_test = cifar10("test")

 

モデル & Optimizer

都合良く、snt.nets にこのデータセットのために特に設計された事前ビルドされたモデルがあります。

作成された任意の変数が正しく分散されることを確実にするために、モデルと optimizer は strategy スコープ内で構築しなければなりません。代わりに、tf.distribute.experimental_set_strategy を使用してプログラム全体のためのスコープに入ることもでできるでしょう。

learning_rate = 0.1

with strategy.scope():
  model = snt.nets.Cifar10ConvNet()
  optimizer = snt.optimizers.Momentum(learning_rate, 0.9)

 

モデルを訓練する

Sonnet optimizer はできる限り綺麗でそして単純であるように設計されています。それらは分散実行を扱うためのどのようなコードも含みません。従ってそれはコードの 2, 3 の追加行を必要とします。

異なるデバイス上で計算された勾配を集めなければなりません。これは ReplicaContext.all_reduce を使用して成されます。

Replicator / TpuReplicator を使用するとき values が総てのレプリカで同一で在り続けることを確かなものにすることはユーザの責任であることに注意してください。

def step(images, labels):
  """Performs a single training step, returning the cross-entropy loss."""
  with tf.GradientTape() as tape:
    logits = model(images, is_training=True)["logits"]
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                       logits=logits))

  grads = tape.gradient(loss, model.trainable_variables)

  # Aggregate the gradients from the full batch.
  replica_ctx = tf.distribute.get_replica_context()
  grads = replica_ctx.all_reduce("mean", grads)

  optimizer.apply(grads, model.trainable_variables)
  return loss

@tf.function
def train_step(images, labels):
  per_replica_loss = strategy.run(step, args=(images, labels))
  return strategy.reduce("sum", per_replica_loss, axis=None)

def train_epoch(dataset):
  """Performs one epoch of training, returning the mean cross-entropy loss."""
  total_loss = 0.0
  num_batches = 0

  # Loop over the entire training set.
  for images, labels in dataset:
    total_loss += train_step(images, labels).numpy()
    num_batches += 1

  return total_loss / num_batches

cifar10_train_dist = strategy.experimental_distribute_dataset(cifar10_train)

for epoch in range(20):
  print("Training epoch", epoch, "...", end=" ")
  print("loss :=", train_epoch(cifar10_train_dist))

 

モデルを評価する

バッチ次元に渡り削減するために strategy.reduce による axis パラメータの使用方法に注意してください。

num_cifar10_test_examples = 10000

def is_predicted(images, labels):
  logits = model(images, is_training=False)["logits"]
  # The reduction over the batch happens in `strategy.reduce`, below.
  return tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.int32)

cifar10_test_dist = strategy.experimental_distribute_dataset(cifar10_test)

@tf.function
def evaluate():
  """Returns the top-1 accuracy over the entire test set."""
  total_correct = 0

  for images, labels in cifar10_test_dist:
    per_replica_correct = strategy.run(is_predicted, args=(images, labels))
    total_correct += strategy.reduce("sum", per_replica_correct, axis=0)

  return tf.cast(total_correct, tf.float32) / num_cifar10_test_examples

print("Testing...", end=" ")
print("top-1 accuracy =", evaluate().numpy())
 

以上






Sonnet 2.0 : Tutorials : MLP で MNIST を予測する

Sonnet 2.0 : Tutorials : MLP で MNIST を予測する (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

MLP で MNIST を予測する

import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"
!pip install dm-sonnet tqdm
Requirement already satisfied: dm-sonnet==2.0.0b0 in /usr/local/lib/python3.6/dist-packages (2.0.0b0)
Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (0.2.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)
Requirement already satisfied: wrapt>=1.11.1 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.11.2)
Requirement already satisfied: numpy>=1.16.3 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.17.2)
Requirement already satisfied: absl-py>=0.7.1 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (0.8.0)
Requirement already satisfied: six>=1.12.0 in /tensorflow-2.0.0-rc1/python3.6 (from dm-sonnet==2.0.0b0) (1.12.0)
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
print("TensorFlow version: {}".format(tf.__version__))
print("    Sonnet version: {}".format(snt.__version__))
TensorFlow version: 2.0.0-rc1
    Sonnet version: 2.0.0b0

最後に利用可能な GPU を素早く見ましょう :

!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'
 Tesla K80

 

データセット

データセットを容易に反復できる状態で得る必要があります。TensorFlow Datasets パッケージはこのために単純な API を提供します。それはデータセットをダウンロードして私達のために GPU 上で速やかに処理できるように準備します。モデルがそれを見る前にデータセットを変更する私達自身の前処理関数を追加することもできます :

batch_size = 100

def process_batch(images, labels):
  images = tf.squeeze(images, axis=[-1])
  images = tf.cast(images, dtype=tf.float32)
  images = ((images / 255.) - .5) * 2.
  return images, labels

def mnist(split):
  dataset = tfds.load("mnist", split=split, as_supervised=True)
  dataset = dataset.map(process_batch)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  return dataset

mnist_train = mnist("train").shuffle(10)
mnist_test = mnist("test")

MNIST は 28×28 グレースケール手書き数字を含みます。一つ見てみましょう :

import matplotlib.pyplot as plt

images, _ = next(iter(mnist_test))
plt.imshow(images[0]);

 

Sonnet

次のステップはモデルを定義することです。Sonnet では TensorFlow variables (tf.Variable) を含む総ては snt.Module を拡張しています、これは低位ニューラルネットワーク・コンポーネント (e.g. snt.Linear, snt.Conv2D)、サブコンポーネントを含むより大きなネット (e.g. snt.nets.MLP)、optimizer (e.g. snt.optimizers.Adam) そして貴方が考えられるものは何でも含みます。

モジュールはパラメータ (そして BatchNorm の移動平均をストアするためのような他の目的のために使用される Variable) をストアするための単純な抽象を提供します。

与えられたモジュールのための総てのパラメータを見つけるには、単純に: module.variables を行ないます。これはこのモジュール、あるいはそれが参照する任意のモジュールのために存在する総てのパラメータのタプルを返します :

 

モデルを構築する

Sonnet では snt.Modules からニューラルネットワークを構築します。この場合多層パーセプトロンを、入力を ReLU 非線形を伴う幾つかの完全結合層を通してロジットを計算する __call__ メソッドを持つ新しいクラスとして構築します。

class MLP(snt.Module):

  def __init__(self):
    super(MLP, self).__init__()
    self.flatten = snt.Flatten()
    self.hidden1 = snt.Linear(1024, name="hidden1")
    self.hidden2 = snt.Linear(1024, name="hidden2")
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images):
    output = self.flatten(images)
    output = tf.nn.relu(self.hidden1(output))
    output = tf.nn.relu(self.hidden2(output))
    output = self.logits(output)
    return output

今はクラスのインスタンスを作成します、その重みはランダムに初期化されます。この MLP をそれが MNIST データセットの数字を認識することを学習するように訓練します。

mlp = MLP()
mlp
MLP()

 

モデルを使用する

サンプル入力をモデルに供給してそれが何を予測するかを見ましょう。モデルはランダムに初期化されましたので、それが正しいクラスを予測する 1/10 のチャンスがあります!

images, labels = next(iter(mnist_test))
logits = mlp(images)
  
prediction = tf.argmax(logits[0]).numpy()
actual = labels[0].numpy()
print("Predicted class: {} actual class: {}".format(prediction, actual))
plt.imshow(images[0]);
Predicted class: 0 actual class: 6

 

モデルを訓練する

モデルを訓練するためには optimizer が必要です。この単純なサンプルのために SGD optimizer で実装されている確率的勾配降下を使用します。勾配を計算するために tf.GradientTape を使用します、これは通して逆伝播することを望む計算のためだけに勾配を選択的に記録することを可能にします :

#@title Utility function to show progress bar.
from tqdm import tqdm

# MNIST training set has 60k images.
num_images = 60000

def progress_bar(generator):
  return tqdm(
      generator,
      unit='images',
      unit_scale=batch_size,
      total=(num_images // batch_size) * num_epochs)
opt = snt.optimizers.SGD(learning_rate=0.1)

num_epochs = 10

def step(images, labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    logits = mlp(images)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=labels)
    loss = tf.reduce_mean(loss)

  params = mlp.trainable_variables
  grads = tape.gradient(loss, params)
  opt.apply(grads, params)
  return loss

for images, labels in progress_bar(mnist_train.repeat(num_epochs)):
  loss = step(images, labels)

print("\n\nFinal loss: {}".format(loss.numpy()))
100%|██████████| 600000/600000 [01:02<00:00, 9660.48images/s] 

Final loss: 0.039747316390275955

 

モデルを評価する

モデルがこのデータセットに対してどのくらい上手くやるのかの感触を得るためにモデルの非常に単純な解析を行ないます :

total = 0
correct = 0
for images, labels in mnist_test:
  predictions = tf.argmax(mlp(images), axis=1)
  correct += tf.math.count_nonzero(tf.equal(predictions, labels))
  total += images.shape[0]

print("Got %d/%d (%.02f%%) correct" % (correct, total, correct / total * 100.))
Got 9767/10000 (97.67%) correct

結果を少しだけより良く理解するために、モデルがどこで数字を正しく識別したかの小さいサンプルを見ましょう :

#@title Utility function to show a sample of images.
def sample(correct, rows, cols):
  n = 0

  f, ax = plt.subplots(rows, cols)
  if rows > 1:    
    ax = tf.nest.flatten([tuple(ax[i]) for i in range(rows)])
  f.set_figwidth(14)
  f.set_figheight(4 * rows)


  for images, labels in mnist_test:
    predictions = tf.argmax(mlp(images), axis=1)
    eq = tf.equal(predictions, labels)
    for i, x in enumerate(eq):
      if x.numpy() == correct:
        label = labels[i]
        prediction = predictions[i]
        image = images[i]

        ax[n].imshow(image)
        ax[n].set_title("Prediction:{}\nActual:{}".format(prediction, label))

        n += 1
        if n == (rows * cols):
          break

    if n == (rows * cols):
      break
sample(correct=True, rows=1, cols=5)

今はそれがどこで入力を間違って分類するかを見ましょう。MNIST は幾つか寧ろ疑わしい手書きを持ちます。下のサンプルの幾つかは少し曖昧であることに貴方は同意することでしょう :

sample(correct=False, rows=2, cols=5)

 

以上






Sonnet 2.0 : イントロダクション, Getting Started & 直列化

Sonnet 2.0 : イントロダクション, Getting Started & 直列化 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/10/2020

* 本ページは、Sonnet の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

イントロダクション

Sonnet は機械学習研究のための単純で、構成可能な抽象を提供するために設計された、TensorFlow 2 上に構築されたライブラリです。

Sonnet は DeepMind の研究者により設計されて構築されました。それは多くの様々な目的のためにニューラルネットワークを構築するために利用可能です (教師なし/あり学習、強化学習, …)。それが私達の組織のためには成功的な抽象であると見出しています、you might too !

より具体的には、Sonnet は単一の概念: snt.Module を中心とする単純でしかしパワフルなプログラミングモデルを提供します。Module はパラメータ、他のモジュールそしてユーザ入力にある関数を適用するメソッドへの参照を保持できます。Sonnet は多くの事前定義されたモジュール (e.g. snt.Linear, snt.Conv2D, snt.BatchNorm) と幾つかの事前定義されたネットワーク・モジュール (e.g. snt.nets.MLP) とともに出荷されますが、ユーザはまた自身のモジュールを構築することも奨励されます。

多くのフレームワークとは違い、Sonnet は貴方のモジュールをどのように使用するかについて非常に意固地ではありません。Module は自己充足的で他の一つから完全に分離されるように設計されています。Sonnet は訓練フレームワークとともに出荷されませんので、ユーザは自身のものを構築するか他の人により構築されたものを採用することが推奨されます。

Sonnet はまた理解するために単純であるようにも設計されていて、私達のコードは (願わくば!) 明瞭で焦点が合っています。デフォルト (e.g. 初期パラメータ値のためのデフォルト) を選択したとこでは何故かを指摘することを試みます。

 

Getting Started

サンプル

Sonnet を試す最も容易な方法は Google Colab を利用することです、これは GPU か TPU に装着された free な Python ノートブックを供給します。

 

インストール

始めるには TensorFlow 2.0 と Sonnet 2 をインストールします :

To get started install TensorFlow 2.0 and Sonnet 2:

$ pip install tensorflow-gpu tensorflow-probability
$ pip install dm-sonnet

正しくインストールされたかを検証するに以下を実行できます :

import tensorflow as tf
import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))

 

既存のモジュールを使用する

Sonnet は自明に利用可能な幾つかの組込みモジュールとともに出荷されます。例えば MLP を定義するためにモジュールのシークエンスを呼び出すために snt.Sequential モジュールを利用できて、与えられたモジュールの出力を次のモジュールの入力として渡せます。計算を実際に定義するために snt.Linear と tf.nn.relu を利用できます :

mlp = snt.Sequential([
    snt.Linear(1024),
    tf.nn.relu,
    snt.Linear(10),
])

モジュールを使用するにはそれを「呼び出す」必要があります。Sequential モジュール (そして殆どのモジュール) は __call__ メソッドを定義します、これはそれらを名前で呼び出せることを意味します :

logits = mlp(tf.random.normal([batch_size, input_size]))

貴方のモジュールのためのパラメータ総てをリクエストすることも非常に一般的です。Sonnet の殆どのモジュールはそれらのパラメータをある入力で最初に呼び出されるとき作成します (何故ならば殆どの場合パラメータの shape は入力の関数であるからです)。Sonnet モジュールはパラメータにアクセスするために 2 つのプロパティを提供します。

variables プロパティは与えられたモジュールにより参照される 総ての tf.Variables を返します :

all_variables = mlp.variables

注目すべき点は tf.Variables は単に貴方のモデルのパラメータのために使用されるだけではないことです。例えばそれらは snt.BatchNorm で使用されるメトリクスの状態を保持するために使用されます。殆どの場合ユーザはモジュール変数を取得してそれらを更新されるために optimizer に渡します。この場合非訓練可能変数は典型的にはそのリストにあるべきではありません、何故ならばそれらは異なるメカニズムを通して更新されるからです。TensorFlow は変数を「訓練可能」 (モデルのパラメータ) vs. 非訓練可能 (他の変数) として印をつける組込みメカニズムを持ちます。Sonnet はモジュールから総ての訓練可能な変数を集めるメカニズムを提供します、これは多分貴方が optimizer に渡すことを望むものです :

model_parameters = mlp.trainable_variables

 

貴方自身のモジュールを構築する

Sonnet はユーザに自身のモジュールを定義するために snt.Module をサブクラス化することを強く奨励します。MyLinear と呼ばれる単純な線形層を作成することから始めましょう :

class MyLinear(snt.Module):

  def __init__(self, output_size, name=None):
    super(MyLinear, self).__init__(name=name)
    self.output_size = output_size

  @snt.once
  def _initialize(self, x):
    initial_w = tf.random.normal([x.shape[1], self.output_size])
    self.w = tf.Variable(initial_w, name="w")
    self.b = tf.Variable(tf.zeros([self.output_size]), name="b")

  def __call__(self, x):
    self._initialize(x)
    return tf.matmul(x, self.w) + self.b

このモジュールの使用は自明です :

mod = MyLinear(32)
mod(tf.ones([batch_size, input_size]))

snt.Module をサブクラス化することにより多くの素晴らしいプロパティをただで得ます。例えば __repr__ のデフォルト実装です、これはコンストラクタ引数を示します (デバッグと内省のために非常に有用です) :

>>> print(repr(mod))
MyLinear(output_size=10)

variables と trainable_variables プロパティもまた得ます :

>>> mod.variables
(<tf.Variable 'my_linear/b:0' shape=(10,) ...)>,
 <tf.Variable 'my_linear/w:0' shape=(1, 10) ...)>)

上の variables 上の my_linear prefix に気付くかもしれません。これは Sonnet モジュールもメソッドが呼び出されるときはいつでもモジュール名前空間に入るためです。モジュール名前空間に入ることにより消費する TensorBoard のようなツールのための遥かにより有用なグラフを提供します (e.g. my_linear 内で発生する総ての演算は my_linear と呼ばれるグループにあります)。更にモジュールは今では TensorFlow チェックポイントと saved モデルをサポートします、これは後でカバーされる進んだ特徴です。

 

シリアライゼーション

Sonnet は複数のシリアライゼーション形式をサポートします。サポートする最も単純な形式は Python の pickle です、そして総ての組込みモジュールは同じ Python プロセスで pickle を通してセーブ/ロードできることを確実にするためにテストされています。一般には pickle の利用は推奨されません、それは TensorFlow の多くのパートで上手くサポートされません、そして経験的に非常に不安定である可能性があります。

 

TensorFlow チェックポイント

参照: https://www.tensorflow.org/guide/checkpoint

訓練の間に定期的にパラメータ値をセーブするために TensorFlow チェックポイントが利用できます。これは、貴方のプログラムがクラッシュするか停止する場合に訓練の進捗をセーブするために有用であり得ます。Sonnet は TensorFlow チェックポイントとともにきれいに動作するように設計されています :

checkpoint_root = "/tmp/checkpoints"
checkpoint_name = "example"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)

my_module = create_my_sonnet_module()  # Can be anything extending snt.Module.

# A `Checkpoint` object manages checkpointing of the TensorFlow state associated
# with the objects passed to it's constructor. Note that Checkpoint supports
# restore on create, meaning that the variables of `my_module` do **not** need
# to be created before you restore from a checkpoint (their value will be
# restored when they are created).
checkpoint = tf.train.Checkpoint(module=my_module)

# Most training scripts will want to restore from a checkpoint if one exists. This
# would be the case if you interrupted your training (e.g. to use your GPU for
# something else, or in a cloud environment if your instance is preempted).
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
  checkpoint.restore(latest)

for step_num in range(num_steps):
  train(my_module)

  # During training we will occasionally save the values of weights. Note that
  # this is a blocking call and can be slow (typically we are writing to the
  # slowest storage on the machine). If you have a more reliable setup it might be
  # appropriate to save less frequently.
  if step_num and not step_num % 1000:
    checkpoint.save(save_prefix)

# Make sure to save your final values!!
checkpoint.save(save_prefix)

 

TensorFlow Saved モデル

参照: https://www.tensorflow.org/guide/saved_model

TensorFlow saved モデルはネットワークのコピーをセーブするために利用できます、これはそのための Python ソースから切り離されています。これは計算を記述する TensorFlow グラフと重みの値を含むチェックポイントをセーブすることにより可能になります。saved モデルを作成するために行なう最初のことはセーブすることを望む snt.Module を作成することです :

my_module = snt.nets.MLP([1024, 1024, 10])
my_module(tf.ones([1, input_size]))

次に、エクスポートすることを望むモデルの特定のパーツを記述するもう一つのモジュールを作成する必要があります。(元のモデルを in-place で変更するよりも) これを行なうことを勧めます、そうすれば実際にエクスポートされるものに渡る極め細かい制御を持ちます。これは非常に大きい saved モデルを作成することを回避するために典型的には重要です、そしてそのようなものとして貴方が望むモデルのパーツを共有するだけです (e.g. GAN のために generator を共有することを望むだけで discriminator は private に保持します)。

@tf.function(input_signature=[tf.TensorSpec([None, input_size])])
def inference(x):
  return my_module(x)

to_save = snt.Module()
to_save.inference = inference
to_save.all_variables = list(my_module.variables)
tf.saved_model.save(to_save, "/tmp/example_saved_model")

今は /tmp/example_saved_model フォルダで saved モデルを持ちます :

$ ls -lh /tmp/example_saved_model
total 24K
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:14 assets
-rw-rw-r-- 1 tomhennigan 154432098  14K Apr 28 00:15 saved_model.pb
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:15 variables

このモデルのロードは単純で、saved モデルを構築したどのような Python コードもなしに異なるマシン上で成されます :

loaded = tf.saved_model.load("/tmp/example_saved_model")

# Use the inference method. Note this doesn't run the Python code from `to_save`
# but instead uses the TensorFlow Graph that is part of the saved model.
loaded.inference(tf.ones([1, input_size]))

# The all_variables property can be used to retrieve the restored variables.
assert len(loaded.all_variables) > 0

ロードされたオブジェクトは Sonnet モジュールではないことに注意してください、それは前のブロックで追加した特定のメソッド (e.g. inference) とプロパティ (e.g. all_variables) を持つコンテナ・オブジェクトです。

 

分散訓練

サンプル: https://github.com/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb

Sonnet は カスタム TensorFlow 分散ストラテジー を使用して分散訓練のためのサポートを持ちます。

Sonnet と tf.keras を使用する分散訓練の間の主要な違いは Sonnet モジュールと optimizer は分散ストラテジーで動作するとき異なる動作をしないことです (e.g. 勾配を平均したりバッチ norm スタッツを同期しません)。ユーザは訓練のこれらの様相の完全な制御にあるべきでライブラリに焼き固められるべきではないと信じます。ここでのトレードオフはこれらの特徴を貴方の訓練スクリプトで実装するか (optimizer を適用する前に勾配を総て減じるために典型的にはこれは単に 2 行のコードです) 明示的に分散 aware なモジュール (e.g. snt.distribute.CrossReplicaBatchNorm) と交換する必要があります。

分散 Cifar-10 サンプルは Sonnet でマルチ GPU 訓練を行なうことをウォークスルーします。

 

以上






TensorFlow : Graph Nets : グラフの最短経路を見つける

TensorFlow : Graph Nets : グラフの最短経路を見つける (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/07/2019

* 本ページは、deepmind/graph_nets の README.md 及び “Find the shortest path in a graph” を翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

Graph Nets : README.md

Graph Nets は TensorFlow と Sonnet でグラフネットワークを構築するための DeepMind のライブラリです。

グラフネットワークとは何でしょう?

グラフネットワークは入力としてグラフを取り出力としてグラフを返します。入力グラフはエッジ- (E)、ノード- (V) とグローバルレベル (u) 属性を持ちます。出力グラフは同じ構造、しかし更新された属性を持ちます。グラフネットワークは「グラフニューラルネットワーク」のより広いファミリーの一部です (Scarselli et al., 2009)。

グラフネットワークについてより学習するためには、私達の arXiv ペーパー: Relational inductive biases, deep learning, and graph networks を見てください。

 

使用方法サンプル

次のコードは単純なグラフネット・モジュールを構築してそれをデータに接続します。

import graph_nets as gn
import sonnet as snt

# Provide your own functions to generate graph-structured data.
input_graphs = get_graphs()

# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=lambda: snt.nets.MLP([32, 32]),
    node_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 32]))

# Pass the input graphs to the graph network, and return the output graphs.
output_graphs = graph_net_module(input_graphs)

 
 

Graph Nets : グラフの最短経路を見つける

「最短経路デモ」はランダムグラフを作成して、任意の 2 つのノードの間の最短経路上のノードとエッジをラベル付けするためにグラフネットワークを訓練します。メッセージパッシングのステップのシークエンスに渡り (各ステップのプロットで描かれるように)、モデルは最短経路のその予測を改良していきます。

このノートブックと伴うコードはグラフの 2 つのノード間に最短経路を予測することを学習するために Graph Nets ライブラリをどのように使用するかを示します。

開始と終了ノードが与えられたとき、ネットワークは最短経路のノードとエッジをラベル付けするために訓練されます。

訓練後、ネットワークの予測能力はその出力を真の最短経路と比較することにより示されます。それから汎化するためのネットワークの能力がテストされます、類似のしかしより巨大なグラフの最短経路を予測するためにそれを使用することによって。

 

インポート

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import itertools
import time

from graph_nets import graphs
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.demos import models
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from scipy import spatial
import tensorflow as tf

SEED = 1
np.random.seed(SEED)
tf.set_random_seed(SEED)

 

ヘルパー関数

DISTANCE_WEIGHT_NAME = "distance"  # The name for the distance edge attribute.


def pairwise(iterable):
  """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
  a, b = itertools.tee(iterable)
  next(b, None)
  return zip(a, b)


def set_diff(seq0, seq1):
  """Return the set difference between 2 sequences as a list."""
  return list(set(seq0) - set(seq1))


def to_one_hot(indices, max_value, axis=-1):
  one_hot = np.eye(max_value)[indices]
  if axis not in (-1, one_hot.ndim):
    one_hot = np.moveaxis(one_hot, -1, axis)
  return one_hot


def get_node_dict(graph, attr):
  """Return a `dict` of node:attribute pairs from a graph."""
  return {k: v[attr] for k, v in graph.node.items()}


def generate_graph(rand,
                   num_nodes_min_max,
                   dimensions=2,
                   theta=1000.0,
                   rate=1.0):
  """Creates a connected graph.

  The graphs are geographic threshold graphs, but with added edges via a
  minimum spanning tree algorithm, to ensure all nodes are connected.

  Args:
    rand: A random seed for the graph generator. Default= None.
    num_nodes_min_max: A sequence [lower, upper) number of nodes per graph.
    dimensions: (optional) An `int` number of dimensions for the positions.
      Default= 2.
    theta: (optional) A `float` threshold parameters for the geographic
      threshold graph's threshold. Large values (1000+) make mostly trees. Try
      20-60 for good non-trees. Default=1000.0.
    rate: (optional) A rate parameter for the node weight exponential sampling
      distribution. Default= 1.0.

  Returns:
    The graph.
  """
  # Sample num_nodes.
  num_nodes = rand.randint(*num_nodes_min_max)

  # Create geographic threshold graph.
  pos_array = rand.uniform(size=(num_nodes, dimensions))
  pos = dict(enumerate(pos_array))
  weight = dict(enumerate(rand.exponential(rate, size=num_nodes)))
  geo_graph = nx.geographical_threshold_graph(
      num_nodes, theta, pos=pos, weight=weight)

  # Create minimum spanning tree across geo_graph's nodes.
  distances = spatial.distance.squareform(spatial.distance.pdist(pos_array))
  i_, j_ = np.meshgrid(range(num_nodes), range(num_nodes), indexing="ij")
  weighted_edges = list(zip(i_.ravel(), j_.ravel(), distances.ravel()))
  mst_graph = nx.Graph()
  mst_graph.add_weighted_edges_from(weighted_edges, weight=DISTANCE_WEIGHT_NAME)
  mst_graph = nx.minimum_spanning_tree(mst_graph, weight=DISTANCE_WEIGHT_NAME)
  # Put geo_graph's node attributes into the mst_graph.
  for i in mst_graph.nodes():
    mst_graph.node[i].update(geo_graph.node[i])

  # Compose the graphs.
  combined_graph = nx.compose_all((mst_graph, geo_graph.copy()))
  # Put all distance weights into edge attributes.
  for i, j in combined_graph.edges():
    combined_graph.get_edge_data(i, j).setdefault(DISTANCE_WEIGHT_NAME,
                                                  distances[i, j])
  return combined_graph, mst_graph, geo_graph


def add_shortest_path(rand, graph, min_length=1):
  """Samples a shortest path from A to B and adds attributes to indicate it.

  Args:
    rand: A random seed for the graph generator. Default= None.
    graph: A `nx.Graph`.
    min_length: (optional) An `int` minimum number of edges in the shortest
      path. Default= 1.

  Returns:
    The `nx.DiGraph` with the shortest path added.

  Raises:
    ValueError: All shortest paths are below the minimum length
  """
  # Map from node pairs to the length of their shortest path.
  pair_to_length_dict = {}
  try:
    # This is for compatibility with older networkx.
    lengths = nx.all_pairs_shortest_path_length(graph).items()
  except AttributeError:
    # This is for compatibility with newer networkx.
    lengths = list(nx.all_pairs_shortest_path_length(graph))
  for x, yy in lengths:
    for y, l in yy.items():
      if l >= min_length:
        pair_to_length_dict[x, y] = l
  if max(pair_to_length_dict.values()) < min_length:
    raise ValueError("All shortest paths are below the minimum length")
  # The node pairs which exceed the minimum length.
  node_pairs = list(pair_to_length_dict)

  # Computes probabilities per pair, to enforce uniform sampling of each
  # shortest path lengths.
  # The counts of pairs per length.
  counts = collections.Counter(pair_to_length_dict.values())
  prob_per_length = 1.0 / len(counts)
  probabilities = [
      prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
  ]

  # Choose the start and end points.
  i = rand.choice(len(node_pairs), p=probabilities)
  start, end = node_pairs[i]
  path = nx.shortest_path(
      graph, source=start, target=end, weight=DISTANCE_WEIGHT_NAME)

  # Creates a directed graph, to store the directed path from start to end.
  digraph = graph.to_directed()

  # Add the "start", "end", and "solution" attributes to the nodes and edges.
  digraph.add_node(start, start=True)
  digraph.add_node(end, end=True)
  digraph.add_nodes_from(set_diff(digraph.nodes(), [start]), start=False)
  digraph.add_nodes_from(set_diff(digraph.nodes(), [end]), end=False)
  digraph.add_nodes_from(set_diff(digraph.nodes(), path), solution=False)
  digraph.add_nodes_from(path, solution=True)
  path_edges = list(pairwise(path))
  digraph.add_edges_from(set_diff(digraph.edges(), path_edges), solution=False)
  digraph.add_edges_from(path_edges, solution=True)

  return digraph


def graph_to_input_target(graph):
  """Returns 2 graphs with input and target feature vectors for training.

  Args:
    graph: An `nx.DiGraph` instance.

  Returns:
    The input `nx.DiGraph` instance.
    The target `nx.DiGraph` instance.

  Raises:
    ValueError: unknown node type
  """

  def create_feature(attr, fields):
    return np.hstack([np.array(attr[field], dtype=float) for field in fields])

  input_node_fields = ("pos", "weight", "start", "end")
  input_edge_fields = ("distance",)
  target_node_fields = ("solution",)
  target_edge_fields = ("solution",)

  input_graph = graph.copy()
  target_graph = graph.copy()

  solution_length = 0
  for node_index, node_feature in graph.nodes(data=True):
    input_graph.add_node(
        node_index, features=create_feature(node_feature, input_node_fields))
    target_node = to_one_hot(
        create_feature(node_feature, target_node_fields).astype(int), 2)[0]
    target_graph.add_node(node_index, features=target_node)
    solution_length += int(node_feature["solution"])
  solution_length /= graph.number_of_nodes()

  for receiver, sender, features in graph.edges(data=True):
    input_graph.add_edge(
        sender, receiver, features=create_feature(features, input_edge_fields))
    target_edge = to_one_hot(
        create_feature(features, target_edge_fields).astype(int), 2)[0]
    target_graph.add_edge(sender, receiver, features=target_edge)

  input_graph.graph["features"] = np.array([0.0])
  target_graph.graph["features"] = np.array([solution_length], dtype=float)

  return input_graph, target_graph


def generate_networkx_graphs(rand, num_examples, num_nodes_min_max, theta):
  """Generate graphs for training.

  Args:
    rand: A random seed (np.RandomState instance).
    num_examples: Total number of graphs to generate.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: (optional) A `float` threshold parameters for the geographic
      threshold graph's threshold. Default= the number of nodes.

  Returns:
    input_graphs: The list of input graphs.
    target_graphs: The list of output graphs.
    graphs: The list of generated graphs.
  """
  input_graphs = []
  target_graphs = []
  graphs = []
  for _ in range(num_examples):
    graph = generate_graph(rand, num_nodes_min_max, theta=theta)[0]
    graph = add_shortest_path(rand, graph)
    input_graph, target_graph = graph_to_input_target(graph)
    input_graphs.append(input_graph)
    target_graphs.append(target_graph)
    graphs.append(graph)
  return input_graphs, target_graphs, graphs


def create_placeholders(rand, batch_size, num_nodes_min_max, theta):
  """Creates placeholders for the model training and evaluation.

  Args:
    rand: A random seed (np.RandomState instance).
    batch_size: Total number of graphs per batch.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: A `float` threshold parameters for the geographic threshold graph's
      threshold. Default= the number of nodes.

  Returns:
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.
  """
  # Create some example data for inspecting the vector sizes.
  input_graphs, target_graphs, _ = generate_networkx_graphs(
      rand, batch_size, num_nodes_min_max, theta)
  input_ph = utils_tf.placeholders_from_networkxs(input_graphs)
  target_ph = utils_tf.placeholders_from_networkxs(target_graphs)
  return input_ph, target_ph


def create_feed_dict(rand, batch_size, num_nodes_min_max, theta, input_ph,
                     target_ph):
  """Creates placeholders for the model training and evaluation.

  Args:
    rand: A random seed (np.RandomState instance).
    batch_size: Total number of graphs per batch.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: A `float` threshold parameters for the geographic threshold graph's
      threshold. Default= the number of nodes.
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.

  Returns:
    feed_dict: The feed `dict` of input and target placeholders and data.
    raw_graphs: The `dict` of raw networkx graphs.
  """
  inputs, targets, raw_graphs = generate_networkx_graphs(
      rand, batch_size, num_nodes_min_max, theta)
  input_graphs = utils_np.networkxs_to_graphs_tuple(inputs)
  target_graphs = utils_np.networkxs_to_graphs_tuple(targets)
  feed_dict = {input_ph: input_graphs, target_ph: target_graphs}
  return feed_dict, raw_graphs


def compute_accuracy(target, output, use_nodes=True, use_edges=False):
  """Calculate model accuracy.

  Returns the number of correctly predicted shortest path nodes and the number
  of completely solved graphs (100% correct predictions).

  Args:
    target: A `graphs.GraphsTuple` that contains the target graph.
    output: A `graphs.GraphsTuple` that contains the output graph.
    use_nodes: A `bool` indicator of whether to compute node accuracy or not.
    use_edges: A `bool` indicator of whether to compute edge accuracy or not.

  Returns:
    correct: A `float` fraction of correctly labeled nodes/edges.
    solved: A `float` fraction of graphs that are completely correctly labeled.

  Raises:
    ValueError: Nodes or edges (or both) must be used
  """
  if not use_nodes and not use_edges:
    raise ValueError("Nodes or edges (or both) must be used")
  tdds = utils_np.graphs_tuple_to_data_dicts(target)
  odds = utils_np.graphs_tuple_to_data_dicts(output)
  cs = []
  ss = []
  for td, od in zip(tdds, odds):
    xn = np.argmax(td["nodes"], axis=-1)
    yn = np.argmax(od["nodes"], axis=-1)
    xe = np.argmax(td["edges"], axis=-1)
    ye = np.argmax(od["edges"], axis=-1)
    c = []
    if use_nodes:
      c.append(xn == yn)
    if use_edges:
      c.append(xe == ye)
    c = np.concatenate(c, axis=0)
    s = np.all(c)
    cs.append(c)
    ss.append(s)
  correct = np.mean(np.concatenate(cs, axis=0))
  solved = np.mean(np.stack(ss))
  return correct, solved


def create_loss_ops(target_op, output_ops):
  loss_ops = [
      tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +
      tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)
      for output_op in output_ops
  ]
  return loss_ops


def make_all_runnable_in_session(*args):
  """Lets an iterable of TF graphs be output from a session as NP graphs."""
  return [utils_tf.make_runnable_in_session(a) for a in args]


class GraphPlotter(object):

  def __init__(self, ax, graph, pos):
    self._ax = ax
    self._graph = graph
    self._pos = pos
    self._base_draw_kwargs = dict(G=self._graph, pos=self._pos, ax=self._ax)
    self._solution_length = None
    self._nodes = None
    self._edges = None
    self._start_nodes = None
    self._end_nodes = None
    self._solution_nodes = None
    self._intermediate_solution_nodes = None
    self._solution_edges = None
    self._non_solution_nodes = None
    self._non_solution_edges = None
    self._ax.set_axis_off()

  @property
  def solution_length(self):
    if self._solution_length is None:
      self._solution_length = len(self._solution_edges)
    return self._solution_length

  @property
  def nodes(self):
    if self._nodes is None:
      self._nodes = self._graph.nodes()
    return self._nodes

  @property
  def edges(self):
    if self._edges is None:
      self._edges = self._graph.edges()
    return self._edges

  @property
  def start_nodes(self):
    if self._start_nodes is None:
      self._start_nodes = [
          n for n in self.nodes if self._graph.node[n].get("start", False)
      ]
    return self._start_nodes

  @property
  def end_nodes(self):
    if self._end_nodes is None:
      self._end_nodes = [
          n for n in self.nodes if self._graph.node[n].get("end", False)
      ]
    return self._end_nodes

  @property
  def solution_nodes(self):
    if self._solution_nodes is None:
      self._solution_nodes = [
          n for n in self.nodes if self._graph.node[n].get("solution", False)
      ]
    return self._solution_nodes

  @property
  def intermediate_solution_nodes(self):
    if self._intermediate_solution_nodes is None:
      self._intermediate_solution_nodes = [
          n for n in self.nodes
          if self._graph.node[n].get("solution", False) and
          not self._graph.node[n].get("start", False) and
          not self._graph.node[n].get("end", False)
      ]
    return self._intermediate_solution_nodes

  @property
  def solution_edges(self):
    if self._solution_edges is None:
      self._solution_edges = [
          e for e in self.edges
          if self._graph.get_edge_data(e[0], e[1]).get("solution", False)
      ]
    return self._solution_edges

  @property
  def non_solution_nodes(self):
    if self._non_solution_nodes is None:
      self._non_solution_nodes = [
          n for n in self.nodes
          if not self._graph.node[n].get("solution", False)
      ]
    return self._non_solution_nodes

  @property
  def non_solution_edges(self):
    if self._non_solution_edges is None:
      self._non_solution_edges = [
          e for e in self.edges
          if not self._graph.get_edge_data(e[0], e[1]).get("solution", False)
      ]
    return self._non_solution_edges

  def _make_draw_kwargs(self, **kwargs):
    kwargs.update(self._base_draw_kwargs)
    return kwargs

  def _draw(self, draw_function, zorder=None, **kwargs):
    draw_kwargs = self._make_draw_kwargs(**kwargs)
    collection = draw_function(**draw_kwargs)
    if collection is not None and zorder is not None:
      try:
        # This is for compatibility with older matplotlib.
        collection.set_zorder(zorder)
      except AttributeError:
        # This is for compatibility with newer matplotlib.
        collection[0].set_zorder(zorder)
    return collection

  def draw_nodes(self, **kwargs):
    """Useful kwargs: nodelist, node_size, node_color, linewidths."""
    if ("node_color" in kwargs and
        isinstance(kwargs["node_color"], collections.Sequence) and
        len(kwargs["node_color"]) in {3, 4} and
        not isinstance(kwargs["node_color"][0],
                       (collections.Sequence, np.ndarray))):
      num_nodes = len(kwargs.get("nodelist", self.nodes))
      kwargs["node_color"] = np.tile(
          np.array(kwargs["node_color"])[None], [num_nodes, 1])
    return self._draw(nx.draw_networkx_nodes, **kwargs)

  def draw_edges(self, **kwargs):
    """Useful kwargs: edgelist, width."""
    return self._draw(nx.draw_networkx_edges, **kwargs)

  def draw_graph(self,
                 node_size=200,
                 node_color=(0.4, 0.8, 0.4),
                 node_linewidth=1.0,
                 edge_width=1.0):
    # Plot nodes.
    self.draw_nodes(
        nodelist=self.nodes,
        node_size=node_size,
        node_color=node_color,
        linewidths=node_linewidth,
        zorder=20)
    # Plot edges.
    self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)

  def draw_graph_with_solution(self,
                               node_size=200,
                               node_color=(0.4, 0.8, 0.4),
                               node_linewidth=1.0,
                               edge_width=1.0,
                               start_color="w",
                               end_color="k",
                               solution_node_linewidth=3.0,
                               solution_edge_width=3.0):
    node_border_color = (0.0, 0.0, 0.0, 1.0)
    node_collections = {}
    # Plot start nodes.
    node_collections["start nodes"] = self.draw_nodes(
        nodelist=self.start_nodes,
        node_size=node_size,
        node_color=start_color,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=100)
    # Plot end nodes.
    node_collections["end nodes"] = self.draw_nodes(
        nodelist=self.end_nodes,
        node_size=node_size,
        node_color=end_color,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=90)
    # Plot intermediate solution nodes.
    if isinstance(node_color, dict):
      c = [node_color[n] for n in self.intermediate_solution_nodes]
    else:
      c = node_color
    node_collections["intermediate solution nodes"] = self.draw_nodes(
        nodelist=self.intermediate_solution_nodes,
        node_size=node_size,
        node_color=c,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=80)
    # Plot solution edges.
    node_collections["solution edges"] = self.draw_edges(
        edgelist=self.solution_edges, width=solution_edge_width, zorder=70)
    # Plot non-solution nodes.
    if isinstance(node_color, dict):
      c = [node_color[n] for n in self.non_solution_nodes]
    else:
      c = node_color
    node_collections["non-solution nodes"] = self.draw_nodes(
        nodelist=self.non_solution_nodes,
        node_size=node_size,
        node_color=c,
        linewidths=node_linewidth,
        edgecolors=node_border_color,
        zorder=20)
    # Plot non-solution edges.
    node_collections["non-solution edges"] = self.draw_edges(
        edgelist=self.non_solution_edges, width=edge_width, zorder=10)
    # Set title as solution length.
    self._ax.set_title("Solution length: {}".format(self.solution_length))
    return node_collections

 

サンプルグラフを可視化する

seed = 1  #@param{type: 'integer'}
rand = np.random.RandomState(seed=seed)

num_examples = 15  #@param{type: 'integer'}
# Large values (1000+) make trees. Try 20-60 for good non-trees.
theta = 20  #@param{type: 'integer'}
num_nodes_min_max = (16, 17)

input_graphs, target_graphs, graphs = generate_networkx_graphs(
    rand, num_examples, num_nodes_min_max, theta)

num = min(num_examples, 16)
w = 3
h = int(np.ceil(num / w))
fig = plt.figure(40, figsize=(w * 4, h * 4))
fig.clf()
for j, graph in enumerate(graphs):
  ax = fig.add_subplot(h, w, j + 1)
  pos = get_node_dict(graph, "pos")
  plotter = GraphPlotter(ax, graph, pos)
  plotter.draw_graph_with_solution()

 

モデル訓練と評価をセットアップする

# The model we explore includes three components:
# - An "Encoder" graph net, which independently encodes the edge, node, and
#   global attributes (does not compute relations etc.).
# - A "Core" graph net, which performs N rounds of processing (message-passing)
#   steps. The input to the Core is the concatenation of the Encoder's output
#   and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
#   the processing step).
# - A "Decoder" graph net, which independently decodes the edge, node, and
#   global attributes (does not compute relations etc.), on each
#   message-passing step.
#
#                     Hidden(t)   Hidden(t+1)
#                        |            ^
#           *---------*  |  *------*  |  *---------*
#           |         |  |  |      |  |  |         |
# Input --->| Encoder |  *->| Core |--*->| Decoder |---> Output(t)
#           |         |---->|      |     |         |
#           *---------*     *------*     *---------*
#
# The model is trained by supervised learning. Input graphs are procedurally
# generated, and output graphs have the same structure with the nodes and edges
# of the shortest path labeled (using 2-element 1-hot vectors). We could have
# predicted the shortest path only by labeling either the nodes or edges, and
# that does work, but we decided to predict both to demonstrate the flexibility
# of graph nets' outputs.
#
# The training loss is computed on the output of each processing step. The
# reason for this is to encourage the model to try to solve the problem in as
# few steps as possible. It also helps make the output of intermediate steps
# more interpretable.
#
# There's no need for a separate evaluate dataset because the inputs are
# never repeated, so the training loss is the measure of performance on graphs
# from the input distribution.
#
# We also evaluate how well the models generalize to graphs which are up to
# twice as large as those on which it was trained. The loss is computed only
# on the final processing step.
#
# Variables with the suffix _tr are training parameters, and variables with the
# suffix _ge are test/generalization parameters.
#
# After around 2000-5000 training iterations the model reaches near-perfect
# performance on graphs with between 8-16 nodes.

tf.reset_default_graph()

seed = 2
rand = np.random.RandomState(seed=seed)

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10

# Data / training parameters.
num_training_iterations = 10000
theta = 20  # Large values (1000+) make trees. Try 20-60 for good non-trees.
batch_size_tr = 32
batch_size_ge = 100
# Number of nodes per graph sampled uniformly from this range.
num_nodes_min_max_tr = (8, 17)
num_nodes_min_max_ge = (16, 33)

# Data.
# Input and target placeholders.
input_ph, target_ph = create_placeholders(rand, batch_size_tr,
                                          num_nodes_min_max_tr, theta)

# Connect the data to the model.
# Instantiate the model.
model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
# A list of outputs, one per processing step.
output_ops_tr = model(input_ph, num_processing_steps_tr)
output_ops_ge = model(input_ph, num_processing_steps_ge)

# Training loss.
loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
# Loss across processing steps.
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr
# Test/generalization loss.
loss_ops_ge = create_loss_ops(target_ph, output_ops_ge)
loss_op_ge = loss_ops_ge[-1]  # Loss from final processing step.

# Optimizer.
learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

# Lets an iterable of TF graphs be output from a session as NP graphs.
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)

 

セッションをリセットする

# This cell resets the Tensorflow session, but keeps the same computational
# graph.

try:
  sess.close()
except NameError:
  pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())

last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

 

訓練を実行する

# You can interrupt this cell's training loop at any time, and visualize the
# intermediate results by running the next cell (below). You can then resume
# training by simply executing this cell again.

# How much time between logging and printing the current results.
log_every_seconds = 20

print("# (iteration number), T (elapsed seconds), "
      "Ltr (training loss), Lge (test/generalization loss), "
      "Ctr (training fraction nodes/edges labeled correctly), "
      "Str (training fraction examples solved correctly), "
      "Cge (test/generalization fraction nodes/edges labeled correctly), "
      "Sge (test/generalization fraction examples solved correctly)")

start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
  last_iteration = iteration
  feed_dict, _ = create_feed_dict(rand, batch_size_tr, num_nodes_min_max_tr,
                                  theta, input_ph, target_ph)
  train_values = sess.run({
      "step": step_op,
      "target": target_ph,
      "loss": loss_op_tr,
      "outputs": output_ops_tr
  },
                          feed_dict=feed_dict)
  the_time = time.time()
  elapsed_since_last_log = the_time - last_log_time
  if elapsed_since_last_log > log_every_seconds:
    last_log_time = the_time
    feed_dict, raw_graphs = create_feed_dict(
        rand, batch_size_ge, num_nodes_min_max_ge, theta, input_ph, target_ph)
    test_values = sess.run({
        "target": target_ph,
        "loss": loss_op_ge,
        "outputs": output_ops_ge
    },
                           feed_dict=feed_dict)
    correct_tr, solved_tr = compute_accuracy(
        train_values["target"], train_values["outputs"][-1], use_edges=True)
    correct_ge, solved_ge = compute_accuracy(
        test_values["target"], test_values["outputs"][-1], use_edges=True)
    elapsed = time.time() - start_time
    losses_tr.append(train_values["loss"])
    corrects_tr.append(correct_tr)
    solveds_tr.append(solved_tr)
    losses_ge.append(test_values["loss"])
    corrects_ge.append(correct_ge)
    solveds_ge.append(solved_ge)
    logged_iterations.append(iteration)
    print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, Str"
          " {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
              iteration, elapsed, train_values["loss"], test_values["loss"],
              correct_tr, solved_tr, correct_ge, solved_ge))
# (iteration number), T (elapsed seconds), Ltr (training loss), Lge (test/generalization loss), Ctr (training fraction nodes/edges labeled correctly), Str (training fraction examples solved correctly), Cge (test/generalization fraction nodes/edges labeled correctly), Sge (test/generalization fraction examples solved correctly)
# 00029, T 23.8, Ltr 0.8731, Lge 0.6658, Ctr 0.8596, Str 0.0000, Cge 0.9481, Sge 0.0000
# 00078, T 42.1, Ltr 0.6341, Lge 0.4758, Ctr 0.9056, Str 0.0000, Cge 0.9549, Sge 0.0000
# 00133, T 62.5, Ltr 0.5034, Lge 0.3845, Ctr 0.9172, Str 0.0000, Cge 0.9625, Sge 0.0200
# 00189, T 82.7, Ltr 0.5162, Lge 0.3417, Ctr 0.9166, Str 0.1250, Cge 0.9664, Sge 0.0200
# 00244, T 103.0, Ltr 0.4486, Lge 0.3383, Ctr 0.9343, Str 0.1250, Cge 0.9685, Sge 0.1600
# 00299, T 123.1, Ltr 0.4963, Lge 0.3507, Ctr 0.9184, Str 0.2500, Cge 0.9637, Sge 0.1100
# 00354, T 143.4, Ltr 0.3223, Lge 0.2883, Ctr 0.9614, Str 0.4062, Cge 0.9721, Sge 0.3300
# 00407, T 163.5, Ltr 0.4604, Lge 0.3853, Ctr 0.9270, Str 0.2500, Cge 0.9585, Sge 0.0600
# 00461, T 183.7, Ltr 0.2822, Lge 0.2933, Ctr 0.9670, Str 0.5625, Cge 0.9702, Sge 0.3200
# 00517, T 203.8, Ltr 0.3703, Lge 0.2784, Ctr 0.9480, Str 0.4375, Cge 0.9698, Sge 0.2600
# 00571, T 224.1, Ltr 0.4301, Lge 0.2783, Ctr 0.9308, Str 0.3125, Cge 0.9723, Sge 0.2000
# 00626, T 244.1, Ltr 0.3287, Lge 0.2833, Ctr 0.9533, Str 0.4062, Cge 0.9687, Sge 0.2700
# 00682, T 264.4, Ltr 0.2802, Lge 0.2913, Ctr 0.9617, Str 0.5000, Cge 0.9703, Sge 0.3000
# 00736, T 284.7, Ltr 0.3474, Lge 0.2775, Ctr 0.9531, Str 0.5625, Cge 0.9704, Sge 0.1900
# 00790, T 305.1, Ltr 0.3098, Lge 0.3607, Ctr 0.9488, Str 0.4062, Cge 0.9690, Sge 0.1700
# 00844, T 324.9, Ltr 0.3092, Lge 0.2941, Ctr 0.9566, Str 0.4375, Cge 0.9702, Sge 0.2500
# 00899, T 345.1, Ltr 0.3805, Lge 0.2202, Ctr 0.9440, Str 0.2812, Cge 0.9770, Sge 0.3200
# 00953, T 365.0, Ltr 0.2927, Lge 0.2637, Ctr 0.9609, Str 0.5938, Cge 0.9707, Sge 0.1900
# 01008, T 385.0, Ltr 0.3164, Lge 0.2093, Ctr 0.9568, Str 0.4688, Cge 0.9749, Sge 0.3100
# 01063, T 405.1, Ltr 0.2704, Lge 0.2455, Ctr 0.9749, Str 0.5000, Cge 0.9719, Sge 0.2000
# 01117, T 425.2, Ltr 0.2696, Lge 0.2713, Ctr 0.9600, Str 0.6250, Cge 0.9719, Sge 0.2300
# 01172, T 445.3, Ltr 0.4089, Lge 0.2442, Ctr 0.9489, Str 0.5312, Cge 0.9727, Sge 0.2100
# 01227, T 465.3, Ltr 0.3053, Lge 0.2616, Ctr 0.9620, Str 0.5312, Cge 0.9733, Sge 0.2400
# 01280, T 485.2, Ltr 0.2292, Lge 0.2433, Ctr 0.9742, Str 0.6250, Cge 0.9703, Sge 0.3100
# 01335, T 505.6, Ltr 0.3238, Lge 0.2267, Ctr 0.9554, Str 0.5312, Cge 0.9748, Sge 0.3000
# 01390, T 525.7, Ltr 0.3662, Lge 0.2706, Ctr 0.9602, Str 0.5938, Cge 0.9720, Sge 0.2500
# 01445, T 545.8, Ltr 0.2444, Lge 0.2530, Ctr 0.9755, Str 0.6562, Cge 0.9732, Sge 0.2800
# 01498, T 565.8, Ltr 0.3119, Lge 0.3036, Ctr 0.9565, Str 0.5938, Cge 0.9708, Sge 0.2300
# 01552, T 585.9, Ltr 0.3058, Lge 0.2633, Ctr 0.9553, Str 0.4688, Cge 0.9717, Sge 0.2400
# 01606, T 606.1, Ltr 0.2392, Lge 0.2462, Ctr 0.9782, Str 0.6562, Cge 0.9726, Sge 0.3000
# 01661, T 626.4, Ltr 0.2917, Lge 0.2522, Ctr 0.9611, Str 0.5625, Cge 0.9725, Sge 0.3000
# 01716, T 646.7, Ltr 0.3049, Lge 0.2254, Ctr 0.9566, Str 0.5938, Cge 0.9749, Sge 0.3300
# 01771, T 666.8, Ltr 0.2509, Lge 0.2393, Ctr 0.9667, Str 0.6250, Cge 0.9747, Sge 0.3000
# 01824, T 687.1, Ltr 0.1827, Lge 0.1870, Ctr 0.9879, Str 0.7812, Cge 0.9781, Sge 0.3300
# 01879, T 707.2, Ltr 0.3511, Lge 0.2048, Ctr 0.9574, Str 0.6562, Cge 0.9775, Sge 0.4200
# 01935, T 727.6, Ltr 0.2784, Lge 0.2044, Ctr 0.9699, Str 0.5625, Cge 0.9752, Sge 0.2000
# 01990, T 747.7, Ltr 0.3216, Lge 0.1943, Ctr 0.9639, Str 0.6562, Cge 0.9768, Sge 0.2800
# 02044, T 768.0, Ltr 0.1950, Lge 0.1579, Ctr 0.9892, Str 0.8750, Cge 0.9820, Sge 0.4100
# 02098, T 788.0, Ltr 0.2075, Lge 0.1729, Ctr 0.9882, Str 0.7500, Cge 0.9799, Sge 0.3700
# 02153, T 808.0, Ltr 0.2118, Lge 0.1775, Ctr 0.9783, Str 0.7500, Cge 0.9804, Sge 0.3700
# 02207, T 828.4, Ltr 0.2426, Lge 0.1862, Ctr 0.9669, Str 0.6562, Cge 0.9797, Sge 0.3800
# 02262, T 848.7, Ltr 0.2076, Lge 0.1836, Ctr 0.9862, Str 0.8750, Cge 0.9792, Sge 0.4300
# 02317, T 869.0, Ltr 0.1890, Lge 0.1984, Ctr 0.9873, Str 0.8750, Cge 0.9767, Sge 0.3800
# 02371, T 889.2, Ltr 0.1652, Lge 0.1936, Ctr 0.9887, Str 0.8125, Cge 0.9784, Sge 0.3900
# 02426, T 909.4, Ltr 0.2751, Lge 0.1523, Ctr 0.9707, Str 0.6875, Cge 0.9823, Sge 0.4100
# 02481, T 929.4, Ltr 0.1775, Lge 0.1617, Ctr 0.9867, Str 0.8750, Cge 0.9788, Sge 0.3800
# 02536, T 949.6, Ltr 0.2007, Lge 0.1207, Ctr 0.9880, Str 0.8438, Cge 0.9857, Sge 0.4900
# 02591, T 969.8, Ltr 0.1514, Lge 0.1489, Ctr 0.9934, Str 0.9062, Cge 0.9813, Sge 0.3600
# 02646, T 989.9, Ltr 0.2410, Lge 0.1100, Ctr 0.9862, Str 0.8125, Cge 0.9854, Sge 0.4000
# 02702, T 1010.0, Ltr 0.1991, Lge 0.1578, Ctr 0.9827, Str 0.8125, Cge 0.9813, Sge 0.3800
# 02756, T 1030.1, Ltr 0.1464, Lge 0.1388, Ctr 0.9893, Str 0.8750, Cge 0.9814, Sge 0.4000
# 02811, T 1050.1, Ltr 0.1931, Lge 0.1588, Ctr 0.9833, Str 0.8750, Cge 0.9799, Sge 0.3900
# 02866, T 1070.1, Ltr 0.1570, Lge 0.1189, Ctr 0.9858, Str 0.8750, Cge 0.9858, Sge 0.5500
# 02920, T 1090.5, Ltr 0.1420, Lge 0.1113, Ctr 0.9922, Str 0.8750, Cge 0.9855, Sge 0.4500
# 02974, T 1110.8, Ltr 0.1550, Lge 0.1640, Ctr 0.9911, Str 0.8438, Cge 0.9809, Sge 0.3200
# 03029, T 1130.8, Ltr 0.1681, Lge 0.1297, Ctr 0.9936, Str 0.8750, Cge 0.9873, Sge 0.4900
# 03084, T 1151.5, Ltr 0.1810, Lge 0.1909, Ctr 0.9921, Str 0.8750, Cge 0.9785, Sge 0.2300
# 03139, T 1171.1, Ltr 0.2063, Lge 0.1209, Ctr 0.9818, Str 0.8125, Cge 0.9861, Sge 0.4400
# 03195, T 1191.2, Ltr 0.1340, Lge 0.1583, Ctr 0.9945, Str 0.8750, Cge 0.9789, Sge 0.2100
# 03251, T 1211.7, Ltr 0.1461, Lge 0.1520, Ctr 0.9943, Str 0.9062, Cge 0.9856, Sge 0.4800
# 03303, T 1231.8, Ltr 0.1694, Lge 0.1235, Ctr 0.9854, Str 0.7812, Cge 0.9852, Sge 0.4900
# 03359, T 1252.0, Ltr 0.1738, Lge 0.1222, Ctr 0.9852, Str 0.7812, Cge 0.9846, Sge 0.4400
# 03414, T 1272.2, Ltr 0.1498, Lge 0.1101, Ctr 0.9897, Str 0.8438, Cge 0.9867, Sge 0.4900
# 03468, T 1292.5, Ltr 0.1638, Lge 0.1573, Ctr 0.9894, Str 0.8438, Cge 0.9836, Sge 0.4500
# 03523, T 1312.6, Ltr 0.2194, Lge 0.1516, Ctr 0.9761, Str 0.7188, Cge 0.9846, Sge 0.4600
# 03578, T 1332.6, Ltr 0.1490, Lge 0.1425, Ctr 0.9874, Str 0.9375, Cge 0.9846, Sge 0.5300
# 03633, T 1353.0, Ltr 0.1951, Lge 0.0889, Ctr 0.9860, Str 0.8750, Cge 0.9883, Sge 0.5600
# 03686, T 1373.0, Ltr 0.1586, Lge 0.1016, Ctr 0.9900, Str 0.9062, Cge 0.9875, Sge 0.4900
# 03741, T 1393.4, Ltr 0.1404, Lge 0.1356, Ctr 0.9911, Str 0.8750, Cge 0.9855, Sge 0.5800
# 03794, T 1413.3, Ltr 0.1938, Lge 0.1298, Ctr 0.9852, Str 0.7812, Cge 0.9828, Sge 0.4300
# 03847, T 1433.7, Ltr 0.1412, Lge 0.1183, Ctr 0.9899, Str 0.8750, Cge 0.9870, Sge 0.5900
# 03901, T 1453.8, Ltr 0.1894, Lge 0.0941, Ctr 0.9842, Str 0.8438, Cge 0.9890, Sge 0.6000
# 03955, T 1473.8, Ltr 0.1605, Lge 0.0935, Ctr 0.9867, Str 0.8125, Cge 0.9876, Sge 0.5800
# 04008, T 1493.9, Ltr 0.1560, Lge 0.0700, Ctr 0.9886, Str 0.8438, Cge 0.9927, Sge 0.6900
# 04062, T 1514.0, Ltr 0.2174, Lge 0.1912, Ctr 0.9865, Str 0.7812, Cge 0.9780, Sge 0.3800
# 04115, T 1534.0, Ltr 0.1537, Lge 0.0797, Ctr 0.9852, Str 0.8438, Cge 0.9891, Sge 0.5900
# 04169, T 1554.1, Ltr 0.1586, Lge 0.1071, Ctr 0.9871, Str 0.8438, Cge 0.9864, Sge 0.6100
# 04223, T 1574.5, Ltr 0.1071, Lge 0.1316, Ctr 1.0000, Str 1.0000, Cge 0.9869, Sge 0.6200
# 04278, T 1594.7, Ltr 0.1270, Lge 0.1329, Ctr 0.9985, Str 0.9375, Cge 0.9850, Sge 0.5400
# 04333, T 1614.8, Ltr 0.1352, Lge 0.1023, Ctr 0.9929, Str 0.9375, Cge 0.9864, Sge 0.5300
# 04386, T 1635.2, Ltr 0.1423, Lge 0.0890, Ctr 0.9894, Str 0.8125, Cge 0.9875, Sge 0.4600
# 04440, T 1655.1, Ltr 0.1320, Lge 0.0963, Ctr 0.9994, Str 0.9688, Cge 0.9882, Sge 0.5900
# 04495, T 1675.5, Ltr 0.1603, Lge 0.1094, Ctr 0.9889, Str 0.8750, Cge 0.9876, Sge 0.4800
# 04548, T 1695.7, Ltr 0.1474, Lge 0.1107, Ctr 0.9949, Str 0.9375, Cge 0.9868, Sge 0.5600
# 04602, T 1715.5, Ltr 0.1608, Lge 0.1791, Ctr 0.9960, Str 0.8438, Cge 0.9811, Sge 0.3700
# 04656, T 1735.6, Ltr 0.1416, Lge 0.1130, Ctr 0.9899, Str 0.8438, Cge 0.9865, Sge 0.5600
# 04710, T 1755.5, Ltr 0.1868, Lge 0.1135, Ctr 0.9944, Str 0.9375, Cge 0.9862, Sge 0.5500
# 04764, T 1775.7, Ltr 0.1466, Lge 0.0730, Ctr 0.9916, Str 0.9375, Cge 0.9901, Sge 0.6600
# 04819, T 1795.9, Ltr 0.1147, Lge 0.0881, Ctr 0.9966, Str 0.9688, Cge 0.9906, Sge 0.6900
# 04874, T 1816.0, Ltr 0.1130, Lge 0.1065, Ctr 0.9987, Str 0.9688, Cge 0.9868, Sge 0.5900
# 04928, T 1836.3, Ltr 0.1979, Lge 0.0953, Ctr 0.9909, Str 0.8750, Cge 0.9885, Sge 0.5200
# 04982, T 1856.2, Ltr 0.1319, Lge 0.1024, Ctr 0.9929, Str 0.9062, Cge 0.9875, Sge 0.6000
# 05036, T 1876.4, Ltr 0.1575, Lge 0.0744, Ctr 0.9910, Str 0.9062, Cge 0.9914, Sge 0.6300
# 05090, T 1896.6, Ltr 0.1387, Lge 0.1054, Ctr 0.9938, Str 0.9062, Cge 0.9877, Sge 0.5800
# 05144, T 1916.8, Ltr 0.1196, Lge 0.1088, Ctr 0.9929, Str 0.8750, Cge 0.9857, Sge 0.4800
# 05198, T 1936.6, Ltr 0.1441, Lge 0.0758, Ctr 0.9912, Str 0.9375, Cge 0.9902, Sge 0.6900
# 05252, T 1957.3, Ltr 0.1296, Lge 0.1036, Ctr 0.9957, Str 0.8750, Cge 0.9909, Sge 0.7000
# 05304, T 1976.9, Ltr 0.1311, Lge 0.1073, Ctr 0.9956, Str 0.9062, Cge 0.9883, Sge 0.6700
# 05359, T 1997.1, Ltr 0.0968, Lge 0.1633, Ctr 1.0000, Str 1.0000, Cge 0.9860, Sge 0.5500
# 05413, T 2017.0, Ltr 0.1550, Lge 0.0875, Ctr 0.9903, Str 0.9062, Cge 0.9896, Sge 0.6600
# 05466, T 2037.4, Ltr 0.1204, Lge 0.2264, Ctr 0.9966, Str 0.9062, Cge 0.9834, Sge 0.6400
# 05519, T 2057.5, Ltr 0.1255, Lge 0.1421, Ctr 0.9953, Str 0.9375, Cge 0.9868, Sge 0.6100
# 05573, T 2077.5, Ltr 0.1255, Lge 0.0956, Ctr 0.9941, Str 0.9062, Cge 0.9900, Sge 0.6800
# 05628, T 2098.2, Ltr 0.1427, Lge 0.0803, Ctr 0.9953, Str 0.9062, Cge 0.9893, Sge 0.6500
# 05683, T 2118.4, Ltr 0.1344, Lge 0.0857, Ctr 0.9934, Str 0.9062, Cge 0.9909, Sge 0.6000
# 05739, T 2138.5, Ltr 0.1634, Lge 0.1224, Ctr 0.9931, Str 0.9375, Cge 0.9875, Sge 0.5600
# 05793, T 2158.9, Ltr 0.1784, Lge 0.0720, Ctr 0.9853, Str 0.8750, Cge 0.9905, Sge 0.5800
# 05847, T 2179.2, Ltr 0.1259, Lge 0.0641, Ctr 0.9975, Str 0.9688, Cge 0.9921, Sge 0.6800
# 05902, T 2199.3, Ltr 0.0840, Lge 0.1022, Ctr 0.9974, Str 0.9688, Cge 0.9880, Sge 0.5600
# 05957, T 2219.4, Ltr 0.1161, Lge 0.0861, Ctr 0.9978, Str 0.9688, Cge 0.9906, Sge 0.6900
# 06010, T 2239.6, Ltr 0.1470, Lge 0.0660, Ctr 0.9894, Str 0.8125, Cge 0.9906, Sge 0.6700
# 06065, T 2259.7, Ltr 0.1664, Lge 0.1401, Ctr 0.9937, Str 0.9375, Cge 0.9831, Sge 0.5000
# 06120, T 2279.9, Ltr 0.1733, Lge 0.0901, Ctr 0.9829, Str 0.8438, Cge 0.9880, Sge 0.5600
# 06173, T 2300.1, Ltr 0.1269, Lge 0.0721, Ctr 0.9936, Str 0.9375, Cge 0.9926, Sge 0.7600
# 06228, T 2320.4, Ltr 0.1716, Lge 0.0702, Ctr 0.9926, Str 0.9375, Cge 0.9908, Sge 0.6500
# 06282, T 2340.4, Ltr 0.1386, Lge 0.0545, Ctr 0.9975, Str 0.9688, Cge 0.9926, Sge 0.7100
# 06337, T 2360.9, Ltr 0.1400, Lge 0.0512, Ctr 0.9960, Str 0.9062, Cge 0.9926, Sge 0.6100
# 06391, T 2380.5, Ltr 0.1468, Lge 0.0791, Ctr 0.9965, Str 0.8750, Cge 0.9894, Sge 0.6300
# 06446, T 2400.6, Ltr 0.1655, Lge 0.0847, Ctr 0.9900, Str 0.8750, Cge 0.9897, Sge 0.5800
# 06500, T 2420.6, Ltr 0.1530, Lge 0.0538, Ctr 0.9878, Str 0.7812, Cge 0.9925, Sge 0.6800
# 06553, T 2440.5, Ltr 0.1442, Lge 0.0634, Ctr 0.9969, Str 0.9375, Cge 0.9919, Sge 0.7500
# 06609, T 2460.8, Ltr 0.0933, Lge 0.0678, Ctr 0.9987, Str 0.9688, Cge 0.9912, Sge 0.6900
# 06663, T 2481.0, Ltr 0.1460, Lge 0.0936, Ctr 0.9953, Str 0.9375, Cge 0.9879, Sge 0.5400
# 06716, T 2501.4, Ltr 0.1505, Lge 0.0685, Ctr 0.9941, Str 0.9375, Cge 0.9914, Sge 0.7200
# 06769, T 2521.4, Ltr 0.1400, Lge 0.0530, Ctr 0.9955, Str 0.9062, Cge 0.9931, Sge 0.7400
# 06823, T 2541.4, Ltr 0.1310, Lge 0.0799, Ctr 0.9936, Str 0.8750, Cge 0.9896, Sge 0.6900
# 06878, T 2561.7, Ltr 0.1640, Lge 0.0848, Ctr 0.9885, Str 0.7812, Cge 0.9884, Sge 0.5900
# 06931, T 2581.7, Ltr 0.1395, Lge 0.0783, Ctr 0.9954, Str 0.9375, Cge 0.9904, Sge 0.6400
# 06986, T 2601.9, Ltr 0.1150, Lge 0.0546, Ctr 0.9969, Str 0.9375, Cge 0.9923, Sge 0.7400
# 07041, T 2622.1, Ltr 0.0829, Lge 0.0574, Ctr 1.0000, Str 1.0000, Cge 0.9923, Sge 0.6700
# 07095, T 2642.4, Ltr 0.1719, Lge 0.1901, Ctr 0.9907, Str 0.9375, Cge 0.9819, Sge 0.2800
# 07149, T 2662.5, Ltr 0.2284, Lge 0.0478, Ctr 0.9789, Str 0.8438, Cge 0.9934, Sge 0.7100
# 07204, T 2682.8, Ltr 0.1277, Lge 0.0614, Ctr 0.9923, Str 0.8750, Cge 0.9914, Sge 0.6000
# 07256, T 2703.0, Ltr 0.2056, Lge 0.0849, Ctr 0.9938, Str 0.9062, Cge 0.9910, Sge 0.6400
# 07311, T 2723.1, Ltr 0.1456, Lge 0.0573, Ctr 0.9967, Str 0.9688, Cge 0.9924, Sge 0.6700
# 07366, T 2743.3, Ltr 0.1366, Lge 0.0878, Ctr 0.9993, Str 0.9688, Cge 0.9898, Sge 0.7200
# 07420, T 2764.0, Ltr 0.1349, Lge 0.0462, Ctr 0.9953, Str 0.9375, Cge 0.9948, Sge 0.7700
# 07472, T 2783.6, Ltr 0.1244, Lge 0.0604, Ctr 0.9955, Str 0.9375, Cge 0.9929, Sge 0.7300
# 07528, T 2803.8, Ltr 0.1206, Lge 0.0890, Ctr 1.0000, Str 1.0000, Cge 0.9875, Sge 0.5300
# 07583, T 2824.1, Ltr 0.1248, Lge 0.0860, Ctr 0.9993, Str 0.9688, Cge 0.9910, Sge 0.7300
# 07636, T 2844.1, Ltr 0.1737, Lge 0.1036, Ctr 0.9909, Str 0.9062, Cge 0.9891, Sge 0.6600
# 07689, T 2864.1, Ltr 0.1297, Lge 0.0718, Ctr 0.9974, Str 0.9375, Cge 0.9933, Sge 0.7400
# 07744, T 2884.4, Ltr 0.1139, Lge 0.0905, Ctr 0.9962, Str 0.9375, Cge 0.9888, Sge 0.5700
# 07797, T 2904.6, Ltr 0.1405, Lge 0.0703, Ctr 0.9975, Str 0.9062, Cge 0.9905, Sge 0.6900
# 07852, T 2924.8, Ltr 0.1078, Lge 0.0566, Ctr 0.9986, Str 0.9375, Cge 0.9926, Sge 0.7600
# 07906, T 2944.9, Ltr 0.1498, Lge 0.0772, Ctr 0.9923, Str 0.8750, Cge 0.9902, Sge 0.6400
# 07959, T 2965.1, Ltr 0.1378, Lge 0.0657, Ctr 0.9919, Str 0.9375, Cge 0.9919, Sge 0.6900
# 08012, T 2985.2, Ltr 0.1468, Lge 0.0639, Ctr 0.9872, Str 0.8438, Cge 0.9919, Sge 0.6600
# 08067, T 3005.5, Ltr 0.1473, Lge 0.0555, Ctr 0.9967, Str 0.9688, Cge 0.9930, Sge 0.7400
# 08121, T 3025.6, Ltr 0.0928, Lge 0.0502, Ctr 0.9963, Str 0.9375, Cge 0.9930, Sge 0.6500
# 08174, T 3045.9, Ltr 0.1561, Lge 0.0637, Ctr 0.9906, Str 0.8750, Cge 0.9929, Sge 0.7300
# 08229, T 3066.2, Ltr 0.1539, Lge 0.1363, Ctr 0.9885, Str 0.8438, Cge 0.9887, Sge 0.6600
# 08283, T 3086.3, Ltr 0.1270, Lge 0.0807, Ctr 0.9930, Str 0.9375, Cge 0.9889, Sge 0.5900
# 08337, T 3107.0, Ltr 0.1001, Lge 0.0721, Ctr 0.9948, Str 0.9375, Cge 0.9909, Sge 0.6900
# 08391, T 3126.5, Ltr 0.1344, Lge 0.0732, Ctr 0.9944, Str 0.9375, Cge 0.9917, Sge 0.7000
# 08446, T 3146.8, Ltr 0.1127, Lge 0.0597, Ctr 0.9994, Str 0.9688, Cge 0.9907, Sge 0.6600
# 08502, T 3167.0, Ltr 0.1328, Lge 0.0496, Ctr 0.9959, Str 0.9375, Cge 0.9928, Sge 0.7600
# 08556, T 3187.1, Ltr 0.1424, Lge 0.0844, Ctr 0.9953, Str 0.9062, Cge 0.9901, Sge 0.6900
# 08612, T 3207.4, Ltr 0.1434, Lge 0.0986, Ctr 0.9949, Str 0.9062, Cge 0.9863, Sge 0.5700
# 08666, T 3227.6, Ltr 0.1772, Lge 0.0893, Ctr 0.9865, Str 0.8438, Cge 0.9894, Sge 0.6100
# 08720, T 3247.8, Ltr 0.1491, Lge 0.0896, Ctr 0.9906, Str 0.9375, Cge 0.9893, Sge 0.6800
# 08775, T 3268.2, Ltr 0.1873, Lge 0.0815, Ctr 0.9895, Str 0.8438, Cge 0.9905, Sge 0.6900
# 08830, T 3288.4, Ltr 0.1128, Lge 0.0790, Ctr 0.9988, Str 0.9688, Cge 0.9881, Sge 0.6400
# 08882, T 3308.5, Ltr 0.1164, Lge 0.1097, Ctr 0.9957, Str 0.9688, Cge 0.9896, Sge 0.7200
# 08935, T 3328.4, Ltr 0.1509, Lge 0.0733, Ctr 0.9891, Str 0.8438, Cge 0.9901, Sge 0.6100
# 08990, T 3348.7, Ltr 0.1394, Lge 0.0357, Ctr 0.9969, Str 0.9375, Cge 0.9954, Sge 0.7800
# 09045, T 3368.9, Ltr 0.1263, Lge 0.0905, Ctr 0.9933, Str 0.8750, Cge 0.9882, Sge 0.6400
# 09098, T 3388.9, Ltr 0.1421, Lge 0.0697, Ctr 0.9929, Str 0.9062, Cge 0.9935, Sge 0.7900
# 09152, T 3409.1, Ltr 0.1357, Lge 0.0765, Ctr 0.9904, Str 0.7812, Cge 0.9904, Sge 0.6100
# 09207, T 3429.5, Ltr 0.1691, Lge 0.0696, Ctr 0.9950, Str 0.9688, Cge 0.9917, Sge 0.6700
# 09260, T 3449.4, Ltr 0.1421, Lge 0.0924, Ctr 0.9928, Str 0.9062, Cge 0.9896, Sge 0.6700
# 09315, T 3469.8, Ltr 0.1280, Lge 0.0687, Ctr 0.9941, Str 0.9375, Cge 0.9906, Sge 0.6900
# 09370, T 3490.1, Ltr 0.1428, Lge 0.0758, Ctr 0.9968, Str 0.9062, Cge 0.9920, Sge 0.7100
# 09423, T 3510.1, Ltr 0.1391, Lge 0.0665, Ctr 0.9956, Str 0.9062, Cge 0.9915, Sge 0.7100
# 09479, T 3530.1, Ltr 0.1644, Lge 0.1169, Ctr 0.9915, Str 0.9062, Cge 0.9883, Sge 0.6300
# 09535, T 3550.4, Ltr 0.1296, Lge 0.0786, Ctr 0.9898, Str 0.9062, Cge 0.9901, Sge 0.6600
# 09590, T 3570.6, Ltr 0.1611, Lge 0.0532, Ctr 0.9871, Str 0.8750, Cge 0.9927, Sge 0.7000
# 09644, T 3590.7, Ltr 0.1599, Lge 0.0585, Ctr 0.9943, Str 0.9375, Cge 0.9918, Sge 0.6900
# 09698, T 3610.8, Ltr 0.1409, Lge 0.0598, Ctr 0.9944, Str 0.8750, Cge 0.9926, Sge 0.7200
# 09754, T 3631.0, Ltr 0.1637, Lge 0.0523, Ctr 0.9900, Str 0.9375, Cge 0.9926, Sge 0.7300
# 09808, T 3651.2, Ltr 0.1112, Lge 0.0894, Ctr 0.9931, Str 0.9375, Cge 0.9924, Sge 0.7200
# 09863, T 3671.4, Ltr 0.1136, Lge 0.0938, Ctr 0.9979, Str 0.9688, Cge 0.9877, Sge 0.5300
# 09918, T 3691.5, Ltr 0.1333, Lge 0.0761, Ctr 0.9890, Str 0.8750, Cge 0.9905, Sge 0.6200
# 09972, T 3711.8, Ltr 0.1063, Lge 0.0603, Ctr 0.9975, Str 0.9375, Cge 0.9933, Sge 0.7100

 

結果を可視化する

# This cell visualizes the results of training. You can visualize the
# intermediate results by interrupting execution of the cell above, and running
# this cell. You can then resume training by simply executing the above cell
# again.

def softmax_prob_last_dim(x):  # pylint: disable=redefined-outer-name
  e = np.exp(x)
  return e[:, -1] / np.sum(e, axis=-1)


# Plot results curves.
fig = plt.figure(1, figsize=(18, 3))
fig.clf()
x = np.array(logged_iterations)
# Loss.
y_tr = losses_tr
y_ge = losses_ge
ax = fig.add_subplot(1, 3, 1)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Loss across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Loss (binary cross-entropy)")
ax.legend()
# Correct.
y_tr = corrects_tr
y_ge = corrects_ge
ax = fig.add_subplot(1, 3, 2)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction correct across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction nodes/edges correct")
# Solved.
y_tr = solveds_tr
y_ge = solveds_ge
ax = fig.add_subplot(1, 3, 3)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction solved across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction examples solved")

# Plot graphs and results after each processing step.
# The white node is the start, and the black is the end. Other nodes are colored
# from red to purple to blue, where red means the model is confident the node is
# off the shortest path, blue means the model is confident the node is on the
# shortest path, and purplish colors mean the model isn't sure.
max_graphs_to_plot = 6
num_steps_to_plot = 4
node_size = 120
min_c = 0.3
num_graphs = len(raw_graphs)
targets = utils_np.graphs_tuple_to_data_dicts(test_values["target"])
step_indices = np.floor(
    np.linspace(0, num_processing_steps_ge - 1,
                num_steps_to_plot)).astype(int).tolist()
outputs = list(
    zip(*(utils_np.graphs_tuple_to_data_dicts(test_values["outputs"][i])
          for i in step_indices)))
h = min(num_graphs, max_graphs_to_plot)
w = num_steps_to_plot + 1
fig = plt.figure(101, figsize=(18, h * 3))
fig.clf()
ncs = []
for j, (graph, target, output) in enumerate(zip(raw_graphs, targets, outputs)):
  if j >= h:
    break
  pos = get_node_dict(graph, "pos")
  ground_truth = target["nodes"][:, -1]
  # Ground truth.
  iax = j * (1 + num_steps_to_plot) + 1
  ax = fig.add_subplot(h, w, iax)
  plotter = GraphPlotter(ax, graph, pos)
  color = {}
  for i, n in enumerate(plotter.nodes):
    color[n] = np.array([1.0 - ground_truth[i], 0.0, ground_truth[i], 1.0
                        ]) * (1.0 - min_c) + min_c
  plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
  ax.set_axis_on()
  ax.set_xticks([])
  ax.set_yticks([])
  try:
    ax.set_facecolor([0.9] * 3 + [1.0])
  except AttributeError:
    ax.set_axis_bgcolor([0.9] * 3 + [1.0])
  ax.grid(None)
  ax.set_title("Ground truth\nSolution length: {}".format(
      plotter.solution_length))
  # Prediction.
  for k, outp in enumerate(output):
    iax = j * (1 + num_steps_to_plot) + 2 + k
    ax = fig.add_subplot(h, w, iax)
    plotter = GraphPlotter(ax, graph, pos)
    color = {}
    prob = softmax_prob_last_dim(outp["nodes"])
    for i, n in enumerate(plotter.nodes):
      color[n] = np.array([1.0 - prob[n], 0.0, prob[n], 1.0
                          ]) * (1.0 - min_c) + min_c
    plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
    ax.set_title("Model-predicted\nStep {:02d} / {:02d}".format(
        step_indices[k] + 1, step_indices[-1] + 1))

 

以上






Sonnet : モジュール指向 TensorFlow 高位ライブラリ

Sonnet : モジュール指向 TensorFlow 高位ライブラリ
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/13/2017

* 本ページは、github 上の sonnet の README.md の前半を動作確認・翻訳した上で適宜、補足説明したものです:
https://github.com/deepmind/sonnet/blob/master/README.md
* DeepMind 社の Sonnet ページは :
Open sourcing Sonnet – a new library for constructing neural networks
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

Sonnet は TensorFlow 上に構築された、複雑なニューラルネットワークを構成するための ライブラリです。

 

インストール説明

Sonnet をインストールするためには、TensorFlow ヘッダファイルに対して bazel を使用してライブラリをコンパイルする必要があります。TensorFlow installation instructions に従って TensorFlow はインストールされているものとします。

このインストールは Linux/Mac OS X と Python 2.7 と互換性があります。インストールされた TensorFlow のバージョンは少なくとも 1.0.1 でなければなりません。Sonnet のインストールは TensorFlow の native pip install に加えて、virtualenv installation mode もサポートします。

bazel をインストールする

bazel の最近のバージョン (>= 0.4.5 ) を持っていることを確かめてください。そうでないならば、これらの指示 に従ってください。

(virtualenv TensorFlow インストール) virtualenv を activate する

virtualenv を使用しているならば、残りのインストールのために virtualenv を activate してください、そうでないならばこのステップはスキップしてください :

$ source $VIRTUALENV_PATH/bin/activate # bash, sh, ksh, or zsh
$ source $VIRTUALENV_PATH/bin/activate.csh  # csh or tcsh

TensorFlow ヘッダを cofigure する

最初に Sonnet ソースコードを TensorFlow をサブモジュールとして clone します :

$ git clone --recursive https://github.com/deepmind/sonnet

そして configure を呼び出します :

$ cd sonnet/tensorflow
$ ./configure
$ cd ../

TensorFlow configuration の間は提示されたデフォルトを選択できます。
【注意】これは TensorFlow の既存のインストールを変更するものではありません。このステップは Sonnet が TensorFlow ヘッダに対してビルドできるようにするために必要です。

インストーラをビルドして実行する

一時ディレクトリに wheel ファイルを作成するためにインストール・スクリプトを実行します :

$ mkdir /tmp/sonnet
$ bazel build --config=opt :install
$ ./bazel-bin/install /tmp/sonnet

生成された wheel ファイルを pip install します :

$ pip install /tmp/sonnet/*.whl

Sonnet が既にインストールされているならば、wheel ファイルで pip install を呼び出す前に uninstall します :

$ pip uninstall sonnet

例えば、 resampler op を試してみることによって Sonnet が正しくインストールされたか検証できます :

$ cd ~/
$ python
>>> import sonnet as snt
>>> import tensorflow as tf
>>> snt.resampler(tf.constant([0.]), tf.constant([0.]))

期待される出力は以下のようなものです :


けれども、もし ImportError が上がるならば C++ コンポーネントが見つかっていません。clone されたソースコードをインポートしていないこと (i.e. clone されたレポジトリの外で python を呼び出すこと) そして wheel ファイルをインストールする前に Sonnet を uninstall したことを確認してください。

 

使用例

次のコードは線形モデルを構築してそれを複数の入力に接続します。変数 (i..e 線形変換の重みとバイアス) は自動的に共有されます。

import sonnet as snt

train_data = get_training_data()
test_data = get_test_data()

# モジュールを構築して、必要な構成を提供します。
linear_regression_module = snt.Linear(output_size=FLAGS.output_size)

# モジュールを幾つかの入力に接続します、何回でも。
train_predictions = linear_regression_module(train_data)
test_predictions = linear_regression_module(test_data)

更なる使用例は ここ で見つかります。

 

一般的な原理

Sonnet の主要な原理は最初にニューラルネットワークのある部分を表す Python オブジェクトを構築してこれらのオブジェクトを TensorFlow 計算グラフに別々に接続することです。オブジェクトは sonnet.AbstractModule のサブクラスでそれらはそういうものとしてモジュールとして参照されます。

モジュールはグラフに複数回接続されるかもしれません、そしてそのモジュールで宣言された任意の変数は続く接続呼び出しで自動的に共有されます。変数スコープ名を指定したり reuse= フラグを使用したりすることを含む、変数共有を制御する TensorFlow の低位な局面はユーザからは抽象化され離されます。

構成と接続の分離は高水準なモジュール、i.e. 他のモジュールをラップするモジュール 、の簡単な構築を可能にします。例えば、BatchApply モジュールはテンソルの leading dimension の数を single dimension にマージし、提供されたモジュールを接続し、そして入力に適合するように結果の leading dimension を分割します。コンストラクト時には、内部モジュールは BatchApply コンストラクタに引数として渡されます。実行時には、モジュールは最初に入力上の reshape 操作を実行して、コンストラクタに渡されたモジュールを適用し、そして reshape 操作を逆に実行します。

Python オブジェクトでモジュールを表す更なる優位点はそれは必要なところで定義される追加の方法を可能にすることです。これの例は、重み共有を保持する一方で、構築後に様々な方法で接続されるかもしれないモジュールです。例えば、生成モデルのケースにおいて、モデルからサンプリングしたり与えられた観測の log 確率を計算することを望むかもしれません。両者の接続を同時に持つことは重み共有を要求し、 そしてこれらの方法は同じ変数に依存します。変数はオブジェクトに概念的に所有され、モジュールの異なるメソッドで使用されます。

 

Sonnet を import する

Sonnet を import するために推奨される方法は snt と命名した変数にエイリアスすることです :

import sonnet as snt

すると全てのモジュールは名前空間 snt の下にアクセス可能となり、この文書の残りでは簡潔さのために snt を使用します。

次のコードは他のモジュールから成るモジュールを構築します :

import sonnet as snt

# データは複数の入力経由で由来し、それぞれに同じモデルを適用するためには
# 変数共有を使用することが必要です。
train_data = get_training_data()
test_data = get_test_data()

#  多層パーセプトロン (Multi Layer Perceptron) を形成するために、2つの線形モデルを作成します。
# TensorBoard / 他のツールで解釈可能な変数名を提供するために、
# デフォルト名をオーバーライドします。 (それは 'linear', 'linear_1' という結果になります)
lin_to_hidden = snt.Linear(output_size=FLAGS.hidden_size, name='inp_to_hidden')
hidden_to_out = snt.Linear(output_size=FLAGS.output_size, name='hidden_to_out')

# Sequential は提供されるデータに対して連続的に多くの内部モジュールや ops を適用する
# モジュールです。tanh のような生の TF ops は構築されたモジュールと互換的に利用可能です、
# それらは変数を含みませんので。
mlp = snt.Sequential([lin_to_hidden, tf.sigmoid, hidden_to_out])

# sequential をグラフに接続します、何回でも。
train_predictions = mlp(train_data)
test_predictions = mlp(test_data)

次のコードは Linear モジュールに初期化と正則化を追加しています :

import sonnet as snt

train_data = get_training_data()
test_data = get_test_data()

# 重みとバイアスへの初期化と正則化。
initializers={"w": tf.truncated_normal_initializer(stddev=1.0),
              "b": tf.truncated_normal_initializer(stddev=1.0)}
regularizers = {"w": tf.contrib.layers.l1_regularizer(scale=0.1),
                "b": tf.contrib.layers.l2_regularizer(scale=0.1)}

linear_regression_module = snt.Linear(output_size=FLAGS.output_size,
                                      initializers=initializers,
                                      regularizers=regularizers)

# Connect the module to some inputs, any number of times.
train_predictions = linear_regression_module(train_data)
test_predictions = linear_regression_module(test_data)

# ...

# 正則化ロスを取得してそれらを一緒に追加します。
graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_regularization_loss = tf.reduce_sum(graph_regularizers)

# ...

# 損失を最小化する時、正則化損失もまた最小化します。
train_op = optimizer.minimize(loss + total_regularizer_loss)

 

貴方自身のモジュールを定義する

snt.AbstractModule から継承する

モジュールを定義するためには、snt.AbstractModule から継承した新しいクラスを作成します。貴方のクラスのコンストラクタは、そのモジュールの演算を定義する任意の構成を受け取って、private であることを示す、アンダースコアで prefix されたメンバー変数で保持します。

スーパークラス・コンストラクタを呼び出す

コンストラクタが行なう最初のことはスーパークラス・コンストラクタを呼び出し、そのモジュールの名前を渡すことです。

  • これを行なうことを忘れた場合、変数共有は壊れます。name kwarg はリストの最後の一つとして、クラス名のスネークケース版であるデフォルト値とともに常に提供されるべきです。
class MyMLP(snt.AbstractModule):
  """Docstring for MyMLP."""
  def __init__(self, hidden_size, output_size,
               nonlinearity=tf.tanh, name="my_mlp"):
    """Docstring explaining __init__ args, including types and defaults."""
    super(MyMLP, self).__init__(name)
    self._hidden_size = hidden_size
    self._output_size = output_size
    self._nonlinearity = nonlinearity

_build() メソッドを実装する

提供されなければならない唯一の他のメソッド実装は _build() です。これはモジュールが tf.Graph に接続される時にいつでも呼び出されます。それは幾つかの入力を受け取り、それは空か、単一のテンソルか、あるいは複数のテンソルを含む幾つかの任意の構造です。複数のテンソルはタプルか namedtuple で提供され、それの要素は順番にテンソルか更なるタプル / namedtuple かです。多くの入力テンソルはバッチ次元を必要とし、もしテンソルがカラー・チャネルを持つならばそれは最後の次元でなければなりません。多くの場合でライブラリは明示的に妨げはしませんが、リストと辞書の使用はサポートされません、何故ならばこれらの構造の mutability (変更可能な性質) は微妙なバグにつながるからです。

  # Following on from code snippet above..
  def _build(self, inputs):
    """Compute output Tensor from input Tensor."""
    lin_x_to_h = snt.Linear(output_size=self._hidden_size, name="x_to_h")
    lin_h_to_o = snt.Linear(output_size=self._output_size, name="h_to_o")
    return lin_h_to_o(self._nonlinearity(lin_x_to_h(inputs)))

_build メソッドは次のプロセスの任意の一つあるいは全てを含むでしょう :

  • 内部モジュールの構築と利用
  • 既に存在し、コンストラクタに渡されたモジュールの利用
  • 変数を直接作成する。

変数を貴方自身で作成する場合、tf.get_variable でそれらを作成することは重要です。tf.Variable コンストラクタを直接呼び出すことは最初にモジュールが接続された時だけは動作しますが、2回目の呼び出しではエラーメッセージ “Trainable variable created when calling a template after the first time” を受け取るでしょう。

上のサンプルのモジュールは別々に作成され、様々な構成を渡して、最後の行がそれら全てをグラフに接続します。return 行は右から左へ読まれるべきです – 入力テンソルは最初の Linear, lin_x_to_h,に渡され、その出力はコンストラクタで保持される非線形が何であれ、その出力はその結果を生成するために他の Linear を通り抜けます。内部 Linear インスタンスに短い意味のある名前を与えていることに注意してください。

上の非線形は生の TF op、eg tf.tanh あるいは tf.sigmoid、あるいは Sonnet モジュールのインスタンスでもかまいません。Python 標準との調和を保つために、これを明示的にはチェックしないかもしれません、そしてそのために _build が呼ばれた時にエラーを受け取るかもしれません。また __init__ 内で制約とサニティーチェックを追加することも許容します。

上のコードでは、_build() が呼び出されるたびに snt.Linear の新しいインスタンスが生成されることに注意してください、そしてこれは異なる、非共有な変数を作成すると考えるかもしれません。そのようなことはありません – MLP インスタンスがグラフに何度接続されようとも、4 変数 (各 Linear に 2) だけが作成されます。これがどのように動作するかは低位 TF 詳細で、変更に従います – 詳細は [tf.variable_op_scope] を見てください。

サブモジュールはどこで宣言されるべきか?

モジュール eg Sequential etc. は、外部的に既にコンストラクトされた他のモジュールを受け取り使用するかもしれないことに注意してください。このセクションで議論するサブ・モジュールは他のモジュールのコード内部でコンストラクトされる任意のモジュールで、これは Parent Module として参照しましょう。例として LSTM があり、そこでは多くの実装は、重みを含む、内部的に1つまたはそれ以上の Linear モジュールをコンストラクトします。

サブモジュールは _build() で作成されることが奨励されます。このようにそれを行なうことは変数スコープの正しいネストを得ることを意味します、例えば :

class ParentModule(snt.AbstractModule):
  def __init__(self, hidden_size, name="parent_module"):
    super(ParentModule, self).__init__(name=name)
    self._hidden_size = hidden_size

  def _build(self, inputs):
    lin_mod = snt.Linear(self._hidden_size)  # Construct submodule...
    return tf.relu(lin_mod(inputs))          # then connect it.

 

前半の翻訳はここまです。Recurrent Modules 以後の後半の翻訳は次回。

 

以上

AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com