ホーム » Graph Nets » TensorFlow : Graph Nets : グラフの最短経路を見つける

TensorFlow : Graph Nets : グラフの最短経路を見つける

TensorFlow : Graph Nets : グラフの最短経路を見つける (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/07/2019

* 本ページは、deepmind/graph_nets の README.md 及び “Find the shortest path in a graph” を翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

Graph Nets : README.md

Graph Nets は TensorFlow と Sonnet でグラフネットワークを構築するための DeepMind のライブラリです。

グラフネットワークとは何でしょう?

グラフネットワークは入力としてグラフを取り出力としてグラフを返します。入力グラフはエッジ- (E)、ノード- (V) とグローバルレベル (u) 属性を持ちます。出力グラフは同じ構造、しかし更新された属性を持ちます。グラフネットワークは「グラフニューラルネットワーク」のより広いファミリーの一部です (Scarselli et al., 2009)。

グラフネットワークについてより学習するためには、私達の arXiv ペーパー: Relational inductive biases, deep learning, and graph networks を見てください。

 

使用方法サンプル

次のコードは単純なグラフネット・モジュールを構築してそれをデータに接続します。

import graph_nets as gn
import sonnet as snt

# Provide your own functions to generate graph-structured data.
input_graphs = get_graphs()

# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=lambda: snt.nets.MLP([32, 32]),
    node_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 32]))

# Pass the input graphs to the graph network, and return the output graphs.
output_graphs = graph_net_module(input_graphs)

 
 

Graph Nets : グラフの最短経路を見つける

「最短経路デモ」はランダムグラフを作成して、任意の 2 つのノードの間の最短経路上のノードとエッジをラベル付けするためにグラフネットワークを訓練します。メッセージパッシングのステップのシークエンスに渡り (各ステップのプロットで描かれるように)、モデルは最短経路のその予測を改良していきます。

このノートブックと伴うコードはグラフの 2 つのノード間に最短経路を予測することを学習するために Graph Nets ライブラリをどのように使用するかを示します。

開始と終了ノードが与えられたとき、ネットワークは最短経路のノードとエッジをラベル付けするために訓練されます。

訓練後、ネットワークの予測能力はその出力を真の最短経路と比較することにより示されます。それから汎化するためのネットワークの能力がテストされます、類似のしかしより巨大なグラフの最短経路を予測するためにそれを使用することによって。

 

インポート

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import itertools
import time

from graph_nets import graphs
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.demos import models
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from scipy import spatial
import tensorflow as tf

SEED = 1
np.random.seed(SEED)
tf.set_random_seed(SEED)

 

ヘルパー関数

DISTANCE_WEIGHT_NAME = "distance"  # The name for the distance edge attribute.


