TensorFlow : Graph Nets : グラフの最短経路を見つける (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/07/2019
* 本ページは、deepmind/graph_nets の README.md 及び “Find the shortest path in a graph” を翻訳した上で適宜、補足説明したものです:
- deepmind/graph_nets/blob/master/README.md
- deepmind/graph_nets/blob/master/graph_nets/demos/shortest_path.ipynb
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
Graph Nets : README.md
Graph Nets は TensorFlow と Sonnet でグラフネットワークを構築するための DeepMind のライブラリです。
グラフネットワークとは何でしょう?
グラフネットワークは入力としてグラフを取り出力としてグラフを返します。入力グラフはエッジ- (E)、ノード- (V) とグローバルレベル (u) 属性を持ちます。出力グラフは同じ構造、しかし更新された属性を持ちます。グラフネットワークは「グラフニューラルネットワーク」のより広いファミリーの一部です (Scarselli et al., 2009)。
グラフネットワークについてより学習するためには、私達の arXiv ペーパー: Relational inductive biases, deep learning, and graph networks を見てください。

使用方法サンプル
次のコードは単純なグラフネット・モジュールを構築してそれをデータに接続します。
import graph_nets as gn
import sonnet as snt
# Provide your own functions to generate graph-structured data.
input_graphs = get_graphs()
# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
edge_model_fn=lambda: snt.nets.MLP([32, 32]),
node_model_fn=lambda: snt.nets.MLP([32, 32]),
global_model_fn=lambda: snt.nets.MLP([32, 32]))
# Pass the input graphs to the graph network, and return the output graphs.
output_graphs = graph_net_module(input_graphs)
Graph Nets : グラフの最短経路を見つける
「最短経路デモ」はランダムグラフを作成して、任意の 2 つのノードの間の最短経路上のノードとエッジをラベル付けするためにグラフネットワークを訓練します。メッセージパッシングのステップのシークエンスに渡り (各ステップのプロットで描かれるように)、モデルは最短経路のその予測を改良していきます。
このノートブックと伴うコードはグラフの 2 つのノード間に最短経路を予測することを学習するために Graph Nets ライブラリをどのように使用するかを示します。
開始と終了ノードが与えられたとき、ネットワークは最短経路のノードとエッジをラベル付けするために訓練されます。
訓練後、ネットワークの予測能力はその出力を真の最短経路と比較することにより示されます。それから汎化するためのネットワークの能力がテストされます、類似のしかしより巨大なグラフの最短経路を予測するためにそれを使用することによって。
インポート
from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import itertools import time from graph_nets import graphs from graph_nets import utils_np from graph_nets import utils_tf from graph_nets.demos import models import matplotlib.pyplot as plt import networkx as nx import numpy as np from scipy import spatial import tensorflow as tf SEED = 1 np.random.seed(SEED) tf.set_random_seed(SEED)
ヘルパー関数
DISTANCE_WEIGHT_NAME = "distance" # The name for the distance edge attribute.
def pairwise(iterable):
"""s -> (s0,s1), (s1,s2), (s2, s3), ..."""
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
def set_diff(seq0, seq1):
"""Return the set difference between 2 sequences as a list."""
return list(set(seq0) - set(seq1))
def to_one_hot(indices, max_value, axis=-1):
one_hot = np.eye(max_value)[indices]
if axis not in (-1, one_hot.ndim):
one_hot = np.moveaxis(one_hot, -1, axis)
return one_hot
def get_node_dict(graph, attr):
"""Return a `dict` of node:attribute pairs from a graph."""
return {k: v[attr] for k, v in graph.node.items()}
def generate_graph(rand,
num_nodes_min_max,
dimensions=2,
theta=1000.0,
rate=1.0):
"""Creates a connected graph.
The graphs are geographic threshold graphs, but with added edges via a
minimum spanning tree algorithm, to ensure all nodes are connected.
Args:
rand: A random seed for the graph generator. Default= None.
num_nodes_min_max: A sequence [lower, upper) number of nodes per graph.
dimensions: (optional) An `int` number of dimensions for the positions.
Default= 2.
theta: (optional) A `float` threshold parameters for the geographic
threshold graph's threshold. Large values (1000+) make mostly trees. Try
20-60 for good non-trees. Default=1000.0.
rate: (optional) A rate parameter for the node weight exponential sampling
distribution. Default= 1.0.
Returns:
The graph.
"""
# Sample num_nodes.
num_nodes = rand.randint(*num_nodes_min_max)
# Create geographic threshold graph.
pos_array = rand.uniform(size=(num_nodes, dimensions))
pos = dict(enumerate(pos_array))
weight = dict(enumerate(rand.exponential(rate, size=num_nodes)))
geo_graph = nx.geographical_threshold_graph(
num_nodes, theta, pos=pos, weight=weight)
# Create minimum spanning tree across geo_graph's nodes.
distances = spatial.distance.squareform(spatial.distance.pdist(pos_array))
i_, j_ = np.meshgrid(range(num_nodes), range(num_nodes), indexing="ij")
weighted_edges = list(zip(i_.ravel(), j_.ravel(), distances.ravel()))
mst_graph = nx.Graph()
mst_graph.add_weighted_edges_from(weighted_edges, weight=DISTANCE_WEIGHT_NAME)
mst_graph = nx.minimum_spanning_tree(mst_graph, weight=DISTANCE_WEIGHT_NAME)
# Put geo_graph's node attributes into the mst_graph.
for i in mst_graph.nodes():
mst_graph.node[i].update(geo_graph.node[i])
# Compose the graphs.
combined_graph = nx.compose_all((mst_graph, geo_graph.copy()))
# Put all distance weights into edge attributes.
for i, j in combined_graph.edges():
combined_graph.get_edge_data(i, j).setdefault(DISTANCE_WEIGHT_NAME,
distances[i, j])
return combined_graph, mst_graph, geo_graph
def add_shortest_path(rand, graph, min_length=1):
"""Samples a shortest path from A to B and adds attributes to indicate it.
Args:
rand: A random seed for the graph generator. Default= None.
graph: A `nx.Graph`.
min_length: (optional) An `int` minimum number of edges in the shortest
path. Default= 1.
Returns:
The `nx.DiGraph` with the shortest path added.
Raises:
ValueError: All shortest paths are below the minimum length
"""
# Map from node pairs to the length of their shortest path.
pair_to_length_dict = {}
try:
# This is for compatibility with older networkx.
lengths = nx.all_pairs_shortest_path_length(graph).items()
except AttributeError:
# This is for compatibility with newer networkx.
lengths = list(nx.all_pairs_shortest_path_length(graph))
for x, yy in lengths:
for y, l in yy.items():
if l >= min_length:
pair_to_length_dict[x, y] = l
if max(pair_to_length_dict.values()) < min_length:
raise ValueError("All shortest paths are below the minimum length")
# The node pairs which exceed the minimum length.
node_pairs = list(pair_to_length_dict)
# Computes probabilities per pair, to enforce uniform sampling of each
# shortest path lengths.
# The counts of pairs per length.
counts = collections.Counter(pair_to_length_dict.values())
prob_per_length = 1.0 / len(counts)
probabilities = [
prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
]
# Choose the start and end points.
i = rand.choice(len(node_pairs), p=probabilities)
start, end = node_pairs[i]
path = nx.shortest_path(
graph, source=start, target=end, weight=DISTANCE_WEIGHT_NAME)
# Creates a directed graph, to store the directed path from start to end.
digraph = graph.to_directed()
# Add the "start", "end", and "solution" attributes to the nodes and edges.
digraph.add_node(start, start=True)
digraph.add_node(end, end=True)
digraph.add_nodes_from(set_diff(digraph.nodes(), [start]), start=False)
digraph.add_nodes_from(set_diff(digraph.nodes(), [end]), end=False)
digraph.add_nodes_from(set_diff(digraph.nodes(), path), solution=False)
digraph.add_nodes_from(path, solution=True)
path_edges = list(pairwise(path))
digraph.add_edges_from(set_diff(digraph.edges(), path_edges), solution=False)
digraph.add_edges_from(path_edges, solution=True)
return digraph
def graph_to_input_target(graph):
"""Returns 2 graphs with input and target feature vectors for training.
Args:
graph: An `nx.DiGraph` instance.
Returns:
The input `nx.DiGraph` instance.
The target `nx.DiGraph` instance.
Raises:
ValueError: unknown node type
"""
def create_feature(attr, fields):
return np.hstack([np.array(attr[field], dtype=float) for field in fields])
input_node_fields = ("pos", "weight", "start", "end")
input_edge_fields = ("distance",)
target_node_fields = ("solution",)
target_edge_fields = ("solution",)
input_graph = graph.copy()
target_graph = graph.copy()
solution_length = 0
for node_index, node_feature in graph.nodes(data=True):
input_graph.add_node(
node_index, features=create_feature(node_feature, input_node_fields))
target_node = to_one_hot(
create_feature(node_feature, target_node_fields).astype(int), 2)[0]
target_graph.add_node(node_index, features=target_node)
solution_length += int(node_feature["solution"])
solution_length /= graph.number_of_nodes()
for receiver, sender, features in graph.edges(data=True):
input_graph.add_edge(
sender, receiver, features=create_feature(features, input_edge_fields))
target_edge = to_one_hot(
create_feature(features, target_edge_fields).astype(int), 2)[0]
target_graph.add_edge(sender, receiver, features=target_edge)
input_graph.graph["features"] = np.array([0.0])
target_graph.graph["features"] = np.array([solution_length], dtype=float)
return input_graph, target_graph
def generate_networkx_graphs(rand, num_examples, num_nodes_min_max, theta):
"""Generate graphs for training.
Args:
rand: A random seed (np.RandomState instance).
num_examples: Total number of graphs to generate.
num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
graph. The number of nodes for a graph is uniformly sampled within this
range.
theta: (optional) A `float` threshold parameters for the geographic
threshold graph's threshold. Default= the number of nodes.
Returns:
input_graphs: The list of input graphs.
target_graphs: The list of output graphs.
graphs: The list of generated graphs.
"""
input_graphs = []
target_graphs = []
graphs = []
for _ in range(num_examples):
graph = generate_graph(rand, num_nodes_min_max, theta=theta)[0]
graph = add_shortest_path(rand, graph)
input_graph, target_graph = graph_to_input_target(graph)
input_graphs.append(input_graph)
target_graphs.append(target_graph)
graphs.append(graph)
return input_graphs, target_graphs, graphs
def create_placeholders(rand, batch_size, num_nodes_min_max, theta):
"""Creates placeholders for the model training and evaluation.
Args:
rand: A random seed (np.RandomState instance).
batch_size: Total number of graphs per batch.
num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
graph. The number of nodes for a graph is uniformly sampled within this
range.
theta: A `float` threshold parameters for the geographic threshold graph's
threshold. Default= the number of nodes.
Returns:
input_ph: The input graph's placeholders, as a graph namedtuple.
target_ph: The target graph's placeholders, as a graph namedtuple.
"""
# Create some example data for inspecting the vector sizes.
input_graphs, target_graphs, _ = generate_networkx_graphs(
rand, batch_size, num_nodes_min_max, theta)
input_ph = utils_tf.placeholders_from_networkxs(input_graphs)
target_ph = utils_tf.placeholders_from_networkxs(target_graphs)
return input_ph, target_ph
def create_feed_dict(rand, batch_size, num_nodes_min_max, theta, input_ph,
target_ph):
"""Creates placeholders for the model training and evaluation.
Args:
rand: A random seed (np.RandomState instance).
batch_size: Total number of graphs per batch.
num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
graph. The number of nodes for a graph is uniformly sampled within this
range.
theta: A `float` threshold parameters for the geographic threshold graph's
threshold. Default= the number of nodes.
input_ph: The input graph's placeholders, as a graph namedtuple.
target_ph: The target graph's placeholders, as a graph namedtuple.
Returns:
feed_dict: The feed `dict` of input and target placeholders and data.
raw_graphs: The `dict` of raw networkx graphs.
"""
inputs, targets, raw_graphs = generate_networkx_graphs(
rand, batch_size, num_nodes_min_max, theta)
input_graphs = utils_np.networkxs_to_graphs_tuple(inputs)
target_graphs = utils_np.networkxs_to_graphs_tuple(targets)
feed_dict = {input_ph: input_graphs, target_ph: target_graphs}
return feed_dict, raw_graphs
def compute_accuracy(target, output, use_nodes=True, use_edges=False):
"""Calculate model accuracy.
Returns the number of correctly predicted shortest path nodes and the number
of completely solved graphs (100% correct predictions).
Args:
target: A `graphs.GraphsTuple` that contains the target graph.
output: A `graphs.GraphsTuple` that contains the output graph.
use_nodes: A `bool` indicator of whether to compute node accuracy or not.
use_edges: A `bool` indicator of whether to compute edge accuracy or not.
Returns:
correct: A `float` fraction of correctly labeled nodes/edges.
solved: A `float` fraction of graphs that are completely correctly labeled.
Raises:
ValueError: Nodes or edges (or both) must be used
"""
if not use_nodes and not use_edges:
raise ValueError("Nodes or edges (or both) must be used")
tdds = utils_np.graphs_tuple_to_data_dicts(target)
odds = utils_np.graphs_tuple_to_data_dicts(output)
cs = []
ss = []
for td, od in zip(tdds, odds):
xn = np.argmax(td["nodes"], axis=-1)
yn = np.argmax(od["nodes"], axis=-1)
xe = np.argmax(td["edges"], axis=-1)
ye = np.argmax(od["edges"], axis=-1)
c = []
if use_nodes:
c.append(xn == yn)
if use_edges:
c.append(xe == ye)
c = np.concatenate(c, axis=0)
s = np.all(c)
cs.append(c)
ss.append(s)
correct = np.mean(np.concatenate(cs, axis=0))
solved = np.mean(np.stack(ss))
return correct, solved
def create_loss_ops(target_op, output_ops):
loss_ops = [
tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +
tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)
for output_op in output_ops
]
return loss_ops
def make_all_runnable_in_session(*args):
"""Lets an iterable of TF graphs be output from a session as NP graphs."""
return [utils_tf.make_runnable_in_session(a) for a in args]
class GraphPlotter(object):
def __init__(self, ax, graph, pos):
self._ax = ax
self._graph = graph
self._pos = pos
self._base_draw_kwargs = dict(G=self._graph, pos=self._pos, ax=self._ax)
self._solution_length = None
self._nodes = None
self._edges = None
self._start_nodes = None
self._end_nodes = None
self._solution_nodes = None
self._intermediate_solution_nodes = None
self._solution_edges = None
self._non_solution_nodes = None
self._non_solution_edges = None
self._ax.set_axis_off()
@property
def solution_length(self):
if self._solution_length is None:
self._solution_length = len(self._solution_edges)
return self._solution_length
@property
def nodes(self):
if self._nodes is None:
self._nodes = self._graph.nodes()
return self._nodes
@property
def edges(self):
if self._edges is None:
self._edges = self._graph.edges()
return self._edges
@property
def start_nodes(self):
if self._start_nodes is None:
self._start_nodes = [
n for n in self.nodes if self._graph.node[n].get("start", False)
]
return self._start_nodes
@property
def end_nodes(self):
if self._end_nodes is None:
self._end_nodes = [
n for n in self.nodes if self._graph.node[n].get("end", False)
]
return self._end_nodes
@property
def solution_nodes(self):
if self._solution_nodes is None:
self._solution_nodes = [
n for n in self.nodes if self._graph.node[n].get("solution", False)
]
return self._solution_nodes
@property
def intermediate_solution_nodes(self):
if self._intermediate_solution_nodes is None:
self._intermediate_solution_nodes = [
n for n in self.nodes
if self._graph.node[n].get("solution", False) and
not self._graph.node[n].get("start", False) and
not self._graph.node[n].get("end", False)
]
return self._intermediate_solution_nodes
@property
def solution_edges(self):
if self._solution_edges is None:
self._solution_edges = [
e for e in self.edges
if self._graph.get_edge_data(e[0], e[1]).get("solution", False)
]
return self._solution_edges
@property
def non_solution_nodes(self):
if self._non_solution_nodes is None:
self._non_solution_nodes = [
n for n in self.nodes
if not self._graph.node[n].get("solution", False)
]
return self._non_solution_nodes
@property
def non_solution_edges(self):
if self._non_solution_edges is None:
self._non_solution_edges = [
e for e in self.edges
if not self._graph.get_edge_data(e[0], e[1]).get("solution", False)
]
return self._non_solution_edges
def _make_draw_kwargs(self, **kwargs):
kwargs.update(self._base_draw_kwargs)
return kwargs
def _draw(self, draw_function, zorder=None, **kwargs):
draw_kwargs = self._make_draw_kwargs(**kwargs)
collection = draw_function(**draw_kwargs)
if collection is not None and zorder is not None:
try:
# This is for compatibility with older matplotlib.
collection.set_zorder(zorder)
except AttributeError:
# This is for compatibility with newer matplotlib.
collection[0].set_zorder(zorder)
return collection
def draw_nodes(self, **kwargs):
"""Useful kwargs: nodelist, node_size, node_color, linewidths."""
if ("node_color" in kwargs and
isinstance(kwargs["node_color"], collections.Sequence) and
len(kwargs["node_color"]) in {3, 4} and
not isinstance(kwargs["node_color"][0],
(collections.Sequence, np.ndarray))):
num_nodes = len(kwargs.get("nodelist", self.nodes))
kwargs["node_color"] = np.tile(
np.array(kwargs["node_color"])[None], [num_nodes, 1])
return self._draw(nx.draw_networkx_nodes, **kwargs)
def draw_edges(self, **kwargs):
"""Useful kwargs: edgelist, width."""
return self._draw(nx.draw_networkx_edges, **kwargs)
def draw_graph(self,
node_size=200,
node_color=(0.4, 0.8, 0.4),
node_linewidth=1.0,
edge_width=1.0):
# Plot nodes.
self.draw_nodes(
nodelist=self.nodes,
node_size=node_size,
node_color=node_color,
linewidths=node_linewidth,
zorder=20)
# Plot edges.
self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)
def draw_graph_with_solution(self,
node_size=200,
node_color=(0.4, 0.8, 0.4),
node_linewidth=1.0,
edge_width=1.0,
start_color="w",
end_color="k",
solution_node_linewidth=3.0,
solution_edge_width=3.0):
node_border_color = (0.0, 0.0, 0.0, 1.0)
node_collections = {}
# Plot start nodes.
node_collections["start nodes"] = self.draw_nodes(
nodelist=self.start_nodes,
node_size=node_size,
node_color=start_color,
linewidths=solution_node_linewidth,
edgecolors=node_border_color,
zorder=100)
# Plot end nodes.
node_collections["end nodes"] = self.draw_nodes(
nodelist=self.end_nodes,
node_size=node_size,
node_color=end_color,
linewidths=solution_node_linewidth,
edgecolors=node_border_color,
zorder=90)
# Plot intermediate solution nodes.
if isinstance(node_color, dict):
c = [node_color[n] for n in self.intermediate_solution_nodes]
else:
c = node_color
node_collections["intermediate solution nodes"] = self.draw_nodes(
nodelist=self.intermediate_solution_nodes,
node_size=node_size,
node_color=c,
linewidths=solution_node_linewidth,
edgecolors=node_border_color,
zorder=80)
# Plot solution edges.
node_collections["solution edges"] = self.draw_edges(
edgelist=self.solution_edges, width=solution_edge_width, zorder=70)
# Plot non-solution nodes.
if isinstance(node_color, dict):
c = [node_color[n] for n in self.non_solution_nodes]
else:
c = node_color
node_collections["non-solution nodes"] = self.draw_nodes(
nodelist=self.non_solution_nodes,
node_size=node_size,
node_color=c,
linewidths=node_linewidth,
edgecolors=node_border_color,
zorder=20)
# Plot non-solution edges.
node_collections["non-solution edges"] = self.draw_edges(
edgelist=self.non_solution_edges, width=edge_width, zorder=10)
# Set title as solution length.
self._ax.set_title("Solution length: {}".format(self.solution_length))
return node_collections
サンプルグラフを可視化する
seed = 1 #@param{type: 'integer'}
rand = np.random.RandomState(seed=seed)
num_examples = 15 #@param{type: 'integer'}
# Large values (1000+) make trees. Try 20-60 for good non-trees.
theta = 20 #@param{type: 'integer'}
num_nodes_min_max = (16, 17)
input_graphs, target_graphs, graphs = generate_networkx_graphs(
rand, num_examples, num_nodes_min_max, theta)
num = min(num_examples, 16)
w = 3
h = int(np.ceil(num / w))
fig = plt.figure(40, figsize=(w * 4, h * 4))
fig.clf()
for j, graph in enumerate(graphs):
ax = fig.add_subplot(h, w, j + 1)
pos = get_node_dict(graph, "pos")
plotter = GraphPlotter(ax, graph, pos)
plotter.draw_graph_with_solution()

モデル訓練と評価をセットアップする
# The model we explore includes three components:
# - An "Encoder" graph net, which independently encodes the edge, node, and
# global attributes (does not compute relations etc.).
# - A "Core" graph net, which performs N rounds of processing (message-passing)
# steps. The input to the Core is the concatenation of the Encoder's output
# and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
# the processing step).
# - A "Decoder" graph net, which independently decodes the edge, node, and
# global attributes (does not compute relations etc.), on each
# message-passing step.
#
# Hidden(t) Hidden(t+1)
# | ^
# *---------* | *------* | *---------*
# | | | | | | | |
# Input --->| Encoder | *->| Core |--*->| Decoder |---> Output(t)
# | |---->| | | |
# *---------* *------* *---------*
#
# The model is trained by supervised learning. Input graphs are procedurally
# generated, and output graphs have the same structure with the nodes and edges
# of the shortest path labeled (using 2-element 1-hot vectors). We could have
# predicted the shortest path only by labeling either the nodes or edges, and
# that does work, but we decided to predict both to demonstrate the flexibility
# of graph nets' outputs.
#
# The training loss is computed on the output of each processing step. The
# reason for this is to encourage the model to try to solve the problem in as
# few steps as possible. It also helps make the output of intermediate steps
# more interpretable.
#
# There's no need for a separate evaluate dataset because the inputs are
# never repeated, so the training loss is the measure of performance on graphs
# from the input distribution.
#
# We also evaluate how well the models generalize to graphs which are up to
# twice as large as those on which it was trained. The loss is computed only
# on the final processing step.
#
# Variables with the suffix _tr are training parameters, and variables with the
# suffix _ge are test/generalization parameters.
#
# After around 2000-5000 training iterations the model reaches near-perfect
# performance on graphs with between 8-16 nodes.
tf.reset_default_graph()
seed = 2
rand = np.random.RandomState(seed=seed)
# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10
# Data / training parameters.
num_training_iterations = 10000
theta = 20 # Large values (1000+) make trees. Try 20-60 for good non-trees.
batch_size_tr = 32
batch_size_ge = 100
# Number of nodes per graph sampled uniformly from this range.
num_nodes_min_max_tr = (8, 17)
num_nodes_min_max_ge = (16, 33)
# Data.
# Input and target placeholders.
input_ph, target_ph = create_placeholders(rand, batch_size_tr,
num_nodes_min_max_tr, theta)
# Connect the data to the model.
# Instantiate the model.
model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
# A list of outputs, one per processing step.
output_ops_tr = model(input_ph, num_processing_steps_tr)
output_ops_ge = model(input_ph, num_processing_steps_ge)
# Training loss.
loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
# Loss across processing steps.
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr
# Test/generalization loss.
loss_ops_ge = create_loss_ops(target_ph, output_ops_ge)
loss_op_ge = loss_ops_ge[-1] # Loss from final processing step.
# Optimizer.
learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)
# Lets an iterable of TF graphs be output from a session as NP graphs.
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)
セッションをリセットする
# This cell resets the Tensorflow session, but keeps the same computational # graph. try: sess.close() except NameError: pass sess = tf.Session() sess.run(tf.global_variables_initializer()) last_iteration = 0 logged_iterations = [] losses_tr = [] corrects_tr = [] solveds_tr = [] losses_ge = [] corrects_ge = [] solveds_ge = []
訓練を実行する
# You can interrupt this cell's training loop at any time, and visualize the
# intermediate results by running the next cell (below). You can then resume
# training by simply executing this cell again.
# How much time between logging and printing the current results.
log_every_seconds = 20
print("# (iteration number), T (elapsed seconds), "
"Ltr (training loss), Lge (test/generalization loss), "
"Ctr (training fraction nodes/edges labeled correctly), "
"Str (training fraction examples solved correctly), "
"Cge (test/generalization fraction nodes/edges labeled correctly), "
"Sge (test/generalization fraction examples solved correctly)")
start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
last_iteration = iteration
feed_dict, _ = create_feed_dict(rand, batch_size_tr, num_nodes_min_max_tr,
theta, input_ph, target_ph)
train_values = sess.run({
"step": step_op,
"target": target_ph,
"loss": loss_op_tr,
"outputs": output_ops_tr
},
feed_dict=feed_dict)
the_time = time.time()
elapsed_since_last_log = the_time - last_log_time
if elapsed_since_last_log > log_every_seconds:
last_log_time = the_time
feed_dict, raw_graphs = create_feed_dict(
rand, batch_size_ge, num_nodes_min_max_ge, theta, input_ph, target_ph)
test_values = sess.run({
"target": target_ph,
"loss": loss_op_ge,
"outputs": output_ops_ge
},
feed_dict=feed_dict)
correct_tr, solved_tr = compute_accuracy(
train_values["target"], train_values["outputs"][-1], use_edges=True)
correct_ge, solved_ge = compute_accuracy(
test_values["target"], test_values["outputs"][-1], use_edges=True)
elapsed = time.time() - start_time
losses_tr.append(train_values["loss"])
corrects_tr.append(correct_tr)
solveds_tr.append(solved_tr)
losses_ge.append(test_values["loss"])
corrects_ge.append(correct_ge)
solveds_ge.append(solved_ge)
logged_iterations.append(iteration)
print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, Str"
" {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
iteration, elapsed, train_values["loss"], test_values["loss"],
correct_tr, solved_tr, correct_ge, solved_ge))
# (iteration number), T (elapsed seconds), Ltr (training loss), Lge (test/generalization loss), Ctr (training fraction nodes/edges labeled correctly), Str (training fraction examples solved correctly), Cge (test/generalization fraction nodes/edges labeled correctly), Sge (test/generalization fraction examples solved correctly) # 00029, T 23.8, Ltr 0.8731, Lge 0.6658, Ctr 0.8596, Str 0.0000, Cge 0.9481, Sge 0.0000 # 00078, T 42.1, Ltr 0.6341, Lge 0.4758, Ctr 0.9056, Str 0.0000, Cge 0.9549, Sge 0.0000 # 00133, T 62.5, Ltr 0.5034, Lge 0.3845, Ctr 0.9172, Str 0.0000, Cge 0.9625, Sge 0.0200 # 00189, T 82.7, Ltr 0.5162, Lge 0.3417, Ctr 0.9166, Str 0.1250, Cge 0.9664, Sge 0.0200 # 00244, T 103.0, Ltr 0.4486, Lge 0.3383, Ctr 0.9343, Str 0.1250, Cge 0.9685, Sge 0.1600 # 00299, T 123.1, Ltr 0.4963, Lge 0.3507, Ctr 0.9184, Str 0.2500, Cge 0.9637, Sge 0.1100 # 00354, T 143.4, Ltr 0.3223, Lge 0.2883, Ctr 0.9614, Str 0.4062, Cge 0.9721, Sge 0.3300 # 00407, T 163.5, Ltr 0.4604, Lge 0.3853, Ctr 0.9270, Str 0.2500, Cge 0.9585, Sge 0.0600 # 00461, T 183.7, Ltr 0.2822, Lge 0.2933, Ctr 0.9670, Str 0.5625, Cge 0.9702, Sge 0.3200 # 00517, T 203.8, Ltr 0.3703, Lge 0.2784, Ctr 0.9480, Str 0.4375, Cge 0.9698, Sge 0.2600 # 00571, T 224.1, Ltr 0.4301, Lge 0.2783, Ctr 0.9308, Str 0.3125, Cge 0.9723, Sge 0.2000 # 00626, T 244.1, Ltr 0.3287, Lge 0.2833, Ctr 0.9533, Str 0.4062, Cge 0.9687, Sge 0.2700 # 00682, T 264.4, Ltr 0.2802, Lge 0.2913, Ctr 0.9617, Str 0.5000, Cge 0.9703, Sge 0.3000 # 00736, T 284.7, Ltr 0.3474, Lge 0.2775, Ctr 0.9531, Str 0.5625, Cge 0.9704, Sge 0.1900 # 00790, T 305.1, Ltr 0.3098, Lge 0.3607, Ctr 0.9488, Str 0.4062, Cge 0.9690, Sge 0.1700 # 00844, T 324.9, Ltr 0.3092, Lge 0.2941, Ctr 0.9566, Str 0.4375, Cge 0.9702, Sge 0.2500 # 00899, T 345.1, Ltr 0.3805, Lge 0.2202, Ctr 0.9440, Str 0.2812, Cge 0.9770, Sge 0.3200 # 00953, T 365.0, Ltr 0.2927, Lge 0.2637, Ctr 0.9609, Str 0.5938, Cge 0.9707, Sge 0.1900 # 01008, T 385.0, Ltr 0.3164, Lge 0.2093, Ctr 0.9568, Str 0.4688, Cge 0.9749, Sge 0.3100 # 01063, T 405.1, Ltr 0.2704, Lge 0.2455, Ctr 0.9749, Str 0.5000, Cge 0.9719, Sge 0.2000 # 01117, T 425.2, Ltr 0.2696, Lge 0.2713, Ctr 0.9600, Str 0.6250, Cge 0.9719, Sge 0.2300 # 01172, T 445.3, Ltr 0.4089, Lge 0.2442, Ctr 0.9489, Str 0.5312, Cge 0.9727, Sge 0.2100 # 01227, T 465.3, Ltr 0.3053, Lge 0.2616, Ctr 0.9620, Str 0.5312, Cge 0.9733, Sge 0.2400 # 01280, T 485.2, Ltr 0.2292, Lge 0.2433, Ctr 0.9742, Str 0.6250, Cge 0.9703, Sge 0.3100 # 01335, T 505.6, Ltr 0.3238, Lge 0.2267, Ctr 0.9554, Str 0.5312, Cge 0.9748, Sge 0.3000 # 01390, T 525.7, Ltr 0.3662, Lge 0.2706, Ctr 0.9602, Str 0.5938, Cge 0.9720, Sge 0.2500 # 01445, T 545.8, Ltr 0.2444, Lge 0.2530, Ctr 0.9755, Str 0.6562, Cge 0.9732, Sge 0.2800 # 01498, T 565.8, Ltr 0.3119, Lge 0.3036, Ctr 0.9565, Str 0.5938, Cge 0.9708, Sge 0.2300 # 01552, T 585.9, Ltr 0.3058, Lge 0.2633, Ctr 0.9553, Str 0.4688, Cge 0.9717, Sge 0.2400 # 01606, T 606.1, Ltr 0.2392, Lge 0.2462, Ctr 0.9782, Str 0.6562, Cge 0.9726, Sge 0.3000 # 01661, T 626.4, Ltr 0.2917, Lge 0.2522, Ctr 0.9611, Str 0.5625, Cge 0.9725, Sge 0.3000 # 01716, T 646.7, Ltr 0.3049, Lge 0.2254, Ctr 0.9566, Str 0.5938, Cge 0.9749, Sge 0.3300 # 01771, T 666.8, Ltr 0.2509, Lge 0.2393, Ctr 0.9667, Str 0.6250, Cge 0.9747, Sge 0.3000 # 01824, T 687.1, Ltr 0.1827, Lge 0.1870, Ctr 0.9879, Str 0.7812, Cge 0.9781, Sge 0.3300 # 01879, T 707.2, Ltr 0.3511, Lge 0.2048, Ctr 0.9574, Str 0.6562, Cge 0.9775, Sge 0.4200 # 01935, T 727.6, Ltr 0.2784, Lge 0.2044, Ctr 0.9699, Str 0.5625, Cge 0.9752, Sge 0.2000 # 01990, T 747.7, Ltr 0.3216, Lge 0.1943, Ctr 0.9639, Str 0.6562, Cge 0.9768, Sge 0.2800 # 02044, T 768.0, Ltr 0.1950, Lge 0.1579, Ctr 0.9892, Str 0.8750, Cge 0.9820, Sge 0.4100 # 02098, T 788.0, Ltr 0.2075, Lge 0.1729, Ctr 0.9882, Str 0.7500, Cge 0.9799, Sge 0.3700 # 02153, T 808.0, Ltr 0.2118, Lge 0.1775, Ctr 0.9783, Str 0.7500, Cge 0.9804, Sge 0.3700 # 02207, T 828.4, Ltr 0.2426, Lge 0.1862, Ctr 0.9669, Str 0.6562, Cge 0.9797, Sge 0.3800 # 02262, T 848.7, Ltr 0.2076, Lge 0.1836, Ctr 0.9862, Str 0.8750, Cge 0.9792, Sge 0.4300 # 02317, T 869.0, Ltr 0.1890, Lge 0.1984, Ctr 0.9873, Str 0.8750, Cge 0.9767, Sge 0.3800 # 02371, T 889.2, Ltr 0.1652, Lge 0.1936, Ctr 0.9887, Str 0.8125, Cge 0.9784, Sge 0.3900 # 02426, T 909.4, Ltr 0.2751, Lge 0.1523, Ctr 0.9707, Str 0.6875, Cge 0.9823, Sge 0.4100 # 02481, T 929.4, Ltr 0.1775, Lge 0.1617, Ctr 0.9867, Str 0.8750, Cge 0.9788, Sge 0.3800 # 02536, T 949.6, Ltr 0.2007, Lge 0.1207, Ctr 0.9880, Str 0.8438, Cge 0.9857, Sge 0.4900 # 02591, T 969.8, Ltr 0.1514, Lge 0.1489, Ctr 0.9934, Str 0.9062, Cge 0.9813, Sge 0.3600 # 02646, T 989.9, Ltr 0.2410, Lge 0.1100, Ctr 0.9862, Str 0.8125, Cge 0.9854, Sge 0.4000 # 02702, T 1010.0, Ltr 0.1991, Lge 0.1578, Ctr 0.9827, Str 0.8125, Cge 0.9813, Sge 0.3800 # 02756, T 1030.1, Ltr 0.1464, Lge 0.1388, Ctr 0.9893, Str 0.8750, Cge 0.9814, Sge 0.4000 # 02811, T 1050.1, Ltr 0.1931, Lge 0.1588, Ctr 0.9833, Str 0.8750, Cge 0.9799, Sge 0.3900 # 02866, T 1070.1, Ltr 0.1570, Lge 0.1189, Ctr 0.9858, Str 0.8750, Cge 0.9858, Sge 0.5500 # 02920, T 1090.5, Ltr 0.1420, Lge 0.1113, Ctr 0.9922, Str 0.8750, Cge 0.9855, Sge 0.4500 # 02974, T 1110.8, Ltr 0.1550, Lge 0.1640, Ctr 0.9911, Str 0.8438, Cge 0.9809, Sge 0.3200 # 03029, T 1130.8, Ltr 0.1681, Lge 0.1297, Ctr 0.9936, Str 0.8750, Cge 0.9873, Sge 0.4900 # 03084, T 1151.5, Ltr 0.1810, Lge 0.1909, Ctr 0.9921, Str 0.8750, Cge 0.9785, Sge 0.2300 # 03139, T 1171.1, Ltr 0.2063, Lge 0.1209, Ctr 0.9818, Str 0.8125, Cge 0.9861, Sge 0.4400 # 03195, T 1191.2, Ltr 0.1340, Lge 0.1583, Ctr 0.9945, Str 0.8750, Cge 0.9789, Sge 0.2100 # 03251, T 1211.7, Ltr 0.1461, Lge 0.1520, Ctr 0.9943, Str 0.9062, Cge 0.9856, Sge 0.4800 # 03303, T 1231.8, Ltr 0.1694, Lge 0.1235, Ctr 0.9854, Str 0.7812, Cge 0.9852, Sge 0.4900 # 03359, T 1252.0, Ltr 0.1738, Lge 0.1222, Ctr 0.9852, Str 0.7812, Cge 0.9846, Sge 0.4400 # 03414, T 1272.2, Ltr 0.1498, Lge 0.1101, Ctr 0.9897, Str 0.8438, Cge 0.9867, Sge 0.4900 # 03468, T 1292.5, Ltr 0.1638, Lge 0.1573, Ctr 0.9894, Str 0.8438, Cge 0.9836, Sge 0.4500 # 03523, T 1312.6, Ltr 0.2194, Lge 0.1516, Ctr 0.9761, Str 0.7188, Cge 0.9846, Sge 0.4600 # 03578, T 1332.6, Ltr 0.1490, Lge 0.1425, Ctr 0.9874, Str 0.9375, Cge 0.9846, Sge 0.5300 # 03633, T 1353.0, Ltr 0.1951, Lge 0.0889, Ctr 0.9860, Str 0.8750, Cge 0.9883, Sge 0.5600 # 03686, T 1373.0, Ltr 0.1586, Lge 0.1016, Ctr 0.9900, Str 0.9062, Cge 0.9875, Sge 0.4900 # 03741, T 1393.4, Ltr 0.1404, Lge 0.1356, Ctr 0.9911, Str 0.8750, Cge 0.9855, Sge 0.5800 # 03794, T 1413.3, Ltr 0.1938, Lge 0.1298, Ctr 0.9852, Str 0.7812, Cge 0.9828, Sge 0.4300 # 03847, T 1433.7, Ltr 0.1412, Lge 0.1183, Ctr 0.9899, Str 0.8750, Cge 0.9870, Sge 0.5900 # 03901, T 1453.8, Ltr 0.1894, Lge 0.0941, Ctr 0.9842, Str 0.8438, Cge 0.9890, Sge 0.6000 # 03955, T 1473.8, Ltr 0.1605, Lge 0.0935, Ctr 0.9867, Str 0.8125, Cge 0.9876, Sge 0.5800 # 04008, T 1493.9, Ltr 0.1560, Lge 0.0700, Ctr 0.9886, Str 0.8438, Cge 0.9927, Sge 0.6900 # 04062, T 1514.0, Ltr 0.2174, Lge 0.1912, Ctr 0.9865, Str 0.7812, Cge 0.9780, Sge 0.3800 # 04115, T 1534.0, Ltr 0.1537, Lge 0.0797, Ctr 0.9852, Str 0.8438, Cge 0.9891, Sge 0.5900 # 04169, T 1554.1, Ltr 0.1586, Lge 0.1071, Ctr 0.9871, Str 0.8438, Cge 0.9864, Sge 0.6100 # 04223, T 1574.5, Ltr 0.1071, Lge 0.1316, Ctr 1.0000, Str 1.0000, Cge 0.9869, Sge 0.6200 # 04278, T 1594.7, Ltr 0.1270, Lge 0.1329, Ctr 0.9985, Str 0.9375, Cge 0.9850, Sge 0.5400 # 04333, T 1614.8, Ltr 0.1352, Lge 0.1023, Ctr 0.9929, Str 0.9375, Cge 0.9864, Sge 0.5300 # 04386, T 1635.2, Ltr 0.1423, Lge 0.0890, Ctr 0.9894, Str 0.8125, Cge 0.9875, Sge 0.4600 # 04440, T 1655.1, Ltr 0.1320, Lge 0.0963, Ctr 0.9994, Str 0.9688, Cge 0.9882, Sge 0.5900 # 04495, T 1675.5, Ltr 0.1603, Lge 0.1094, Ctr 0.9889, Str 0.8750, Cge 0.9876, Sge 0.4800 # 04548, T 1695.7, Ltr 0.1474, Lge 0.1107, Ctr 0.9949, Str 0.9375, Cge 0.9868, Sge 0.5600 # 04602, T 1715.5, Ltr 0.1608, Lge 0.1791, Ctr 0.9960, Str 0.8438, Cge 0.9811, Sge 0.3700 # 04656, T 1735.6, Ltr 0.1416, Lge 0.1130, Ctr 0.9899, Str 0.8438, Cge 0.9865, Sge 0.5600 # 04710, T 1755.5, Ltr 0.1868, Lge 0.1135, Ctr 0.9944, Str 0.9375, Cge 0.9862, Sge 0.5500 # 04764, T 1775.7, Ltr 0.1466, Lge 0.0730, Ctr 0.9916, Str 0.9375, Cge 0.9901, Sge 0.6600 # 04819, T 1795.9, Ltr 0.1147, Lge 0.0881, Ctr 0.9966, Str 0.9688, Cge 0.9906, Sge 0.6900 # 04874, T 1816.0, Ltr 0.1130, Lge 0.1065, Ctr 0.9987, Str 0.9688, Cge 0.9868, Sge 0.5900 # 04928, T 1836.3, Ltr 0.1979, Lge 0.0953, Ctr 0.9909, Str 0.8750, Cge 0.9885, Sge 0.5200 # 04982, T 1856.2, Ltr 0.1319, Lge 0.1024, Ctr 0.9929, Str 0.9062, Cge 0.9875, Sge 0.6000 # 05036, T 1876.4, Ltr 0.1575, Lge 0.0744, Ctr 0.9910, Str 0.9062, Cge 0.9914, Sge 0.6300 # 05090, T 1896.6, Ltr 0.1387, Lge 0.1054, Ctr 0.9938, Str 0.9062, Cge 0.9877, Sge 0.5800 # 05144, T 1916.8, Ltr 0.1196, Lge 0.1088, Ctr 0.9929, Str 0.8750, Cge 0.9857, Sge 0.4800 # 05198, T 1936.6, Ltr 0.1441, Lge 0.0758, Ctr 0.9912, Str 0.9375, Cge 0.9902, Sge 0.6900 # 05252, T 1957.3, Ltr 0.1296, Lge 0.1036, Ctr 0.9957, Str 0.8750, Cge 0.9909, Sge 0.7000 # 05304, T 1976.9, Ltr 0.1311, Lge 0.1073, Ctr 0.9956, Str 0.9062, Cge 0.9883, Sge 0.6700 # 05359, T 1997.1, Ltr 0.0968, Lge 0.1633, Ctr 1.0000, Str 1.0000, Cge 0.9860, Sge 0.5500 # 05413, T 2017.0, Ltr 0.1550, Lge 0.0875, Ctr 0.9903, Str 0.9062, Cge 0.9896, Sge 0.6600 # 05466, T 2037.4, Ltr 0.1204, Lge 0.2264, Ctr 0.9966, Str 0.9062, Cge 0.9834, Sge 0.6400 # 05519, T 2057.5, Ltr 0.1255, Lge 0.1421, Ctr 0.9953, Str 0.9375, Cge 0.9868, Sge 0.6100 # 05573, T 2077.5, Ltr 0.1255, Lge 0.0956, Ctr 0.9941, Str 0.9062, Cge 0.9900, Sge 0.6800 # 05628, T 2098.2, Ltr 0.1427, Lge 0.0803, Ctr 0.9953, Str 0.9062, Cge 0.9893, Sge 0.6500 # 05683, T 2118.4, Ltr 0.1344, Lge 0.0857, Ctr 0.9934, Str 0.9062, Cge 0.9909, Sge 0.6000 # 05739, T 2138.5, Ltr 0.1634, Lge 0.1224, Ctr 0.9931, Str 0.9375, Cge 0.9875, Sge 0.5600 # 05793, T 2158.9, Ltr 0.1784, Lge 0.0720, Ctr 0.9853, Str 0.8750, Cge 0.9905, Sge 0.5800 # 05847, T 2179.2, Ltr 0.1259, Lge 0.0641, Ctr 0.9975, Str 0.9688, Cge 0.9921, Sge 0.6800 # 05902, T 2199.3, Ltr 0.0840, Lge 0.1022, Ctr 0.9974, Str 0.9688, Cge 0.9880, Sge 0.5600 # 05957, T 2219.4, Ltr 0.1161, Lge 0.0861, Ctr 0.9978, Str 0.9688, Cge 0.9906, Sge 0.6900 # 06010, T 2239.6, Ltr 0.1470, Lge 0.0660, Ctr 0.9894, Str 0.8125, Cge 0.9906, Sge 0.6700 # 06065, T 2259.7, Ltr 0.1664, Lge 0.1401, Ctr 0.9937, Str 0.9375, Cge 0.9831, Sge 0.5000 # 06120, T 2279.9, Ltr 0.1733, Lge 0.0901, Ctr 0.9829, Str 0.8438, Cge 0.9880, Sge 0.5600 # 06173, T 2300.1, Ltr 0.1269, Lge 0.0721, Ctr 0.9936, Str 0.9375, Cge 0.9926, Sge 0.7600 # 06228, T 2320.4, Ltr 0.1716, Lge 0.0702, Ctr 0.9926, Str 0.9375, Cge 0.9908, Sge 0.6500 # 06282, T 2340.4, Ltr 0.1386, Lge 0.0545, Ctr 0.9975, Str 0.9688, Cge 0.9926, Sge 0.7100 # 06337, T 2360.9, Ltr 0.1400, Lge 0.0512, Ctr 0.9960, Str 0.9062, Cge 0.9926, Sge 0.6100 # 06391, T 2380.5, Ltr 0.1468, Lge 0.0791, Ctr 0.9965, Str 0.8750, Cge 0.9894, Sge 0.6300 # 06446, T 2400.6, Ltr 0.1655, Lge 0.0847, Ctr 0.9900, Str 0.8750, Cge 0.9897, Sge 0.5800 # 06500, T 2420.6, Ltr 0.1530, Lge 0.0538, Ctr 0.9878, Str 0.7812, Cge 0.9925, Sge 0.6800 # 06553, T 2440.5, Ltr 0.1442, Lge 0.0634, Ctr 0.9969, Str 0.9375, Cge 0.9919, Sge 0.7500 # 06609, T 2460.8, Ltr 0.0933, Lge 0.0678, Ctr 0.9987, Str 0.9688, Cge 0.9912, Sge 0.6900 # 06663, T 2481.0, Ltr 0.1460, Lge 0.0936, Ctr 0.9953, Str 0.9375, Cge 0.9879, Sge 0.5400 # 06716, T 2501.4, Ltr 0.1505, Lge 0.0685, Ctr 0.9941, Str 0.9375, Cge 0.9914, Sge 0.7200 # 06769, T 2521.4, Ltr 0.1400, Lge 0.0530, Ctr 0.9955, Str 0.9062, Cge 0.9931, Sge 0.7400 # 06823, T 2541.4, Ltr 0.1310, Lge 0.0799, Ctr 0.9936, Str 0.8750, Cge 0.9896, Sge 0.6900 # 06878, T 2561.7, Ltr 0.1640, Lge 0.0848, Ctr 0.9885, Str 0.7812, Cge 0.9884, Sge 0.5900 # 06931, T 2581.7, Ltr 0.1395, Lge 0.0783, Ctr 0.9954, Str 0.9375, Cge 0.9904, Sge 0.6400 # 06986, T 2601.9, Ltr 0.1150, Lge 0.0546, Ctr 0.9969, Str 0.9375, Cge 0.9923, Sge 0.7400 # 07041, T 2622.1, Ltr 0.0829, Lge 0.0574, Ctr 1.0000, Str 1.0000, Cge 0.9923, Sge 0.6700 # 07095, T 2642.4, Ltr 0.1719, Lge 0.1901, Ctr 0.9907, Str 0.9375, Cge 0.9819, Sge 0.2800 # 07149, T 2662.5, Ltr 0.2284, Lge 0.0478, Ctr 0.9789, Str 0.8438, Cge 0.9934, Sge 0.7100 # 07204, T 2682.8, Ltr 0.1277, Lge 0.0614, Ctr 0.9923, Str 0.8750, Cge 0.9914, Sge 0.6000 # 07256, T 2703.0, Ltr 0.2056, Lge 0.0849, Ctr 0.9938, Str 0.9062, Cge 0.9910, Sge 0.6400 # 07311, T 2723.1, Ltr 0.1456, Lge 0.0573, Ctr 0.9967, Str 0.9688, Cge 0.9924, Sge 0.6700 # 07366, T 2743.3, Ltr 0.1366, Lge 0.0878, Ctr 0.9993, Str 0.9688, Cge 0.9898, Sge 0.7200 # 07420, T 2764.0, Ltr 0.1349, Lge 0.0462, Ctr 0.9953, Str 0.9375, Cge 0.9948, Sge 0.7700 # 07472, T 2783.6, Ltr 0.1244, Lge 0.0604, Ctr 0.9955, Str 0.9375, Cge 0.9929, Sge 0.7300 # 07528, T 2803.8, Ltr 0.1206, Lge 0.0890, Ctr 1.0000, Str 1.0000, Cge 0.9875, Sge 0.5300 # 07583, T 2824.1, Ltr 0.1248, Lge 0.0860, Ctr 0.9993, Str 0.9688, Cge 0.9910, Sge 0.7300 # 07636, T 2844.1, Ltr 0.1737, Lge 0.1036, Ctr 0.9909, Str 0.9062, Cge 0.9891, Sge 0.6600 # 07689, T 2864.1, Ltr 0.1297, Lge 0.0718, Ctr 0.9974, Str 0.9375, Cge 0.9933, Sge 0.7400 # 07744, T 2884.4, Ltr 0.1139, Lge 0.0905, Ctr 0.9962, Str 0.9375, Cge 0.9888, Sge 0.5700 # 07797, T 2904.6, Ltr 0.1405, Lge 0.0703, Ctr 0.9975, Str 0.9062, Cge 0.9905, Sge 0.6900 # 07852, T 2924.8, Ltr 0.1078, Lge 0.0566, Ctr 0.9986, Str 0.9375, Cge 0.9926, Sge 0.7600 # 07906, T 2944.9, Ltr 0.1498, Lge 0.0772, Ctr 0.9923, Str 0.8750, Cge 0.9902, Sge 0.6400 # 07959, T 2965.1, Ltr 0.1378, Lge 0.0657, Ctr 0.9919, Str 0.9375, Cge 0.9919, Sge 0.6900 # 08012, T 2985.2, Ltr 0.1468, Lge 0.0639, Ctr 0.9872, Str 0.8438, Cge 0.9919, Sge 0.6600 # 08067, T 3005.5, Ltr 0.1473, Lge 0.0555, Ctr 0.9967, Str 0.9688, Cge 0.9930, Sge 0.7400 # 08121, T 3025.6, Ltr 0.0928, Lge 0.0502, Ctr 0.9963, Str 0.9375, Cge 0.9930, Sge 0.6500 # 08174, T 3045.9, Ltr 0.1561, Lge 0.0637, Ctr 0.9906, Str 0.8750, Cge 0.9929, Sge 0.7300 # 08229, T 3066.2, Ltr 0.1539, Lge 0.1363, Ctr 0.9885, Str 0.8438, Cge 0.9887, Sge 0.6600 # 08283, T 3086.3, Ltr 0.1270, Lge 0.0807, Ctr 0.9930, Str 0.9375, Cge 0.9889, Sge 0.5900 # 08337, T 3107.0, Ltr 0.1001, Lge 0.0721, Ctr 0.9948, Str 0.9375, Cge 0.9909, Sge 0.6900 # 08391, T 3126.5, Ltr 0.1344, Lge 0.0732, Ctr 0.9944, Str 0.9375, Cge 0.9917, Sge 0.7000 # 08446, T 3146.8, Ltr 0.1127, Lge 0.0597, Ctr 0.9994, Str 0.9688, Cge 0.9907, Sge 0.6600 # 08502, T 3167.0, Ltr 0.1328, Lge 0.0496, Ctr 0.9959, Str 0.9375, Cge 0.9928, Sge 0.7600 # 08556, T 3187.1, Ltr 0.1424, Lge 0.0844, Ctr 0.9953, Str 0.9062, Cge 0.9901, Sge 0.6900 # 08612, T 3207.4, Ltr 0.1434, Lge 0.0986, Ctr 0.9949, Str 0.9062, Cge 0.9863, Sge 0.5700 # 08666, T 3227.6, Ltr 0.1772, Lge 0.0893, Ctr 0.9865, Str 0.8438, Cge 0.9894, Sge 0.6100 # 08720, T 3247.8, Ltr 0.1491, Lge 0.0896, Ctr 0.9906, Str 0.9375, Cge 0.9893, Sge 0.6800 # 08775, T 3268.2, Ltr 0.1873, Lge 0.0815, Ctr 0.9895, Str 0.8438, Cge 0.9905, Sge 0.6900 # 08830, T 3288.4, Ltr 0.1128, Lge 0.0790, Ctr 0.9988, Str 0.9688, Cge 0.9881, Sge 0.6400 # 08882, T 3308.5, Ltr 0.1164, Lge 0.1097, Ctr 0.9957, Str 0.9688, Cge 0.9896, Sge 0.7200 # 08935, T 3328.4, Ltr 0.1509, Lge 0.0733, Ctr 0.9891, Str 0.8438, Cge 0.9901, Sge 0.6100 # 08990, T 3348.7, Ltr 0.1394, Lge 0.0357, Ctr 0.9969, Str 0.9375, Cge 0.9954, Sge 0.7800 # 09045, T 3368.9, Ltr 0.1263, Lge 0.0905, Ctr 0.9933, Str 0.8750, Cge 0.9882, Sge 0.6400 # 09098, T 3388.9, Ltr 0.1421, Lge 0.0697, Ctr 0.9929, Str 0.9062, Cge 0.9935, Sge 0.7900 # 09152, T 3409.1, Ltr 0.1357, Lge 0.0765, Ctr 0.9904, Str 0.7812, Cge 0.9904, Sge 0.6100 # 09207, T 3429.5, Ltr 0.1691, Lge 0.0696, Ctr 0.9950, Str 0.9688, Cge 0.9917, Sge 0.6700 # 09260, T 3449.4, Ltr 0.1421, Lge 0.0924, Ctr 0.9928, Str 0.9062, Cge 0.9896, Sge 0.6700 # 09315, T 3469.8, Ltr 0.1280, Lge 0.0687, Ctr 0.9941, Str 0.9375, Cge 0.9906, Sge 0.6900 # 09370, T 3490.1, Ltr 0.1428, Lge 0.0758, Ctr 0.9968, Str 0.9062, Cge 0.9920, Sge 0.7100 # 09423, T 3510.1, Ltr 0.1391, Lge 0.0665, Ctr 0.9956, Str 0.9062, Cge 0.9915, Sge 0.7100 # 09479, T 3530.1, Ltr 0.1644, Lge 0.1169, Ctr 0.9915, Str 0.9062, Cge 0.9883, Sge 0.6300 # 09535, T 3550.4, Ltr 0.1296, Lge 0.0786, Ctr 0.9898, Str 0.9062, Cge 0.9901, Sge 0.6600 # 09590, T 3570.6, Ltr 0.1611, Lge 0.0532, Ctr 0.9871, Str 0.8750, Cge 0.9927, Sge 0.7000 # 09644, T 3590.7, Ltr 0.1599, Lge 0.0585, Ctr 0.9943, Str 0.9375, Cge 0.9918, Sge 0.6900 # 09698, T 3610.8, Ltr 0.1409, Lge 0.0598, Ctr 0.9944, Str 0.8750, Cge 0.9926, Sge 0.7200 # 09754, T 3631.0, Ltr 0.1637, Lge 0.0523, Ctr 0.9900, Str 0.9375, Cge 0.9926, Sge 0.7300 # 09808, T 3651.2, Ltr 0.1112, Lge 0.0894, Ctr 0.9931, Str 0.9375, Cge 0.9924, Sge 0.7200 # 09863, T 3671.4, Ltr 0.1136, Lge 0.0938, Ctr 0.9979, Str 0.9688, Cge 0.9877, Sge 0.5300 # 09918, T 3691.5, Ltr 0.1333, Lge 0.0761, Ctr 0.9890, Str 0.8750, Cge 0.9905, Sge 0.6200 # 09972, T 3711.8, Ltr 0.1063, Lge 0.0603, Ctr 0.9975, Str 0.9375, Cge 0.9933, Sge 0.7100
結果を可視化する
# This cell visualizes the results of training. You can visualize the
# intermediate results by interrupting execution of the cell above, and running
# this cell. You can then resume training by simply executing the above cell
# again.
def softmax_prob_last_dim(x): # pylint: disable=redefined-outer-name
e = np.exp(x)
return e[:, -1] / np.sum(e, axis=-1)
# Plot results curves.
fig = plt.figure(1, figsize=(18, 3))
fig.clf()
x = np.array(logged_iterations)
# Loss.
y_tr = losses_tr
y_ge = losses_ge
ax = fig.add_subplot(1, 3, 1)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Loss across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Loss (binary cross-entropy)")
ax.legend()
# Correct.
y_tr = corrects_tr
y_ge = corrects_ge
ax = fig.add_subplot(1, 3, 2)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction correct across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction nodes/edges correct")
# Solved.
y_tr = solveds_tr
y_ge = solveds_ge
ax = fig.add_subplot(1, 3, 3)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction solved across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction examples solved")
# Plot graphs and results after each processing step.
# The white node is the start, and the black is the end. Other nodes are colored
# from red to purple to blue, where red means the model is confident the node is
# off the shortest path, blue means the model is confident the node is on the
# shortest path, and purplish colors mean the model isn't sure.
max_graphs_to_plot = 6
num_steps_to_plot = 4
node_size = 120
min_c = 0.3
num_graphs = len(raw_graphs)
targets = utils_np.graphs_tuple_to_data_dicts(test_values["target"])
step_indices = np.floor(
np.linspace(0, num_processing_steps_ge - 1,
num_steps_to_plot)).astype(int).tolist()
outputs = list(
zip(*(utils_np.graphs_tuple_to_data_dicts(test_values["outputs"][i])
for i in step_indices)))
h = min(num_graphs, max_graphs_to_plot)
w = num_steps_to_plot + 1
fig = plt.figure(101, figsize=(18, h * 3))
fig.clf()
ncs = []
for j, (graph, target, output) in enumerate(zip(raw_graphs, targets, outputs)):
if j >= h:
break
pos = get_node_dict(graph, "pos")
ground_truth = target["nodes"][:, -1]
# Ground truth.
iax = j * (1 + num_steps_to_plot) + 1
ax = fig.add_subplot(h, w, iax)
plotter = GraphPlotter(ax, graph, pos)
color = {}
for i, n in enumerate(plotter.nodes):
color[n] = np.array([1.0 - ground_truth[i], 0.0, ground_truth[i], 1.0
]) * (1.0 - min_c) + min_c
plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
ax.set_axis_on()
ax.set_xticks([])
ax.set_yticks([])
try:
ax.set_facecolor([0.9] * 3 + [1.0])
except AttributeError:
ax.set_axis_bgcolor([0.9] * 3 + [1.0])
ax.grid(None)
ax.set_title("Ground truth\nSolution length: {}".format(
plotter.solution_length))
# Prediction.
for k, outp in enumerate(output):
iax = j * (1 + num_steps_to_plot) + 2 + k
ax = fig.add_subplot(h, w, iax)
plotter = GraphPlotter(ax, graph, pos)
color = {}
prob = softmax_prob_last_dim(outp["nodes"])
for i, n in enumerate(plotter.nodes):
color[n] = np.array([1.0 - prob[n], 0.0, prob[n], 1.0
]) * (1.0 - min_c) + min_c
plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
ax.set_title("Model-predicted\nStep {:02d} / {:02d}".format(
step_indices[k] + 1, step_indices[-1] + 1))


以上