def pairwise(iterable):
  """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
  a, b = itertools.tee(iterable)
  next(b, None)
  return zip(a, b)


def set_diff(seq0, seq1):
  """Return the set difference between 2 sequences as a list."""
  return list(set(seq0) - set(seq1))


def to_one_hot(indices, max_value, axis=-1):
  one_hot = np.eye(max_value)[indices]
  if axis not in (-1, one_hot.ndim):
    one_hot = np.moveaxis(one_hot, -1, axis)
  return one_hot


def get_node_dict(graph, attr):
  """Return a `dict` of node:attribute pairs from a graph."""
  return {k: v[attr] for k, v in graph.node.items()}


def generate_graph(rand,
                   num_nodes_min_max,
                   dimensions=2,
                   theta=1000.0,
                   rate=1.0):
  """Creates a connected graph.

  The graphs are geographic threshold graphs, but with added edges via a
  minimum spanning tree algorithm, to ensure all nodes are connected.

  Args:
    rand: A random seed for the graph generator. Default= None.
    num_nodes_min_max: A sequence [lower, upper) number of nodes per graph.
    dimensions: (optional) An `int` number of dimensions for the positions.
      Default= 2.
    theta: (optional) A `float` threshold parameters for the geographic
      threshold graph's threshold. Large values (1000+) make mostly trees. Try
      20-60 for good non-trees. Default=1000.0.
    rate: (optional) A rate parameter for the node weight exponential sampling
      distribution. Default= 1.0.

  Returns:
    The graph.
  """
  # Sample num_nodes.
  num_nodes = rand.randint(*num_nodes_min_max)

  # Create geographic threshold graph.
  pos_array = rand.uniform(size=(num_nodes, dimensions))
  pos = dict(enumerate(pos_array))
  weight = dict(enumerate(rand.exponential(rate, size=num_nodes)))
  geo_graph = nx.geographical_threshold_graph(
      num_nodes, theta, pos=pos, weight=weight)

  # Create minimum spanning tree across geo_graph's nodes.
  distances = spatial.distance.squareform(spatial.distance.pdist(pos_array))
  i_, j_ = np.meshgrid(range(num_nodes), range(num_nodes), indexing="ij")
  weighted_edges = list(zip(i_.ravel(), j_.ravel(), distances.ravel()))
  mst_graph = nx.Graph()
  mst_graph.add_weighted_edges_from(weighted_edges, weight=DISTANCE_WEIGHT_NAME)
  mst_graph = nx.minimum_spanning_tree(mst_graph, weight=DISTANCE_WEIGHT_NAME)
  # Put geo_graph's node attributes into the mst_graph.
  for i in mst_graph.nodes():
    mst_graph.node[i].update(geo_graph.node[i])

  # Compose the graphs.
  combined_graph = nx.compose_all((mst_graph, geo_graph.copy()))
  # Put all distance weights into edge attributes.
  for i, j in combined_graph.edges():
    combined_graph.get_edge_data(i, j).setdefault(DISTANCE_WEIGHT_NAME,
                                                  distances[i, j])
  return combined_graph, mst_graph, geo_graph


def add_shortest_path(rand, graph, min_length=1):
  """Samples a shortest path from A to B and adds attributes to indicate it.

  Args:
    rand: A random seed for the graph generator. Default= None.
    graph: A `nx.Graph`.
    min_length: (optional) An `int` minimum number of edges in the shortest
      path. Default= 1.

  Returns:
    The `nx.DiGraph` with the shortest path added.

  Raises:
    ValueError: All shortest paths are below the minimum length
  """
  # Map from node pairs to the length of their shortest path.
  pair_to_length_dict = {}
  try:
    # This is for compatibility with older networkx.
    lengths = nx.all_pairs_shortest_path_length(graph).items()
  except AttributeError:
    # This is for compatibility with newer networkx.
    lengths = list(nx.all_pairs_shortest_path_length(graph))
  for x, yy in lengths:
    for y, l in yy.items():
      if l >= min_length:
        pair_to_length_dict[x, y] = l
  if max(pair_to_length_dict.values()) < min_length:
    raise ValueError("All shortest paths are below the minimum length")
  # The node pairs which exceed the minimum length.
  node_pairs = list(pair_to_length_dict)

  # Computes probabilities per pair, to enforce uniform sampling of each
  # shortest path lengths.
  # The counts of pairs per length.
  counts = collections.Counter(pair_to_length_dict.values())
  prob_per_length = 1.0 / len(counts)
  probabilities = [
      prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
  ]

  # Choose the start and end points.
  i = rand.choice(len(node_pairs), p=probabilities)
  start, end = node_pairs[i]
  path = nx.shortest_path(
      graph, source=start, target=end, weight=DISTANCE_WEIGHT_NAME)

  # Creates a directed graph, to store the directed path from start to end.
  digraph = graph.to_directed()

  # Add the "start", "end", and "solution" attributes to the nodes and edges.
  digraph.add_node(start, start=True)
  digraph.add_node(end, end=True)
  digraph.add_nodes_from(set_diff(digraph.nodes(), [start]), start=False)
  digraph.add_nodes_from(set_diff(digraph.nodes(), [end]), end=False)
  digraph.add_nodes_from(set_diff(digraph.nodes(), path), solution=False)
  digraph.add_nodes_from(path, solution=True)
  path_edges = list(pairwise(path))
  digraph.add_edges_from(set_diff(digraph.edges(), path_edges), solution=False)
  digraph.add_edges_from(path_edges, solution=True)

  return digraph


def graph_to_input_target(graph):
  """Returns 2 graphs with input and target feature vectors for training.

  Args:
    graph: An `nx.DiGraph` instance.

  Returns:
    The input `nx.DiGraph` instance.
    The target `nx.DiGraph` instance.

  Raises:
    ValueError: unknown node type
  """

  def create_feature(attr, fields):
    return np.hstack([np.array(attr[field], dtype=float) for field in fields])

  input_node_fields = ("pos", "weight", "start", "end")
  input_edge_fields = ("distance",)
  target_node_fields = ("solution",)
  target_edge_fields = ("solution",)

  input_graph = graph.copy()
  target_graph = graph.copy()

  solution_length = 0
  for node_index, node_feature in graph.nodes(data=True):
    input_graph.add_node(
        node_index, features=create_feature(node_feature, input_node_fields))
    target_node = to_one_hot(
        create_feature(node_feature, target_node_fields).astype(int), 2)[0]
    target_graph.add_node(node_index, features=target_node)
    solution_length += int(node_feature["solution"])
  solution_length /= graph.number_of_nodes()

  for receiver, sender, features in graph.edges(data=True):
    input_graph.add_edge(
        sender, receiver, features=create_feature(features, input_edge_fields))
    target_edge = to_one_hot(
        create_feature(features, target_edge_fields).astype(int), 2)[0]
    target_graph.add_edge(sender, receiver, features=target_edge)

  input_graph.graph["features"] = np.array([0.0])
  target_graph.graph["features"] = np.array([solution_length], dtype=float)

  return input_graph, target_graph


def generate_networkx_graphs(rand, num_examples, num_nodes_min_max, theta):
  """Generate graphs for training.

  Args:
    rand: A random seed (np.RandomState instance).
    num_examples: Total number of graphs to generate.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: (optional) A `float` threshold parameters for the geographic
      threshold graph's threshold. Default= the number of nodes.

  Returns:
    input_graphs: The list of input graphs.
    target_graphs: The list of output graphs.
    graphs: The list of generated graphs.
  """
  input_graphs = []
  target_graphs = []
  graphs = []
  for _ in range(num_examples):
    graph = generate_graph(rand, num_nodes_min_max, theta=theta)[0]
    graph = add_shortest_path(rand, graph)
    input_graph, target_graph = graph_to_input_target(graph)
    input_graphs.append(input_graph)
    target_graphs.append(target_graph)
    graphs.append(graph)
  return input_graphs, target_graphs, graphs


def create_placeholders(rand, batch_size, num_nodes_min_max, theta):
  """Creates placeholders for the model training and evaluation.

  Args:
    rand: A random seed (np.RandomState instance).
    batch_size: Total number of graphs per batch.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: A `float` threshold parameters for the geographic threshold graph's
      threshold. Default= the number of nodes.

  Returns:
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.
  """
  # Create some example data for inspecting the vector sizes.
  input_graphs, target_graphs, _ = generate_networkx_graphs(
      rand, batch_size, num_nodes_min_max, theta)
  input_ph = utils_tf.placeholders_from_networkxs(input_graphs)
  target_ph = utils_tf.placeholders_from_networkxs(target_graphs)
  return input_ph, target_ph


def create_feed_dict(rand, batch_size, num_nodes_min_max, theta, input_ph,
                     target_ph):
  """Creates placeholders for the model training and evaluation.

  Args:
    rand: A random seed (np.RandomState instance).
    batch_size: Total number of graphs per batch.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: A `float` threshold parameters for the geographic threshold graph's
      threshold. Default= the number of nodes.
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.

  Returns:
    feed_dict: The feed `dict` of input and target placeholders and data.
    raw_graphs: The `dict` of raw networkx graphs.
  """
  inputs, targets, raw_graphs = generate_networkx_graphs(
      rand, batch_size, num_nodes_min_max, theta)
  input_graphs = utils_np.networkxs_to_graphs_tuple(inputs)
  target_graphs = utils_np.networkxs_to_graphs_tuple(targets)
  feed_dict = {input_ph: input_graphs, target_ph: target_graphs}
  return feed_dict, raw_graphs


def compute_accuracy(target, output, use_nodes=True, use_edges=False):
  """Calculate model accuracy.

  Returns the number of correctly predicted shortest path nodes and the number
  of completely solved graphs (100% correct predictions).

  Args:
    target: A `graphs.GraphsTuple` that contains the target graph.
    output: A `graphs.GraphsTuple` that contains the output graph.
    use_nodes: A `bool` indicator of whether to compute node accuracy or not.
    use_edges: A `bool` indicator of whether to compute edge accuracy or not.

  Returns:
    correct: A `float` fraction of correctly labeled nodes/edges.
    solved: A `float` fraction of graphs that are completely correctly labeled.

  Raises:
    ValueError: Nodes or edges (or both) must be used
  """
  if not use_nodes and not use_edges:
    raise ValueError("Nodes or edges (or both) must be used")
  tdds = utils_np.graphs_tuple_to_data_dicts(target)
  odds = utils_np.graphs_tuple_to_data_dicts(output)
  cs = []
  ss = []
  for td, od in zip(tdds, odds):
    xn = np.argmax(td["nodes"], axis=-1)
    yn = np.argmax(od["nodes"], axis=-1)
    xe = np.argmax(td["edges"], axis=-1)
    ye = np.argmax(od["edges"], axis=-1)
    c = []
    if use_nodes:
      c.append(xn == yn)
    if use_edges:
      c.append(xe == ye)
    c = np.concatenate(c, axis=0)
    s = np.all(c)
    cs.append(c)
    ss.append(s)
  correct = np.mean(np.concatenate(cs, axis=0))
  solved = np.mean(np.stack(ss))
  return correct, solved


def create_loss_ops(target_op, output_ops):
  loss_ops = [
      tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +
      tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)
      for output_op in output_ops
  ]
  return loss_ops


def make_all_runnable_in_session(*args):
  """Lets an iterable of TF graphs be output from a session as NP graphs."""
  return [utils_tf.make_runnable_in_session(a) for a in args]


class GraphPlotter(object):

  def __init__(self, ax, graph, pos):
    self._ax = ax
    self._graph = graph
    self._pos = pos
    self._base_draw_kwargs = dict(G=self._graph, pos=self._pos, ax=self._ax)
    self._solution_length = None
    self._nodes = None
    self._edges = None
    self._start_nodes = None
    self._end_nodes = None
    self._solution_nodes = None
    self._intermediate_solution_nodes = None
    self._solution_edges = None
    self._non_solution_nodes = None
    self._non_solution_edges = None
    self._ax.set_axis_off()

  @property
  def solution_length(self):
    if self._solution_length is None:
      self._solution_length = len(self._solution_edges)
    return self._solution_length

  @property
  def nodes(self):
    if self._nodes is None:
      self._nodes = self._graph.nodes()
    return self._nodes

  @property
  def edges(self):
    if self._edges is None:
      self._edges = self._graph.edges()
    return self._edges

  @property
  def start_nodes(self):
    if self._start_nodes is None:
      self._start_nodes = [
          n for n in self.nodes if self._graph.node[n].get("start", False)
      ]
    return self._start_nodes

  @property
  def end_nodes(self):
    if self._end_nodes is None:
      self._end_nodes = [
          n for n in self.nodes if self._graph.node[n].get("end", False)
      ]
    return self._end_nodes

  @property
  def solution_nodes(self):
    if self._solution_nodes is None:
      self._solution_nodes = [
          n for n in self.nodes if self._graph.node[n].get("solution", False)
      ]
    return self._solution_nodes

  @property
  def intermediate_solution_nodes(self):
    if self._intermediate_solution_nodes is None:
      self._intermediate_solution_nodes = [
          n for n in self.nodes
          if self._graph.node[n].get("solution", False) and
          not self._graph.node[n].get("start", False) and
          not self._graph.node[n].get("end", False)
      ]
    return self._intermediate_solution_nodes

  @property
  def solution_edges(self):
    if self._solution_edges is None:
      self._solution_edges = [
          e for e in self.edges
          if self._graph.get_edge_data(e[0], e[1]).get("solution", False)
      ]
    return self._solution_edges

  @property
  def non_solution_nodes(self):
    if self._non_solution_nodes is None:
      self._non_solution_nodes = [
          n for n in self.nodes
          if not self._graph.node[n].get("solution", False)
      ]
    return self._non_solution_nodes

  @property
  def non_solution_edges(self):
    if self._non_solution_edges is None:
      self._non_solution_edges = [
          e for e in self.edges
          if not self._graph.get_edge_data(e[0], e[1]).get("solution", False)
      ]
    return self._non_solution_edges

  def _make_draw_kwargs(self, **kwargs):
    kwargs.update(self._base_draw_kwargs)
    return kwargs

  def _draw(self, draw_function, zorder=None, **kwargs):
    draw_kwargs = self._make_draw_kwargs(**kwargs)
    collection = draw_function(**draw_kwargs)
    if collection is not None and zorder is not None:
      try:
        # This is for compatibility with older matplotlib.
        collection.set_zorder(zorder)
      except AttributeError:
        # This is for compatibility with newer matplotlib.
        collection[0].set_zorder(zorder)
    return collection

  def draw_nodes(self, **kwargs):
    """Useful kwargs: nodelist, node_size, node_color, linewidths."""
    if ("node_color" in kwargs and
        isinstance(kwargs["node_color"], collections.Sequence) and
        len(kwargs["node_color"]) in {3, 4} and
        not isinstance(kwargs["node_color"][0],
                       (collections.Sequence, np.ndarray))):
      num_nodes = len(kwargs.get("nodelist", self.nodes))
      kwargs["node_color"] = np.tile(
          np.array(kwargs["node_color"])[None], [num_nodes, 1])
    return self._draw(nx.draw_networkx_nodes, **kwargs)

  def draw_edges(self, **kwargs):
    """Useful kwargs: edgelist, width."""
    return self._draw(nx.draw_networkx_edges, **kwargs)

  def draw_graph(self,
                 node_size=200,
                 node_color=(0.4, 0.8, 0.4),
                 node_linewidth=1.0,
                 edge_width=1.0):
    # Plot nodes.
    self.draw_nodes(
        nodelist=self.nodes,
        node_size=node_size,
        node_color=node_color,
        linewidths=node_linewidth,
        zorder=20)
    # Plot edges.
    self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)

  def draw_graph_with_solution(self,
                               node_size=200,
                               node_color=(0.4, 0.8, 0.4),
                               node_linewidth=1.0,
                               edge_width=1.0,
                               start_color="w",
                               end_color="k",
                               solution_node_linewidth=3.0,
                               solution_edge_width=3.0):
    node_border_color = (0.0, 0.0, 0.0, 1.0)
    node_collections = {}
    # Plot start nodes.
    node_collections["start nodes"] = self.draw_nodes(
        nodelist=self.start_nodes,
        node_size=node_size,
        node_color=start_color,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=100)
    # Plot end nodes.
    node_collections["end nodes"] = self.draw_nodes(
        nodelist=self.end_nodes,
        node_size=node_size,
        node_color=end_color,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=90)
    # Plot intermediate solution nodes.
    if isinstance(node_color, dict):
      c = [node_color[n] for n in self.intermediate_solution_nodes]
    else:
      c = node_color
    node_collections["intermediate solution nodes"] = self.draw_nodes(
        nodelist=self.intermediate_solution_nodes,
        node_size=node_size,
        node_color=c,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=80)
    # Plot solution edges.
    node_collections["solution edges"] = self.draw_edges(
        edgelist=self.solution_edges, width=solution_edge_width, zorder=70)
    # Plot non-solution nodes.
    if isinstance(node_color, dict):
      c = [node_color[n] for n in self.non_solution_nodes]
    else:
      c = node_color
    node_collections["non-solution nodes"] = self.draw_nodes(
        nodelist=self.non_solution_nodes,
        node_size=node_size,
        node_color=c,
        linewidths=node_linewidth,
        edgecolors=node_border_color,
        zorder=20)
    # Plot non-solution edges.
    node_collections["non-solution edges"] = self.draw_edges(
        edgelist=self.non_solution_edges, width=edge_width, zorder=10)
    # Set title as solution length.
    self._ax.set_title("Solution length: {}".format(self.solution_length))
    return node_collections

 

サンプルグラフを可視化する

seed = 1  #@param{type: 'integer'}
rand = np.random.RandomState(seed=seed)

num_examples = 15  #@param{type: 'integer'}
# Large values (1000+) make trees. Try 20-60 for good non-trees.
theta = 20  #@param{type: 'integer'}
num_nodes_min_max = (16, 17)

input_graphs, target_graphs, graphs = generate_networkx_graphs(
    rand, num_examples, num_nodes_min_max, theta)

num = min(num_examples, 16)
w = 3
h = int(np.ceil(num / w))
fig = plt.figure(40, figsize=(w * 4, h * 4))
fig.clf()
for j, graph in enumerate(graphs):
  ax = fig.add_subplot(h, w, j + 1)
  pos = get_node_dict(graph, "pos")
  plotter = GraphPlotter(ax, graph, pos)
  plotter.draw_graph_with_solution()

 

モデル訓練と評価をセットアップする

# The model we explore includes three components:
# - An "Encoder" graph net, which independently encodes the edge, node, and
#   global attributes (does not compute relations etc.).
# - A "Core" graph net, which performs N rounds of processing (message-passing)
#   steps. The input to the Core is the concatenation of the Encoder's output
#   and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
#   the processing step).
# - A "Decoder" graph net, which independently decodes the edge, node, and
#   global attributes (does not compute relations etc.), on each
#   message-passing step.
#
#                     Hidden(t)   Hidden(t+1)
#                        |            ^
#           *---------*  |  *------*  |  *---------*
#           |         |  |  |      |  |  |         |
# Input --->| Encoder |  *->| Core |--*->| Decoder |---> Output(t)
#           |         |---->|      |     |         |
#           *---------*     *------*     *---------*
#
# The model is trained by supervised learning. Input graphs are procedurally
# generated, and output graphs have the same structure with the nodes and edges
# of the shortest path labeled (using 2-element 1-hot vectors). We could have
# predicted the shortest path only by labeling either the nodes or edges, and
# that does work, but we decided to predict both to demonstrate the flexibility
# of graph nets' outputs.
#
# The training loss is computed on the output of each processing step. The
# reason for this is to encourage the model to try to solve the problem in as
# few steps as possible. It also helps make the output of intermediate steps
# more interpretable.
#
# There's no need for a separate evaluate dataset because the inputs are
# never repeated, so the training loss is the measure of performance on graphs
# from the input distribution.
#
# We also evaluate how well the models generalize to graphs which are up to
# twice as large as those on which it was trained. The loss is computed only
# on the final processing step.
#
# Variables with the suffix _tr are training parameters, and variables with the
# suffix _ge are test/generalization parameters.
#
# After around 2000-5000 training iterations the model reaches near-perfect
# performance on graphs with between 8-16 nodes.

tf.reset_default_graph()

seed = 2
rand = np.random.RandomState(seed=seed)

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10

# Data / training parameters.
num_training_iterations = 10000
theta = 20  # Large values (1000+) make trees. Try 20-60 for good non-trees.
batch_size_tr = 32
batch_size_ge = 100
# Number of nodes per graph sampled uniformly from this range.
num_nodes_min_max_tr = (8, 17)
num_nodes_min_max_ge = (16, 33)

# Data.
# Input and target placeholders.
input_ph, target_ph = create_placeholders(rand, batch_size_tr,
                                          num_nodes_min_max_tr, theta)

# Connect the data to the model.
# Instantiate the model.
model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
# A list of outputs, one per processing step.
output_ops_tr = model(input_ph, num_processing_steps_tr)
output_ops_ge = model(input_ph, num_processing_steps_ge)

# Training loss.
loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
# Loss across processing steps.
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr
# Test/generalization loss.
loss_ops_ge = create_loss_ops(target_ph, output_ops_ge)
loss_op_ge = loss_ops_ge[-1]  # Loss from final processing step.

# Optimizer.
learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

# Lets an iterable of TF graphs be output from a session as NP graphs.
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)

 

セッションをリセットする

# This cell resets the Tensorflow session, but keeps the same computational
# graph.

try:
  sess.close()
except NameError:
  pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())

last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

 

訓練を実行する

# You can interrupt this cell's training loop at any time, and visualize the
# intermediate results by running the next cell (below). You can then resume
# training by simply executing this cell again.

# How much time between logging and printing the current results.
log_every_seconds = 20

print("# (iteration number), T (elapsed seconds), "
      "Ltr (training loss), Lge (test/generalization loss), "
      "Ctr (training fraction nodes/edges labeled correctly), "
      "Str (training fraction examples solved correctly), "
      "Cge (test/generalization fraction nodes/edges labeled correctly), "
      "Sge (test/generalization fraction examples solved correctly)")

start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
  last_iteration = iteration
  feed_dict, _ = create_feed_dict(rand, batch_size_tr, num_nodes_min_max_tr,
                                  theta, input_ph, target_ph)
  train_values = sess.run({
      "step": step_op,
      "target": target_ph,
      "loss": loss_op_tr,
      "outputs": output_ops_tr
  },
                          feed_dict=feed_dict)
  the_time = time.time()
  elapsed_since_last_log = the_time - last_log_time
  if elapsed_since_last_log > log_every_seconds:
    last_log_time = the_time
    feed_dict, raw_graphs = create_feed_dict(
        rand, batch_size_ge, num_nodes_min_max_ge, theta, input_ph, target_ph)
    test_values = sess.run({
        "target": target_ph,
        "loss": loss_op_ge,
        "outputs": output_ops_ge
    },
                           feed_dict=feed_dict)
    correct_tr, solved_tr = compute_accuracy(
        train_values["target"], train_values["outputs"][-1], use_edges=True)
    correct_ge, solved_ge = compute_accuracy(
        test_values["target"], test_values["outputs"][-1], use_edges=True)
    elapsed = time.time() - start_time
    losses_tr.append(train_values["loss"])
    corrects_tr.append(correct_tr)
    solveds_tr.append(solved_tr)
    losses_ge.append(test_values["loss"])
    corrects_ge.append(correct_ge)
    solveds_ge.append(solved_ge)
    logged_iterations.append(iteration)
    print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, Str"
          " {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
              iteration, elapsed, train_values["loss"], test_values["loss"],
              correct_tr, solved_tr, correct_ge, solved_ge))
# (iteration number), T (elapsed seconds), Ltr (training loss), Lge (test/generalization loss), Ctr (training fraction nodes/edges labeled correctly), Str (training fraction examples solved correctly), Cge (test/generalization fraction nodes/edges labeled correctly), Sge (test/generalization fraction examples solved correctly)
# 00029, T 23.8, Ltr 0.8731, Lge 0.6658, Ctr 0.8596, Str 0.0000, Cge 0.9481, Sge 0.0000
# 00078, T 42.1, Ltr 0.6341, Lge 0.4758, Ctr 0.9056, Str 0.0000, Cge 0.9549, Sge 0.0000
# 00133, T 62.5, Ltr 0.5034, Lge 0.3845, Ctr 0.9172, Str 0.0000, Cge 0.9625, Sge 0.0200
# 00189, T 82.7, Ltr 0.5162, Lge 0.3417, Ctr 0.9166, Str 0.1250, Cge 0.9664, Sge 0.0200
# 00244, T 103.0, Ltr 0.4486, Lge 0.3383, Ctr 0.9343, Str 0.1250, Cge 0.9685, Sge 0.1600
# 00299, T 123.1, Ltr 0.4963, Lge 0.3507, Ctr 0.9184, Str 0.2500, Cge 0.9637, Sge 0.1100
# 00354, T 143.4, Ltr 0.3223, Lge 0.2883, Ctr 0.9614, Str 0.4062, Cge 0.9721, Sge 0.3300
# 00407, T 163.5, Ltr 0.4604, Lge 0.3853, Ctr 0.9270, Str 0.2500, Cge 0.9585, Sge 0.0600
# 00461, T 183.7, Ltr 0.2822, Lge 0.2933, Ctr 0.9670, Str 0.5625, Cge 0.9702, Sge 0.3200
# 00517, T 203.8, Ltr 0.3703, Lge 0.2784, Ctr 0.9480, Str 0.4375, Cge 0.9698, Sge 0.2600
# 00571, T 224.1, Ltr 0.4301, Lge 0.2783, Ctr 0.9308, Str 0.3125, Cge 0.9723, Sge 0.2000
# 00626, T 244.1, Ltr 0.3287, Lge 0.2833, Ctr 0.9533, Str 0.4062, Cge 0.9687, Sge 0.2700
# 00682, T 264.4, Ltr 0.2802, Lge 0.2913, Ctr 0.9617, Str 0.5000, Cge 0.9703, Sge 0.3000
# 00736, T 284.7, Ltr 0.3474, Lge 0.2775, Ctr 0.9531, Str 0.5625, Cge 0.9704, Sge 0.1900
# 00790, T 305.1, Ltr 0.3098, Lge 0.3607, Ctr 0.9488, Str 0.4062, Cge 0.9690, Sge 0.1700
# 00844, T 324.9, Ltr 0.3092, Lge 0.2941, Ctr 0.9566, Str 0.4375, Cge 0.9702, Sge 0.2500
# 00899, T 345.1, Ltr 0.3805, Lge 0.2202, Ctr 0.9440, Str 0.2812, Cge 0.9770, Sge 0.3200
# 00953, T 365.0, Ltr 0.2927, Lge 0.2637, Ctr 0.9609, Str 0.5938, Cge 0.9707, Sge 0.1900
# 01008, T 385.0, Ltr 0.3164, Lge 0.2093, Ctr 0.9568, Str 0.4688, Cge 0.9749, Sge 0.3100
# 01063, T 405.1, Ltr 0.2704, Lge 0.2455, Ctr 0.9749, Str 0.5000, Cge 0.9719, Sge 0.2000
# 01117, T 425.2, Ltr 0.2696, Lge 0.2713, Ctr 0.9600, Str 0.6250, Cge 0.9719, Sge 0.2300
# 01172, T 445.3, Ltr 0.4089, Lge 0.2442, Ctr 0.9489, Str 0.5312, Cge 0.9727, Sge 0.2100
# 01227, T 465.3, Ltr 0.3053, Lge 0.2616, Ctr 0.9620, Str 0.5312, Cge 0.9733, Sge 0.2400
# 01280, T 485.2, Ltr 0.2292, Lge 0.2433, Ctr 0.9742, Str 0.6250, Cge 0.9703, Sge 0.3100
# 01335, T 505.6, Ltr 0.3238, Lge 0.2267, Ctr 0.9554, Str 0.5312, Cge 0.9748, Sge 0.3000
# 01390, T 525.7, Ltr 0.3662, Lge 0.2706, Ctr 0.9602, Str 0.5938, Cge 0.9720, Sge 0.2500
# 01445, T 545.8, Ltr 0.2444, Lge 0.2530, Ctr 0.9755, Str 0.6562, Cge 0.9732, Sge 0.2800
# 01498, T 565.8, Ltr 0.3119, Lge 0.3036, Ctr 0.9565, Str 0.5938, Cge 0.9708, Sge 0.2300
# 01552, T 585.9, Ltr 0.3058, Lge 0.2633, Ctr 0.9553, Str 0.4688, Cge 0.9717, Sge 0.2400
# 01606, T 606.1, Ltr 0.2392, Lge 0.2462, Ctr 0.9782, Str 0.6562, Cge 0.9726, Sge 0.3000
# 01661, T 626.4, Ltr 0.2917, Lge 0.2522, Ctr 0.9611, Str 0.5625, Cge 0.9725, Sge 0.3000
# 01716, T 646.7, Ltr 0.3049, Lge 0.2254, Ctr 0.9566, Str 0.5938, Cge 0.9749, Sge 0.3300
# 01771, T 666.8, Ltr 0.2509, Lge 0.2393, Ctr 0.9667, Str 0.6250, Cge 0.9747, Sge 0.3000
# 01824, T 687.1, Ltr 0.1827, Lge 0.1870, Ctr 0.9879, Str 0.7812, Cge 0.9781, Sge 0.3300
# 01879, T 707.2, Ltr 0.3511, Lge 0.2048, Ctr 0.9574, Str 0.6562, Cge 0.9775, Sge 0.4200
# 01935, T 727.6, Ltr 0.2784, Lge 0.2044, Ctr 0.9699, Str 0.5625, Cge 0.9752, Sge 0.2000
# 01990, T 747.7, Ltr 0.3216, Lge 0.1943, Ctr 0.9639, Str 0.6562, Cge 0.9768, Sge 0.2800
# 02044, T 768.0, Ltr 0.1950, Lge 0.1579, Ctr 0.9892, Str 0.8750, Cge 0.9820, Sge 0.4100
# 02098, T 788.0, Ltr 0.2075, Lge 0.1729, Ctr 0.9882, Str 0.7500, Cge 0.9799, Sge 0.3700
# 02153, T 808.0, Ltr 0.2118, Lge 0.1775, Ctr 0.9783, Str 0.7500, Cge 0.9804, Sge 0.3700
# 02207, T 828.4, Ltr 0.2426, Lge 0.1862, Ctr 0.9669, Str 0.6562, Cge 0.9797, Sge 0.3800
# 02262, T 848.7, Ltr 0.2076, Lge 0.1836, Ctr 0.9862, Str 0.8750, Cge 0.9792, Sge 0.4300
# 02317, T 869.0, Ltr 0.1890, Lge 0.1984, Ctr 0.9873, Str 0.8750, Cge 0.9767, Sge 0.3800
# 02371, T 889.2, Ltr 0.1652, Lge 0.1936, Ctr 0.9887, Str 0.8125, Cge 0.9784, Sge 0.3900
# 02426, T 909.4, Ltr 0.2751, Lge 0.1523, Ctr 0.9707, Str 0.6875, Cge 0.9823, Sge 0.4100
# 02481, T 929.4, Ltr 0.1775, Lge 0.1617, Ctr 0.9867, Str 0.8750, Cge 0.9788, Sge 0.3800
# 02536, T 949.6, Ltr 0.2007, Lge 0.1207, Ctr 0.9880, Str 0.8438, Cge 0.9857, Sge 0.4900
# 02591, T 969.8, Ltr 0.1514, Lge 0.1489, Ctr 0.9934, Str 0.9062, Cge 0.9813, Sge 0.3600
# 02646, T 989.9, Ltr 0.2410, Lge 0.1100, Ctr 0.9862, Str 0.8125, Cge 0.9854, Sge 0.4000
# 02702, T 1010.0, Ltr 0.1991, Lge 0.1578, Ctr 0.9827, Str 0.8125, Cge 0.9813, Sge 0.3800
# 02756, T 1030.1, Ltr 0.1464, Lge 0.1388, Ctr 0.9893, Str 0.8750, Cge 0.9814, Sge 0.4000
# 02811, T 1050.1, Ltr 0.1931, Lge 0.1588, Ctr 0.9833, Str 0.8750, Cge 0.9799, Sge 0.3900
# 02866, T 1070.1, Ltr 0.1570, Lge 0.1189, Ctr 0.9858, Str 0.8750, Cge 0.9858, Sge 0.5500
# 02920, T 1090.5, Ltr 0.1420, Lge 0.1113, Ctr 0.9922, Str 0.8750, Cge 0.9855, Sge 0.4500
# 02974, T 1110.8, Ltr 0.1550, Lge 0.1640, Ctr 0.9911, Str 0.8438, Cge 0.9809, Sge 0.3200
# 03029, T 1130.8, Ltr 0.1681, Lge 0.1297, Ctr 0.9936, Str 0.8750, Cge 0.9873, Sge 0.4900
# 03084, T 1151.5, Ltr 0.1810, Lge 0.1909, Ctr 0.9921, Str 0.8750, Cge 0.9785, Sge 0.2300
# 03139, T 1171.1, Ltr 0.2063, Lge 0.1209, Ctr 0.9818, Str 0.8125, Cge 0.9861, Sge 0.4400
# 03195, T 1191.2, Ltr 0.1340, Lge 0.1583, Ctr 0.9945, Str 0.8750, Cge 0.9789, Sge 0.2100
# 03251, T 1211.7, Ltr 0.1461, Lge 0.1520, Ctr 0.9943, Str 0.9062, Cge 0.9856, Sge 0.4800
# 03303, T 1231.8, Ltr 0.1694, Lge 0.1235, Ctr 0.9854, Str 0.7812, Cge 0.9852, Sge 0.4900
# 03359, T 1252.0, Ltr 0.1738, Lge 0.1222, Ctr 0.9852, Str 0.7812, Cge 0.9846, Sge 0.4400
# 03414, T 1272.2, Ltr 0.1498, Lge 0.1101, Ctr 0.9897, Str 0.8438, Cge 0.9867, Sge 0.4900
# 03468, T 1292.5, Ltr 0.1638, Lge 0.1573, Ctr 0.9894, Str 0.8438, Cge 0.9836, Sge 0.4500
# 03523, T 1312.6, Ltr 0.2194, Lge 0.1516, Ctr 0.9761, Str 0.7188, Cge 0.9846, Sge 0.4600
# 03578, T 1332.6, Ltr 0.1490, Lge 0.1425, Ctr 0.9874, Str 0.9375, Cge 0.9846, Sge 0.5300
# 03633, T 1353.0, Ltr 0.1951, Lge 0.0889, Ctr 0.9860, Str 0.8750, Cge 0.9883, Sge 0.5600
# 03686, T 1373.0, Ltr 0.1586, Lge 0.1016, Ctr 0.9900, Str 0.9062, Cge 0.9875, Sge 0.4900
# 03741, T 1393.4, Ltr 0.1404, Lge 0.1356, Ctr 0.9911, Str 0.8750, Cge 0.9855, Sge 0.5800
# 03794, T 1413.3, Ltr 0.1938, Lge 0.1298, Ctr 0.9852, Str 0.7812, Cge 0.9828, Sge 0.4300
# 03847, T 1433.7, Ltr 0.1412, Lge 0.1183, Ctr 0.9899, Str 0.8750, Cge 0.9870, Sge 0.5900
# 03901, T 1453.8, Ltr 0.1894, Lge 0.0941, Ctr 0.9842, Str 0.8438, Cge 0.9890, Sge 0.6000
# 03955, T 1473.8, Ltr 0.1605, Lge 0.0935, Ctr 0.9867, Str 0.8125, Cge 0.9876, Sge 0.5800
# 04008, T 1493.9, Ltr 0.1560, Lge 0.0700, Ctr 0.9886, Str 0.8438, Cge 0.9927, Sge 0.6900
# 04062, T 1514.0, Ltr 0.2174, Lge 0.1912, Ctr 0.9865, Str 0.7812, Cge 0.9780, Sge 0.3800
# 04115, T 1534.0, Ltr 0.1537, Lge 0.0797, Ctr 0.9852, Str 0.8438, Cge 0.9891, Sge 0.5900
# 04169, T 1554.1, Ltr 0.1586, Lge 0.1071, Ctr 0.9871, Str 0.8438, Cge 0.9864, Sge 0.6100
# 04223, T 1574.5, Ltr 0.1071, Lge 0.1316, Ctr 1.0000, Str 1.0000, Cge 0.9869, Sge 0.6200
# 04278, T 1594.7, Ltr 0.1270, Lge 0.1329, Ctr 0.9985, Str 0.9375, Cge 0.9850, Sge 0.5400
# 04333, T 1614.8, Ltr 0.1352, Lge 0.1023, Ctr 0.9929, Str 0.9375, Cge 0.9864, Sge 0.5300
# 04386, T 1635.2, Ltr 0.1423, Lge 0.0890, Ctr 0.9894, Str 0.8125, Cge 0.9875, Sge 0.4600
# 04440, T 1655.1, Ltr 0.1320, Lge 0.0963, Ctr 0.9994, Str 0.9688, Cge 0.9882, Sge 0.5900
# 04495, T 1675.5, Ltr 0.1603, Lge 0.1094, Ctr 0.9889, Str 0.8750, Cge 0.9876, Sge 0.4800
# 04548, T 1695.7, Ltr 0.1474, Lge 0.1107, Ctr 0.9949, Str 0.9375, Cge 0.9868, Sge 0.5600
# 04602, T 1715.5, Ltr 0.1608, Lge 0.1791, Ctr 0.9960, Str 0.8438, Cge 0.9811, Sge 0.3700
# 04656, T 1735.6, Ltr 0.1416, Lge 0.1130, Ctr 0.9899, Str 0.8438, Cge 0.9865, Sge 0.5600
# 04710, T 1755.5, Ltr 0.1868, Lge 0.1135, Ctr 0.9944, Str 0.9375, Cge 0.9862, Sge 0.5500
# 04764, T 1775.7, Ltr 0.1466, Lge 0.0730, Ctr 0.9916, Str 0.9375, Cge 0.9901, Sge 0.6600
# 04819, T 1795.9, Ltr 0.1147, Lge 0.0881, Ctr 0.9966, Str 0.9688, Cge 0.9906, Sge 0.6900
# 04874, T 1816.0, Ltr 0.1130, Lge 0.1065, Ctr 0.9987, Str 0.9688, Cge 0.9868, Sge 0.5900
# 04928, T 1836.3, Ltr 0.1979, Lge 0.0953, Ctr 0.9909, Str 0.8750, Cge 0.9885, Sge 0.5200
# 04982, T 1856.2, Ltr 0.1319, Lge 0.1024, Ctr 0.9929, Str 0.9062, Cge 0.9875, Sge 0.6000
# 05036, T 1876.4, Ltr 0.1575, Lge 0.0744, Ctr 0.9910, Str 0.9062, Cge 0.9914, Sge 0.6300
# 05090, T 1896.6, Ltr 0.1387, Lge 0.1054, Ctr 0.9938, Str 0.9062, Cge 0.9877, Sge 0.5800
# 05144, T 1916.8, Ltr 0.1196, Lge 0.1088, Ctr 0.9929, Str 0.8750, Cge 0.9857, Sge 0.4800
# 05198, T 1936.6, Ltr 0.1441, Lge 0.0758, Ctr 0.9912, Str 0.9375, Cge 0.9902, Sge 0.6900
# 05252, T 1957.3, Ltr 0.1296, Lge 0.1036, Ctr 0.9957, Str 0.8750, Cge 0.9909, Sge 0.7000
# 05304, T 1976.9, Ltr 0.1311, Lge 0.1073, Ctr 0.9956, Str 0.9062, Cge 0.9883, Sge 0.6700
# 05359, T 1997.1, Ltr 0.0968, Lge 0.1633, Ctr 1.0000, Str 1.0000, Cge 0.9860, Sge 0.5500
# 05413, T 2017.0, Ltr 0.1550, Lge 0.0875, Ctr 0.9903, Str 0.9062, Cge 0.9896, Sge 0.6600
# 05466, T 2037.4, Ltr 0.1204, Lge 0.2264, Ctr 0.9966, Str 0.9062, Cge 0.9834, Sge 0.6400
# 05519, T 2057.5, Ltr 0.1255, Lge 0.1421, Ctr 0.9953, Str 0.9375, Cge 0.9868, Sge 0.6100
# 05573, T 2077.5, Ltr 0.1255, Lge 0.0956, Ctr 0.9941, Str 0.9062, Cge 0.9900, Sge 0.6800
# 05628, T 2098.2, Ltr 0.1427, Lge 0.0803, Ctr 0.9953, Str 0.9062, Cge 0.9893, Sge 0.6500
# 05683, T 2118.4, Ltr 0.1344, Lge 0.0857, Ctr 0.9934, Str 0.9062, Cge 0.9909, Sge 0.6000
# 05739, T 2138.5, Ltr 0.1634, Lge 0.1224, Ctr 0.9931, Str 0.9375, Cge 0.9875, Sge 0.5600
# 05793, T 2158.9, Ltr 0.1784, Lge 0.0720, Ctr 0.9853, Str 0.8750, Cge 0.9905, Sge 0.5800
# 05847, T 2179.2, Ltr 0.1259, Lge 0.0641, Ctr 0.9975, Str 0.9688, Cge 0.9921, Sge 0.6800
# 05902, T 2199.3, Ltr 0.0840, Lge 0.1022, Ctr 0.9974, Str 0.9688, Cge 0.9880, Sge 0.5600
# 05957, T 2219.4, Ltr 0.1161, Lge 0.0861, Ctr 0.9978, Str 0.9688, Cge 0.9906, Sge 0.6900
# 06010, T 2239.6, Ltr 0.1470, Lge 0.0660, Ctr 0.9894, Str 0.8125, Cge 0.9906, Sge 0.6700
# 06065, T 2259.7, Ltr 0.1664, Lge 0.1401, Ctr 0.9937, Str 0.9375, Cge 0.9831, Sge 0.5000
# 06120, T 2279.9, Ltr 0.1733, Lge 0.0901, Ctr 0.9829, Str 0.8438, Cge 0.9880, Sge 0.5600
# 06173, T 2300.1, Ltr 0.1269, Lge 0.0721, Ctr 0.9936, Str 0.9375, Cge 0.9926, Sge 0.7600
# 06228, T 2320.4, Ltr 0.1716, Lge 0.0702, Ctr 0.9926, Str 0.9375, Cge 0.9908, Sge 0.6500
# 06282, T 2340.4, Ltr 0.1386, Lge 0.0545, Ctr 0.9975, Str 0.9688, Cge 0.9926, Sge 0.7100
# 06337, T 2360.9, Ltr 0.1400, Lge 0.0512, Ctr 0.9960, Str 0.9062, Cge 0.9926, Sge 0.6100
# 06391, T 2380.5, Ltr 0.1468, Lge 0.0791, Ctr 0.9965, Str 0.8750, Cge 0.9894, Sge 0.6300
# 06446, T 2400.6, Ltr 0.1655, Lge 0.0847, Ctr 0.9900, Str 0.8750, Cge 0.9897, Sge 0.5800
# 06500, T 2420.6, Ltr 0.1530, Lge 0.0538, Ctr 0.9878, Str 0.7812, Cge 0.9925, Sge 0.6800
# 06553, T 2440.5, Ltr 0.1442, Lge 0.0634, Ctr 0.9969, Str 0.9375, Cge 0.9919, Sge 0.7500
# 06609, T 2460.8, Ltr 0.0933, Lge 0.0678, Ctr 0.9987, Str 0.9688, Cge 0.9912, Sge 0.6900
# 06663, T 2481.0, Ltr 0.1460, Lge 0.0936, Ctr 0.9953, Str 0.9375, Cge 0.9879, Sge 0.5400
# 06716, T 2501.4, Ltr 0.1505, Lge 0.0685, Ctr 0.9941, Str 0.9375, Cge 0.9914, Sge 0.7200
# 06769, T 2521.4, Ltr 0.1400, Lge 0.0530, Ctr 0.9955, Str 0.9062, Cge 0.9931, Sge 0.7400
# 06823, T 2541.4, Ltr 0.1310, Lge 0.0799, Ctr 0.9936, Str 0.8750, Cge 0.9896, Sge 0.6900
# 06878, T 2561.7, Ltr 0.1640, Lge 0.0848, Ctr 0.9885, Str 0.7812, Cge 0.9884, Sge 0.5900
# 06931, T 2581.7, Ltr 0.1395, Lge 0.0783, Ctr 0.9954, Str 0.9375, Cge 0.9904, Sge 0.6400
# 06986, T 2601.9, Ltr 0.1150, Lge 0.0546, Ctr 0.9969, Str 0.9375, Cge 0.9923, Sge 0.7400
# 07041, T 2622.1, Ltr 0.0829, Lge 0.0574, Ctr 1.0000, Str 1.0000, Cge 0.9923, Sge 0.6700
# 07095, T 2642.4, Ltr 0.1719, Lge 0.1901, Ctr 0.9907, Str 0.9375, Cge 0.9819, Sge 0.2800
# 07149, T 2662.5, Ltr 0.2284, Lge 0.0478, Ctr 0.9789, Str 0.8438, Cge 0.9934, Sge 0.7100
# 07204, T 2682.8, Ltr 0.1277, Lge 0.0614, Ctr 0.9923, Str 0.8750, Cge 0.9914, Sge 0.6000
# 07256, T 2703.0, Ltr 0.2056, Lge 0.0849, Ctr 0.9938, Str 0.9062, Cge 0.9910, Sge 0.6400
# 07311, T 2723.1, Ltr 0.1456, Lge 0.0573, Ctr 0.9967, Str 0.9688, Cge 0.9924, Sge 0.6700
# 07366, T 2743.3, Ltr 0.1366, Lge 0.0878, Ctr 0.9993, Str 0.9688, Cge 0.9898, Sge 0.7200
# 07420, T 2764.0, Ltr 0.1349, Lge 0.0462, Ctr 0.9953, Str 0.9375, Cge 0.9948, Sge 0.7700
# 07472, T 2783.6, Ltr 0.1244, Lge 0.0604, Ctr 0.9955, Str 0.9375, Cge 0.9929, Sge 0.7300
# 07528, T 2803.8, Ltr 0.1206, Lge 0.0890, Ctr 1.0000, Str 1.0000, Cge 0.9875, Sge 0.5300
# 07583, T 2824.1, Ltr 0.1248, Lge 0.0860, Ctr 0.9993, Str 0.9688, Cge 0.9910, Sge 0.7300
# 07636, T 2844.1, Ltr 0.1737, Lge 0.1036, Ctr 0.9909, Str 0.9062, Cge 0.9891, Sge 0.6600
# 07689, T 2864.1, Ltr 0.1297, Lge 0.0718, Ctr 0.9974, Str 0.9375, Cge 0.9933, Sge 0.7400
# 07744, T 2884.4, Ltr 0.1139, Lge 0.0905, Ctr 0.9962, Str 0.9375, Cge 0.9888, Sge 0.5700
# 07797, T 2904.6, Ltr 0.1405, Lge 0.0703, Ctr 0.9975, Str 0.9062, Cge 0.9905, Sge 0.6900
# 07852, T 2924.8, Ltr 0.1078, Lge 0.0566, Ctr 0.9986, Str 0.9375, Cge 0.9926, Sge 0.7600
# 07906, T 2944.9, Ltr 0.1498, Lge 0.0772, Ctr 0.9923, Str 0.8750, Cge 0.9902, Sge 0.6400
# 07959, T 2965.1, Ltr 0.1378, Lge 0.0657, Ctr 0.9919, Str 0.9375, Cge 0.9919, Sge 0.6900
# 08012, T 2985.2, Ltr 0.1468, Lge 0.0639, Ctr 0.9872, Str 0.8438, Cge 0.9919, Sge 0.6600
# 08067, T 3005.5, Ltr 0.1473, Lge 0.0555, Ctr 0.9967, Str 0.9688, Cge 0.9930, Sge 0.7400
# 08121, T 3025.6, Ltr 0.0928, Lge 0.0502, Ctr 0.9963, Str 0.9375, Cge 0.9930, Sge 0.6500
# 08174, T 3045.9, Ltr 0.1561, Lge 0.0637, Ctr 0.9906, Str 0.8750, Cge 0.9929, Sge 0.7300
# 08229, T 3066.2, Ltr 0.1539, Lge 0.1363, Ctr 0.9885, Str 0.8438, Cge 0.9887, Sge 0.6600
# 08283, T 3086.3, Ltr 0.1270, Lge 0.0807, Ctr 0.9930, Str 0.9375, Cge 0.9889, Sge 0.5900
# 08337, T 3107.0, Ltr 0.1001, Lge 0.0721, Ctr 0.9948, Str 0.9375, Cge 0.9909, Sge 0.6900
# 08391, T 3126.5, Ltr 0.1344, Lge 0.0732, Ctr 0.9944, Str 0.9375, Cge 0.9917, Sge 0.7000
# 08446, T 3146.8, Ltr 0.1127, Lge 0.0597, Ctr 0.9994, Str 0.9688, Cge 0.9907, Sge 0.6600
# 08502, T 3167.0, Ltr 0.1328, Lge 0.0496, Ctr 0.9959, Str 0.9375, Cge 0.9928, Sge 0.7600
# 08556, T 3187.1, Ltr 0.1424, Lge 0.0844, Ctr 0.9953, Str 0.9062, Cge 0.9901, Sge 0.6900
# 08612, T 3207.4, Ltr 0.1434, Lge 0.0986, Ctr 0.9949, Str 0.9062, Cge 0.9863, Sge 0.5700
# 08666, T 3227.6, Ltr 0.1772, Lge 0.0893, Ctr 0.9865, Str 0.8438, Cge 0.9894, Sge 0.6100
# 08720, T 3247.8, Ltr 0.1491, Lge 0.0896, Ctr 0.9906, Str 0.9375, Cge 0.9893, Sge 0.6800
# 08775, T 3268.2, Ltr 0.1873, Lge 0.0815, Ctr 0.9895, Str 0.8438, Cge 0.9905, Sge 0.6900
# 08830, T 3288.4, Ltr 0.1128, Lge 0.0790, Ctr 0.9988, Str 0.9688, Cge 0.9881, Sge 0.6400
# 08882, T 3308.5, Ltr 0.1164, Lge 0.1097, Ctr 0.9957, Str 0.9688, Cge 0.9896, Sge 0.7200
# 08935, T 3328.4, Ltr 0.1509, Lge 0.0733, Ctr 0.9891, Str 0.8438, Cge 0.9901, Sge 0.6100
# 08990, T 3348.7, Ltr 0.1394, Lge 0.0357, Ctr 0.9969, Str 0.9375, Cge 0.9954, Sge 0.7800
# 09045, T 3368.9, Ltr 0.1263, Lge 0.0905, Ctr 0.9933, Str 0.8750, Cge 0.9882, Sge 0.6400
# 09098, T 3388.9, Ltr 0.1421, Lge 0.0697, Ctr 0.9929, Str 0.9062, Cge 0.9935, Sge 0.7900
# 09152, T 3409.1, Ltr 0.1357, Lge 0.0765, Ctr 0.9904, Str 0.7812, Cge 0.9904, Sge 0.6100
# 09207, T 3429.5, Ltr 0.1691, Lge 0.0696, Ctr 0.9950, Str 0.9688, Cge 0.9917, Sge 0.6700
# 09260, T 3449.4, Ltr 0.1421, Lge 0.0924, Ctr 0.9928, Str 0.9062, Cge 0.9896, Sge 0.6700
# 09315, T 3469.8, Ltr 0.1280, Lge 0.0687, Ctr 0.9941, Str 0.9375, Cge 0.9906, Sge 0.6900
# 09370, T 3490.1, Ltr 0.1428, Lge 0.0758, Ctr 0.9968, Str 0.9062, Cge 0.9920, Sge 0.7100
# 09423, T 3510.1, Ltr 0.1391, Lge 0.0665, Ctr 0.9956, Str 0.9062, Cge 0.9915, Sge 0.7100
# 09479, T 3530.1, Ltr 0.1644, Lge 0.1169, Ctr 0.9915, Str 0.9062, Cge 0.9883, Sge 0.6300
# 09535, T 3550.4, Ltr 0.1296, Lge 0.0786, Ctr 0.9898, Str 0.9062, Cge 0.9901, Sge 0.6600
# 09590, T 3570.6, Ltr 0.1611, Lge 0.0532, Ctr 0.9871, Str 0.8750, Cge 0.9927, Sge 0.7000
# 09644, T 3590.7, Ltr 0.1599, Lge 0.0585, Ctr 0.9943, Str 0.9375, Cge 0.9918, Sge 0.6900
# 09698, T 3610.8, Ltr 0.1409, Lge 0.0598, Ctr 0.9944, Str 0.8750, Cge 0.9926, Sge 0.7200
# 09754, T 3631.0, Ltr 0.1637, Lge 0.0523, Ctr 0.9900, Str 0.9375, Cge 0.9926, Sge 0.7300
# 09808, T 3651.2, Ltr 0.1112, Lge 0.0894, Ctr 0.9931, Str 0.9375, Cge 0.9924, Sge 0.7200
# 09863, T 3671.4, Ltr 0.1136, Lge 0.0938, Ctr 0.9979, Str 0.9688, Cge 0.9877, Sge 0.5300
# 09918, T 3691.5, Ltr 0.1333, Lge 0.0761, Ctr 0.9890, Str 0.8750, Cge 0.9905, Sge 0.6200
# 09972, T 3711.8, Ltr 0.1063, Lge 0.0603, Ctr 0.9975, Str 0.9375, Cge 0.9933, Sge 0.7100

 

結果を可視化する

# This cell visualizes the results of training. You can visualize the
# intermediate results by interrupting execution of the cell above, and running
# this cell. You can then resume training by simply executing the above cell
# again.

def softmax_prob_last_dim(x):  # pylint: disable=redefined-outer-name
  e = np.exp(x)
  return e[:, -1] / np.sum(e, axis=-1)


# Plot results curves.
fig = plt.figure(1, figsize=(18, 3))
fig.clf()
x = np.array(logged_iterations)
# Loss.
y_tr = losses_tr
y_ge = losses_ge
ax = fig.add_subplot(1, 3, 1)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Loss across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Loss (binary cross-entropy)")
ax.legend()
# Correct.
y_tr = corrects_tr
y_ge = corrects_ge
ax = fig.add_subplot(1, 3, 2)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction correct across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction nodes/edges correct")
# Solved.
y_tr = solveds_tr
y_ge = solveds_ge
ax = fig.add_subplot(1, 3, 3)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction solved across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction examples solved")

# Plot graphs and results after each processing step.
# The white node is the start, and the black is the end. Other nodes are colored
# from red to purple to blue, where red means the model is confident the node is
# off the shortest path, blue means the model is confident the node is on the
# shortest path, and purplish colors mean the model isn't sure.
max_graphs_to_plot = 6
num_steps_to_plot = 4
node_size = 120
min_c = 0.3
num_graphs = len(raw_graphs)
targets = utils_np.graphs_tuple_to_data_dicts(test_values["target"])
step_indices = np.floor(
    np.linspace(0, num_processing_steps_ge - 1,
                num_steps_to_plot)).astype(int).tolist()
outputs = list(
    zip(*(utils_np.graphs_tuple_to_data_dicts(test_values["outputs"][i])
          for i in step_indices)))
h = min(num_graphs, max_graphs_to_plot)
w = num_steps_to_plot + 1
fig = plt.figure(101, figsize=(18, h * 3))
fig.clf()
ncs = []
for j, (graph, target, output) in enumerate(zip(raw_graphs, targets, outputs)):
  if j >= h:
    break
  pos = get_node_dict(graph, "pos")
  ground_truth = target["nodes"][:, -1]
  # Ground truth.
  iax = j * (1 + num_steps_to_plot) + 1
  ax = fig.add_subplot(h, w, iax)
  plotter = GraphPlotter(ax, graph, pos)
  color = {}
  for i, n in enumerate(plotter.nodes):
    color[n] = np.array([1.0 - ground_truth[i], 0.0, ground_truth[i], 1.0
                        ]) * (1.0 - min_c) + min_c
  plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
  ax.set_axis_on()
  ax.set_xticks([])
  ax.set_yticks([])
  try:
    ax.set_facecolor([0.9] * 3 + [1.0])
  except AttributeError:
    ax.set_axis_bgcolor([0.9] * 3 + [1.0])
  ax.grid(None)
  ax.set_title("Ground truth\nSolution length: {}".format(
      plotter.solution_length))
  # Prediction.
  for k, outp in enumerate(output):
    iax = j * (1 + num_steps_to_plot) + 2 + k
    ax = fig.add_subplot(h, w, iax)
    plotter = GraphPlotter(ax, graph, pos)
    color = {}
    prob = softmax_prob_last_dim(outp["nodes"])
    for i, n in enumerate(plotter.nodes):
      color[n] = np.array([1.0 - prob[n], 0.0, prob[n], 1.0
                          ]) * (1.0 - min_c) + min_c
    plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
    ax.set_title("Model-predicted\nStep {:02d} / {:02d}".format(
        step_indices[k] + 1, step_indices[-1] + 1))

 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com