Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

TensorFlow : Graph Nets : グラフの最短経路を見つける

Posted on 06/07/2019 by Sales Information

TensorFlow : Graph Nets : グラフの最短経路を見つける (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/07/2019

* 本ページは、deepmind/graph_nets の README.md 及び “Find the shortest path in a graph” を翻訳した上で適宜、補足説明したものです:

  • deepmind/graph_nets/blob/master/README.md
  • deepmind/graph_nets/blob/master/graph_nets/demos/shortest_path.ipynb

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

Graph Nets : README.md

Graph Nets は TensorFlow と Sonnet でグラフネットワークを構築するための DeepMind のライブラリです。

グラフネットワークとは何でしょう?

グラフネットワークは入力としてグラフを取り出力としてグラフを返します。入力グラフはエッジ- (E)、ノード- (V) とグローバルレベル (u) 属性を持ちます。出力グラフは同じ構造、しかし更新された属性を持ちます。グラフネットワークは「グラフニューラルネットワーク」のより広いファミリーの一部です (Scarselli et al., 2009)。

グラフネットワークについてより学習するためには、私達の arXiv ペーパー: Relational inductive biases, deep learning, and graph networks を見てください。

 

使用方法サンプル

次のコードは単純なグラフネット・モジュールを構築してそれをデータに接続します。

import graph_nets as gn
import sonnet as snt

# Provide your own functions to generate graph-structured data.
input_graphs = get_graphs()

# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=lambda: snt.nets.MLP([32, 32]),
    node_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 32]))

# Pass the input graphs to the graph network, and return the output graphs.
output_graphs = graph_net_module(input_graphs)

 
 

Graph Nets : グラフの最短経路を見つける

「最短経路デモ」はランダムグラフを作成して、任意の 2 つのノードの間の最短経路上のノードとエッジをラベル付けするためにグラフネットワークを訓練します。メッセージパッシングのステップのシークエンスに渡り (各ステップのプロットで描かれるように)、モデルは最短経路のその予測を改良していきます。

このノートブックと伴うコードはグラフの 2 つのノード間に最短経路を予測することを学習するために Graph Nets ライブラリをどのように使用するかを示します。

開始と終了ノードが与えられたとき、ネットワークは最短経路のノードとエッジをラベル付けするために訓練されます。

訓練後、ネットワークの予測能力はその出力を真の最短経路と比較することにより示されます。それから汎化するためのネットワークの能力がテストされます、類似のしかしより巨大なグラフの最短経路を予測するためにそれを使用することによって。

 

インポート

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import itertools
import time

from graph_nets import graphs
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.demos import models
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from scipy import spatial
import tensorflow as tf

SEED = 1
np.random.seed(SEED)
tf.set_random_seed(SEED)

 

ヘルパー関数

DISTANCE_WEIGHT_NAME = "distance"  # The name for the distance edge attribute.


def pairwise(iterable):
  """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
  a, b = itertools.tee(iterable)
  next(b, None)
  return zip(a, b)


def set_diff(seq0, seq1):
  """Return the set difference between 2 sequences as a list."""
  return list(set(seq0) - set(seq1))


def to_one_hot(indices, max_value, axis=-1):
  one_hot = np.eye(max_value)[indices]
  if axis not in (-1, one_hot.ndim):
    one_hot = np.moveaxis(one_hot, -1, axis)
  return one_hot


def get_node_dict(graph, attr):
  """Return a `dict` of node:attribute pairs from a graph."""
  return {k: v[attr] for k, v in graph.node.items()}


def generate_graph(rand,
                   num_nodes_min_max,
                   dimensions=2,
                   theta=1000.0,
                   rate=1.0):
  """Creates a connected graph.

  The graphs are geographic threshold graphs, but with added edges via a
  minimum spanning tree algorithm, to ensure all nodes are connected.

  Args:
    rand: A random seed for the graph generator. Default= None.
    num_nodes_min_max: A sequence [lower, upper) number of nodes per graph.
    dimensions: (optional) An `int` number of dimensions for the positions.
      Default= 2.
    theta: (optional) A `float` threshold parameters for the geographic
      threshold graph's threshold. Large values (1000+) make mostly trees. Try
      20-60 for good non-trees. Default=1000.0.
    rate: (optional) A rate parameter for the node weight exponential sampling
      distribution. Default= 1.0.

  Returns:
    The graph.
  """
  # Sample num_nodes.
  num_nodes = rand.randint(*num_nodes_min_max)

  # Create geographic threshold graph.
  pos_array = rand.uniform(size=(num_nodes, dimensions))
  pos = dict(enumerate(pos_array))
  weight = dict(enumerate(rand.exponential(rate, size=num_nodes)))
  geo_graph = nx.geographical_threshold_graph(
      num_nodes, theta, pos=pos, weight=weight)

  # Create minimum spanning tree across geo_graph's nodes.
  distances = spatial.distance.squareform(spatial.distance.pdist(pos_array))
  i_, j_ = np.meshgrid(range(num_nodes), range(num_nodes), indexing="ij")
  weighted_edges = list(zip(i_.ravel(), j_.ravel(), distances.ravel()))
  mst_graph = nx.Graph()
  mst_graph.add_weighted_edges_from(weighted_edges, weight=DISTANCE_WEIGHT_NAME)
  mst_graph = nx.minimum_spanning_tree(mst_graph, weight=DISTANCE_WEIGHT_NAME)
  # Put geo_graph's node attributes into the mst_graph.
  for i in mst_graph.nodes():
    mst_graph.node[i].update(geo_graph.node[i])

  # Compose the graphs.
  combined_graph = nx.compose_all((mst_graph, geo_graph.copy()))
  # Put all distance weights into edge attributes.
  for i, j in combined_graph.edges():
    combined_graph.get_edge_data(i, j).setdefault(DISTANCE_WEIGHT_NAME,
                                                  distances[i, j])
  return combined_graph, mst_graph, geo_graph


def add_shortest_path(rand, graph, min_length=1):
  """Samples a shortest path from A to B and adds attributes to indicate it.

  Args:
    rand: A random seed for the graph generator. Default= None.
    graph: A `nx.Graph`.
    min_length: (optional) An `int` minimum number of edges in the shortest
      path. Default= 1.

  Returns:
    The `nx.DiGraph` with the shortest path added.

  Raises:
    ValueError: All shortest paths are below the minimum length
  """
  # Map from node pairs to the length of their shortest path.
  pair_to_length_dict = {}
  try:
    # This is for compatibility with older networkx.
    lengths = nx.all_pairs_shortest_path_length(graph).items()
  except AttributeError:
    # This is for compatibility with newer networkx.
    lengths = list(nx.all_pairs_shortest_path_length(graph))
  for x, yy in lengths:
    for y, l in yy.items():
      if l >= min_length:
        pair_to_length_dict[x, y] = l
  if max(pair_to_length_dict.values()) < min_length:
    raise ValueError("All shortest paths are below the minimum length")
  # The node pairs which exceed the minimum length.
  node_pairs = list(pair_to_length_dict)

  # Computes probabilities per pair, to enforce uniform sampling of each
  # shortest path lengths.
  # The counts of pairs per length.
  counts = collections.Counter(pair_to_length_dict.values())
  prob_per_length = 1.0 / len(counts)
  probabilities = [
      prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
  ]

  # Choose the start and end points.
  i = rand.choice(len(node_pairs), p=probabilities)
  start, end = node_pairs[i]
  path = nx.shortest_path(
      graph, source=start, target=end, weight=DISTANCE_WEIGHT_NAME)

  # Creates a directed graph, to store the directed path from start to end.
  digraph = graph.to_directed()

  # Add the "start", "end", and "solution" attributes to the nodes and edges.
  digraph.add_node(start, start=True)
  digraph.add_node(end, end=True)
  digraph.add_nodes_from(set_diff(digraph.nodes(), [start]), start=False)
  digraph.add_nodes_from(set_diff(digraph.nodes(), [end]), end=False)
  digraph.add_nodes_from(set_diff(digraph.nodes(), path), solution=False)
  digraph.add_nodes_from(path, solution=True)
  path_edges = list(pairwise(path))
  digraph.add_edges_from(set_diff(digraph.edges(), path_edges), solution=False)
  digraph.add_edges_from(path_edges, solution=True)

  return digraph


def graph_to_input_target(graph):
  """Returns 2 graphs with input and target feature vectors for training.

  Args:
    graph: An `nx.DiGraph` instance.

  Returns:
    The input `nx.DiGraph` instance.
    The target `nx.DiGraph` instance.

  Raises:
    ValueError: unknown node type
  """

  def create_feature(attr, fields):
    return np.hstack([np.array(attr[field], dtype=float) for field in fields])

  input_node_fields = ("pos", "weight", "start", "end")
  input_edge_fields = ("distance",)
  target_node_fields = ("solution",)
  target_edge_fields = ("solution",)

  input_graph = graph.copy()
  target_graph = graph.copy()

  solution_length = 0
  for node_index, node_feature in graph.nodes(data=True):
    input_graph.add_node(
        node_index, features=create_feature(node_feature, input_node_fields))
    target_node = to_one_hot(
        create_feature(node_feature, target_node_fields).astype(int), 2)[0]
    target_graph.add_node(node_index, features=target_node)
    solution_length += int(node_feature["solution"])
  solution_length /= graph.number_of_nodes()

  for receiver, sender, features in graph.edges(data=True):
    input_graph.add_edge(
        sender, receiver, features=create_feature(features, input_edge_fields))
    target_edge = to_one_hot(
        create_feature(features, target_edge_fields).astype(int), 2)[0]
    target_graph.add_edge(sender, receiver, features=target_edge)

  input_graph.graph["features"] = np.array([0.0])
  target_graph.graph["features"] = np.array([solution_length], dtype=float)

  return input_graph, target_graph


def generate_networkx_graphs(rand, num_examples, num_nodes_min_max, theta):
  """Generate graphs for training.

  Args:
    rand: A random seed (np.RandomState instance).
    num_examples: Total number of graphs to generate.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: (optional) A `float` threshold parameters for the geographic
      threshold graph's threshold. Default= the number of nodes.

  Returns:
    input_graphs: The list of input graphs.
    target_graphs: The list of output graphs.
    graphs: The list of generated graphs.
  """
  input_graphs = []
  target_graphs = []
  graphs = []
  for _ in range(num_examples):
    graph = generate_graph(rand, num_nodes_min_max, theta=theta)[0]
    graph = add_shortest_path(rand, graph)
    input_graph, target_graph = graph_to_input_target(graph)
    input_graphs.append(input_graph)
    target_graphs.append(target_graph)
    graphs.append(graph)
  return input_graphs, target_graphs, graphs


def create_placeholders(rand, batch_size, num_nodes_min_max, theta):
  """Creates placeholders for the model training and evaluation.

  Args:
    rand: A random seed (np.RandomState instance).
    batch_size: Total number of graphs per batch.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: A `float` threshold parameters for the geographic threshold graph's
      threshold. Default= the number of nodes.

  Returns:
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.
  """
  # Create some example data for inspecting the vector sizes.
  input_graphs, target_graphs, _ = generate_networkx_graphs(
      rand, batch_size, num_nodes_min_max, theta)
  input_ph = utils_tf.placeholders_from_networkxs(input_graphs)
  target_ph = utils_tf.placeholders_from_networkxs(target_graphs)
  return input_ph, target_ph


def create_feed_dict(rand, batch_size, num_nodes_min_max, theta, input_ph,
                     target_ph):
  """Creates placeholders for the model training and evaluation.

  Args:
    rand: A random seed (np.RandomState instance).
    batch_size: Total number of graphs per batch.
    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per
      graph. The number of nodes for a graph is uniformly sampled within this
      range.
    theta: A `float` threshold parameters for the geographic threshold graph's
      threshold. Default= the number of nodes.
    input_ph: The input graph's placeholders, as a graph namedtuple.
    target_ph: The target graph's placeholders, as a graph namedtuple.

  Returns:
    feed_dict: The feed `dict` of input and target placeholders and data.
    raw_graphs: The `dict` of raw networkx graphs.
  """
  inputs, targets, raw_graphs = generate_networkx_graphs(
      rand, batch_size, num_nodes_min_max, theta)
  input_graphs = utils_np.networkxs_to_graphs_tuple(inputs)
  target_graphs = utils_np.networkxs_to_graphs_tuple(targets)
  feed_dict = {input_ph: input_graphs, target_ph: target_graphs}
  return feed_dict, raw_graphs


def compute_accuracy(target, output, use_nodes=True, use_edges=False):
  """Calculate model accuracy.

  Returns the number of correctly predicted shortest path nodes and the number
  of completely solved graphs (100% correct predictions).

  Args:
    target: A `graphs.GraphsTuple` that contains the target graph.
    output: A `graphs.GraphsTuple` that contains the output graph.
    use_nodes: A `bool` indicator of whether to compute node accuracy or not.
    use_edges: A `bool` indicator of whether to compute edge accuracy or not.

  Returns:
    correct: A `float` fraction of correctly labeled nodes/edges.
    solved: A `float` fraction of graphs that are completely correctly labeled.

  Raises:
    ValueError: Nodes or edges (or both) must be used
  """
  if not use_nodes and not use_edges:
    raise ValueError("Nodes or edges (or both) must be used")
  tdds = utils_np.graphs_tuple_to_data_dicts(target)
  odds = utils_np.graphs_tuple_to_data_dicts(output)
  cs = []
  ss = []
  for td, od in zip(tdds, odds):
    xn = np.argmax(td["nodes"], axis=-1)
    yn = np.argmax(od["nodes"], axis=-1)
    xe = np.argmax(td["edges"], axis=-1)
    ye = np.argmax(od["edges"], axis=-1)
    c = []
    if use_nodes:
      c.append(xn == yn)
    if use_edges:
      c.append(xe == ye)
    c = np.concatenate(c, axis=0)
    s = np.all(c)
    cs.append(c)
    ss.append(s)
  correct = np.mean(np.concatenate(cs, axis=0))
  solved = np.mean(np.stack(ss))
  return correct, solved


def create_loss_ops(target_op, output_ops):
  loss_ops = [
      tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +
      tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)
      for output_op in output_ops
  ]
  return loss_ops


def make_all_runnable_in_session(*args):
  """Lets an iterable of TF graphs be output from a session as NP graphs."""
  return [utils_tf.make_runnable_in_session(a) for a in args]


class GraphPlotter(object):

  def __init__(self, ax, graph, pos):
    self._ax = ax
    self._graph = graph
    self._pos = pos
    self._base_draw_kwargs = dict(G=self._graph, pos=self._pos, ax=self._ax)
    self._solution_length = None
    self._nodes = None
    self._edges = None
    self._start_nodes = None
    self._end_nodes = None
    self._solution_nodes = None
    self._intermediate_solution_nodes = None
    self._solution_edges = None
    self._non_solution_nodes = None
    self._non_solution_edges = None
    self._ax.set_axis_off()

  @property
  def solution_length(self):
    if self._solution_length is None:
      self._solution_length = len(self._solution_edges)
    return self._solution_length

  @property
  def nodes(self):
    if self._nodes is None:
      self._nodes = self._graph.nodes()
    return self._nodes

  @property
  def edges(self):
    if self._edges is None:
      self._edges = self._graph.edges()
    return self._edges

  @property
  def start_nodes(self):
    if self._start_nodes is None:
      self._start_nodes = [
          n for n in self.nodes if self._graph.node[n].get("start", False)
      ]
    return self._start_nodes

  @property
  def end_nodes(self):
    if self._end_nodes is None:
      self._end_nodes = [
          n for n in self.nodes if self._graph.node[n].get("end", False)
      ]
    return self._end_nodes

  @property
  def solution_nodes(self):
    if self._solution_nodes is None:
      self._solution_nodes = [
          n for n in self.nodes if self._graph.node[n].get("solution", False)
      ]
    return self._solution_nodes

  @property
  def intermediate_solution_nodes(self):
    if self._intermediate_solution_nodes is None:
      self._intermediate_solution_nodes = [
          n for n in self.nodes
          if self._graph.node[n].get("solution", False) and
          not self._graph.node[n].get("start", False) and
          not self._graph.node[n].get("end", False)
      ]
    return self._intermediate_solution_nodes

  @property
  def solution_edges(self):
    if self._solution_edges is None:
      self._solution_edges = [
          e for e in self.edges
          if self._graph.get_edge_data(e[0], e[1]).get("solution", False)
      ]
    return self._solution_edges

  @property
  def non_solution_nodes(self):
    if self._non_solution_nodes is None:
      self._non_solution_nodes = [
          n for n in self.nodes
          if not self._graph.node[n].get("solution", False)
      ]
    return self._non_solution_nodes

  @property
  def non_solution_edges(self):
    if self._non_solution_edges is None:
      self._non_solution_edges = [
          e for e in self.edges
          if not self._graph.get_edge_data(e[0], e[1]).get("solution", False)
      ]
    return self._non_solution_edges

  def _make_draw_kwargs(self, **kwargs):
    kwargs.update(self._base_draw_kwargs)
    return kwargs

  def _draw(self, draw_function, zorder=None, **kwargs):
    draw_kwargs = self._make_draw_kwargs(**kwargs)
    collection = draw_function(**draw_kwargs)
    if collection is not None and zorder is not None:
      try:
        # This is for compatibility with older matplotlib.
        collection.set_zorder(zorder)
      except AttributeError:
        # This is for compatibility with newer matplotlib.
        collection[0].set_zorder(zorder)
    return collection

  def draw_nodes(self, **kwargs):
    """Useful kwargs: nodelist, node_size, node_color, linewidths."""
    if ("node_color" in kwargs and
        isinstance(kwargs["node_color"], collections.Sequence) and
        len(kwargs["node_color"]) in {3, 4} and
        not isinstance(kwargs["node_color"][0],
                       (collections.Sequence, np.ndarray))):
      num_nodes = len(kwargs.get("nodelist", self.nodes))
      kwargs["node_color"] = np.tile(
          np.array(kwargs["node_color"])[None], [num_nodes, 1])
    return self._draw(nx.draw_networkx_nodes, **kwargs)

  def draw_edges(self, **kwargs):
    """Useful kwargs: edgelist, width."""
    return self._draw(nx.draw_networkx_edges, **kwargs)

  def draw_graph(self,
                 node_size=200,
                 node_color=(0.4, 0.8, 0.4),
                 node_linewidth=1.0,
                 edge_width=1.0):
    # Plot nodes.
    self.draw_nodes(
        nodelist=self.nodes,
        node_size=node_size,
        node_color=node_color,
        linewidths=node_linewidth,
        zorder=20)
    # Plot edges.
    self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)

  def draw_graph_with_solution(self,
                               node_size=200,
                               node_color=(0.4, 0.8, 0.4),
                               node_linewidth=1.0,
                               edge_width=1.0,
                               start_color="w",
                               end_color="k",
                               solution_node_linewidth=3.0,
                               solution_edge_width=3.0):
    node_border_color = (0.0, 0.0, 0.0, 1.0)
    node_collections = {}
    # Plot start nodes.
    node_collections["start nodes"] = self.draw_nodes(
        nodelist=self.start_nodes,
        node_size=node_size,
        node_color=start_color,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=100)
    # Plot end nodes.
    node_collections["end nodes"] = self.draw_nodes(
        nodelist=self.end_nodes,
        node_size=node_size,
        node_color=end_color,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=90)
    # Plot intermediate solution nodes.
    if isinstance(node_color, dict):
      c = [node_color[n] for n in self.intermediate_solution_nodes]
    else:
      c = node_color
    node_collections["intermediate solution nodes"] = self.draw_nodes(
        nodelist=self.intermediate_solution_nodes,
        node_size=node_size,
        node_color=c,
        linewidths=solution_node_linewidth,
        edgecolors=node_border_color,
        zorder=80)
    # Plot solution edges.
    node_collections["solution edges"] = self.draw_edges(
        edgelist=self.solution_edges, width=solution_edge_width, zorder=70)
    # Plot non-solution nodes.
    if isinstance(node_color, dict):
      c = [node_color[n] for n in self.non_solution_nodes]
    else:
      c = node_color
    node_collections["non-solution nodes"] = self.draw_nodes(
        nodelist=self.non_solution_nodes,
        node_size=node_size,
        node_color=c,
        linewidths=node_linewidth,
        edgecolors=node_border_color,
        zorder=20)
    # Plot non-solution edges.
    node_collections["non-solution edges"] = self.draw_edges(
        edgelist=self.non_solution_edges, width=edge_width, zorder=10)
    # Set title as solution length.
    self._ax.set_title("Solution length: {}".format(self.solution_length))
    return node_collections

 

サンプルグラフを可視化する

seed = 1  #@param{type: 'integer'}
rand = np.random.RandomState(seed=seed)

num_examples = 15  #@param{type: 'integer'}
# Large values (1000+) make trees. Try 20-60 for good non-trees.
theta = 20  #@param{type: 'integer'}
num_nodes_min_max = (16, 17)

input_graphs, target_graphs, graphs = generate_networkx_graphs(
    rand, num_examples, num_nodes_min_max, theta)

num = min(num_examples, 16)
w = 3
h = int(np.ceil(num / w))
fig = plt.figure(40, figsize=(w * 4, h * 4))
fig.clf()
for j, graph in enumerate(graphs):
  ax = fig.add_subplot(h, w, j + 1)
  pos = get_node_dict(graph, "pos")
  plotter = GraphPlotter(ax, graph, pos)
  plotter.draw_graph_with_solution()

 

モデル訓練と評価をセットアップする

# The model we explore includes three components:
# - An "Encoder" graph net, which independently encodes the edge, node, and
#   global attributes (does not compute relations etc.).
# - A "Core" graph net, which performs N rounds of processing (message-passing)
#   steps. The input to the Core is the concatenation of the Encoder's output
#   and the previous output of the Core (labeled "Hidden(t)" below, where "t" is
#   the processing step).
# - A "Decoder" graph net, which independently decodes the edge, node, and
#   global attributes (does not compute relations etc.), on each
#   message-passing step.
#
#                     Hidden(t)   Hidden(t+1)
#                        |            ^
#           *---------*  |  *------*  |  *---------*
#           |         |  |  |      |  |  |         |
# Input --->| Encoder |  *->| Core |--*->| Decoder |---> Output(t)
#           |         |---->|      |     |         |
#           *---------*     *------*     *---------*
#
# The model is trained by supervised learning. Input graphs are procedurally
# generated, and output graphs have the same structure with the nodes and edges
# of the shortest path labeled (using 2-element 1-hot vectors). We could have
# predicted the shortest path only by labeling either the nodes or edges, and
# that does work, but we decided to predict both to demonstrate the flexibility
# of graph nets' outputs.
#
# The training loss is computed on the output of each processing step. The
# reason for this is to encourage the model to try to solve the problem in as
# few steps as possible. It also helps make the output of intermediate steps
# more interpretable.
#
# There's no need for a separate evaluate dataset because the inputs are
# never repeated, so the training loss is the measure of performance on graphs
# from the input distribution.
#
# We also evaluate how well the models generalize to graphs which are up to
# twice as large as those on which it was trained. The loss is computed only
# on the final processing step.
#
# Variables with the suffix _tr are training parameters, and variables with the
# suffix _ge are test/generalization parameters.
#
# After around 2000-5000 training iterations the model reaches near-perfect
# performance on graphs with between 8-16 nodes.

tf.reset_default_graph()

seed = 2
rand = np.random.RandomState(seed=seed)

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10

# Data / training parameters.
num_training_iterations = 10000
theta = 20  # Large values (1000+) make trees. Try 20-60 for good non-trees.
batch_size_tr = 32
batch_size_ge = 100
# Number of nodes per graph sampled uniformly from this range.
num_nodes_min_max_tr = (8, 17)
num_nodes_min_max_ge = (16, 33)

# Data.
# Input and target placeholders.
input_ph, target_ph = create_placeholders(rand, batch_size_tr,
                                          num_nodes_min_max_tr, theta)

# Connect the data to the model.
# Instantiate the model.
model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
# A list of outputs, one per processing step.
output_ops_tr = model(input_ph, num_processing_steps_tr)
output_ops_ge = model(input_ph, num_processing_steps_ge)

# Training loss.
loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
# Loss across processing steps.
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr
# Test/generalization loss.
loss_ops_ge = create_loss_ops(target_ph, output_ops_ge)
loss_op_ge = loss_ops_ge[-1]  # Loss from final processing step.

# Optimizer.
learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

# Lets an iterable of TF graphs be output from a session as NP graphs.
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)

 

セッションをリセットする

# This cell resets the Tensorflow session, but keeps the same computational
# graph.

try:
  sess.close()
except NameError:
  pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())

last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

 

訓練を実行する

# You can interrupt this cell's training loop at any time, and visualize the
# intermediate results by running the next cell (below). You can then resume
# training by simply executing this cell again.

# How much time between logging and printing the current results.
log_every_seconds = 20

print("# (iteration number), T (elapsed seconds), "
      "Ltr (training loss), Lge (test/generalization loss), "
      "Ctr (training fraction nodes/edges labeled correctly), "
      "Str (training fraction examples solved correctly), "
      "Cge (test/generalization fraction nodes/edges labeled correctly), "
      "Sge (test/generalization fraction examples solved correctly)")

start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
  last_iteration = iteration
  feed_dict, _ = create_feed_dict(rand, batch_size_tr, num_nodes_min_max_tr,
                                  theta, input_ph, target_ph)
  train_values = sess.run({
      "step": step_op,
      "target": target_ph,
      "loss": loss_op_tr,
      "outputs": output_ops_tr
  },
                          feed_dict=feed_dict)
  the_time = time.time()
  elapsed_since_last_log = the_time - last_log_time
  if elapsed_since_last_log > log_every_seconds:
    last_log_time = the_time
    feed_dict, raw_graphs = create_feed_dict(
        rand, batch_size_ge, num_nodes_min_max_ge, theta, input_ph, target_ph)
    test_values = sess.run({
        "target": target_ph,
        "loss": loss_op_ge,
        "outputs": output_ops_ge
    },
                           feed_dict=feed_dict)
    correct_tr, solved_tr = compute_accuracy(
        train_values["target"], train_values["outputs"][-1], use_edges=True)
    correct_ge, solved_ge = compute_accuracy(
        test_values["target"], test_values["outputs"][-1], use_edges=True)
    elapsed = time.time() - start_time
    losses_tr.append(train_values["loss"])
    corrects_tr.append(correct_tr)
    solveds_tr.append(solved_tr)
    losses_ge.append(test_values["loss"])
    corrects_ge.append(correct_ge)
    solveds_ge.append(solved_ge)
    logged_iterations.append(iteration)
    print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, Str"
          " {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
              iteration, elapsed, train_values["loss"], test_values["loss"],
              correct_tr, solved_tr, correct_ge, solved_ge))
# (iteration number), T (elapsed seconds), Ltr (training loss), Lge (test/generalization loss), Ctr (training fraction nodes/edges labeled correctly), Str (training fraction examples solved correctly), Cge (test/generalization fraction nodes/edges labeled correctly), Sge (test/generalization fraction examples solved correctly)
# 00029, T 23.8, Ltr 0.8731, Lge 0.6658, Ctr 0.8596, Str 0.0000, Cge 0.9481, Sge 0.0000
# 00078, T 42.1, Ltr 0.6341, Lge 0.4758, Ctr 0.9056, Str 0.0000, Cge 0.9549, Sge 0.0000
# 00133, T 62.5, Ltr 0.5034, Lge 0.3845, Ctr 0.9172, Str 0.0000, Cge 0.9625, Sge 0.0200
# 00189, T 82.7, Ltr 0.5162, Lge 0.3417, Ctr 0.9166, Str 0.1250, Cge 0.9664, Sge 0.0200
# 00244, T 103.0, Ltr 0.4486, Lge 0.3383, Ctr 0.9343, Str 0.1250, Cge 0.9685, Sge 0.1600
# 00299, T 123.1, Ltr 0.4963, Lge 0.3507, Ctr 0.9184, Str 0.2500, Cge 0.9637, Sge 0.1100
# 00354, T 143.4, Ltr 0.3223, Lge 0.2883, Ctr 0.9614, Str 0.4062, Cge 0.9721, Sge 0.3300
# 00407, T 163.5, Ltr 0.4604, Lge 0.3853, Ctr 0.9270, Str 0.2500, Cge 0.9585, Sge 0.0600
# 00461, T 183.7, Ltr 0.2822, Lge 0.2933, Ctr 0.9670, Str 0.5625, Cge 0.9702, Sge 0.3200
# 00517, T 203.8, Ltr 0.3703, Lge 0.2784, Ctr 0.9480, Str 0.4375, Cge 0.9698, Sge 0.2600
# 00571, T 224.1, Ltr 0.4301, Lge 0.2783, Ctr 0.9308, Str 0.3125, Cge 0.9723, Sge 0.2000
# 00626, T 244.1, Ltr 0.3287, Lge 0.2833, Ctr 0.9533, Str 0.4062, Cge 0.9687, Sge 0.2700
# 00682, T 264.4, Ltr 0.2802, Lge 0.2913, Ctr 0.9617, Str 0.5000, Cge 0.9703, Sge 0.3000
# 00736, T 284.7, Ltr 0.3474, Lge 0.2775, Ctr 0.9531, Str 0.5625, Cge 0.9704, Sge 0.1900
# 00790, T 305.1, Ltr 0.3098, Lge 0.3607, Ctr 0.9488, Str 0.4062, Cge 0.9690, Sge 0.1700
# 00844, T 324.9, Ltr 0.3092, Lge 0.2941, Ctr 0.9566, Str 0.4375, Cge 0.9702, Sge 0.2500
# 00899, T 345.1, Ltr 0.3805, Lge 0.2202, Ctr 0.9440, Str 0.2812, Cge 0.9770, Sge 0.3200
# 00953, T 365.0, Ltr 0.2927, Lge 0.2637, Ctr 0.9609, Str 0.5938, Cge 0.9707, Sge 0.1900
# 01008, T 385.0, Ltr 0.3164, Lge 0.2093, Ctr 0.9568, Str 0.4688, Cge 0.9749, Sge 0.3100
# 01063, T 405.1, Ltr 0.2704, Lge 0.2455, Ctr 0.9749, Str 0.5000, Cge 0.9719, Sge 0.2000
# 01117, T 425.2, Ltr 0.2696, Lge 0.2713, Ctr 0.9600, Str 0.6250, Cge 0.9719, Sge 0.2300
# 01172, T 445.3, Ltr 0.4089, Lge 0.2442, Ctr 0.9489, Str 0.5312, Cge 0.9727, Sge 0.2100
# 01227, T 465.3, Ltr 0.3053, Lge 0.2616, Ctr 0.9620, Str 0.5312, Cge 0.9733, Sge 0.2400
# 01280, T 485.2, Ltr 0.2292, Lge 0.2433, Ctr 0.9742, Str 0.6250, Cge 0.9703, Sge 0.3100
# 01335, T 505.6, Ltr 0.3238, Lge 0.2267, Ctr 0.9554, Str 0.5312, Cge 0.9748, Sge 0.3000
# 01390, T 525.7, Ltr 0.3662, Lge 0.2706, Ctr 0.9602, Str 0.5938, Cge 0.9720, Sge 0.2500
# 01445, T 545.8, Ltr 0.2444, Lge 0.2530, Ctr 0.9755, Str 0.6562, Cge 0.9732, Sge 0.2800
# 01498, T 565.8, Ltr 0.3119, Lge 0.3036, Ctr 0.9565, Str 0.5938, Cge 0.9708, Sge 0.2300
# 01552, T 585.9, Ltr 0.3058, Lge 0.2633, Ctr 0.9553, Str 0.4688, Cge 0.9717, Sge 0.2400
# 01606, T 606.1, Ltr 0.2392, Lge 0.2462, Ctr 0.9782, Str 0.6562, Cge 0.9726, Sge 0.3000
# 01661, T 626.4, Ltr 0.2917, Lge 0.2522, Ctr 0.9611, Str 0.5625, Cge 0.9725, Sge 0.3000
# 01716, T 646.7, Ltr 0.3049, Lge 0.2254, Ctr 0.9566, Str 0.5938, Cge 0.9749, Sge 0.3300
# 01771, T 666.8, Ltr 0.2509, Lge 0.2393, Ctr 0.9667, Str 0.6250, Cge 0.9747, Sge 0.3000
# 01824, T 687.1, Ltr 0.1827, Lge 0.1870, Ctr 0.9879, Str 0.7812, Cge 0.9781, Sge 0.3300
# 01879, T 707.2, Ltr 0.3511, Lge 0.2048, Ctr 0.9574, Str 0.6562, Cge 0.9775, Sge 0.4200
# 01935, T 727.6, Ltr 0.2784, Lge 0.2044, Ctr 0.9699, Str 0.5625, Cge 0.9752, Sge 0.2000
# 01990, T 747.7, Ltr 0.3216, Lge 0.1943, Ctr 0.9639, Str 0.6562, Cge 0.9768, Sge 0.2800
# 02044, T 768.0, Ltr 0.1950, Lge 0.1579, Ctr 0.9892, Str 0.8750, Cge 0.9820, Sge 0.4100
# 02098, T 788.0, Ltr 0.2075, Lge 0.1729, Ctr 0.9882, Str 0.7500, Cge 0.9799, Sge 0.3700
# 02153, T 808.0, Ltr 0.2118, Lge 0.1775, Ctr 0.9783, Str 0.7500, Cge 0.9804, Sge 0.3700
# 02207, T 828.4, Ltr 0.2426, Lge 0.1862, Ctr 0.9669, Str 0.6562, Cge 0.9797, Sge 0.3800
# 02262, T 848.7, Ltr 0.2076, Lge 0.1836, Ctr 0.9862, Str 0.8750, Cge 0.9792, Sge 0.4300
# 02317, T 869.0, Ltr 0.1890, Lge 0.1984, Ctr 0.9873, Str 0.8750, Cge 0.9767, Sge 0.3800
# 02371, T 889.2, Ltr 0.1652, Lge 0.1936, Ctr 0.9887, Str 0.8125, Cge 0.9784, Sge 0.3900
# 02426, T 909.4, Ltr 0.2751, Lge 0.1523, Ctr 0.9707, Str 0.6875, Cge 0.9823, Sge 0.4100
# 02481, T 929.4, Ltr 0.1775, Lge 0.1617, Ctr 0.9867, Str 0.8750, Cge 0.9788, Sge 0.3800
# 02536, T 949.6, Ltr 0.2007, Lge 0.1207, Ctr 0.9880, Str 0.8438, Cge 0.9857, Sge 0.4900
# 02591, T 969.8, Ltr 0.1514, Lge 0.1489, Ctr 0.9934, Str 0.9062, Cge 0.9813, Sge 0.3600
# 02646, T 989.9, Ltr 0.2410, Lge 0.1100, Ctr 0.9862, Str 0.8125, Cge 0.9854, Sge 0.4000
# 02702, T 1010.0, Ltr 0.1991, Lge 0.1578, Ctr 0.9827, Str 0.8125, Cge 0.9813, Sge 0.3800
# 02756, T 1030.1, Ltr 0.1464, Lge 0.1388, Ctr 0.9893, Str 0.8750, Cge 0.9814, Sge 0.4000
# 02811, T 1050.1, Ltr 0.1931, Lge 0.1588, Ctr 0.9833, Str 0.8750, Cge 0.9799, Sge 0.3900
# 02866, T 1070.1, Ltr 0.1570, Lge 0.1189, Ctr 0.9858, Str 0.8750, Cge 0.9858, Sge 0.5500
# 02920, T 1090.5, Ltr 0.1420, Lge 0.1113, Ctr 0.9922, Str 0.8750, Cge 0.9855, Sge 0.4500
# 02974, T 1110.8, Ltr 0.1550, Lge 0.1640, Ctr 0.9911, Str 0.8438, Cge 0.9809, Sge 0.3200
# 03029, T 1130.8, Ltr 0.1681, Lge 0.1297, Ctr 0.9936, Str 0.8750, Cge 0.9873, Sge 0.4900
# 03084, T 1151.5, Ltr 0.1810, Lge 0.1909, Ctr 0.9921, Str 0.8750, Cge 0.9785, Sge 0.2300
# 03139, T 1171.1, Ltr 0.2063, Lge 0.1209, Ctr 0.9818, Str 0.8125, Cge 0.9861, Sge 0.4400
# 03195, T 1191.2, Ltr 0.1340, Lge 0.1583, Ctr 0.9945, Str 0.8750, Cge 0.9789, Sge 0.2100
# 03251, T 1211.7, Ltr 0.1461, Lge 0.1520, Ctr 0.9943, Str 0.9062, Cge 0.9856, Sge 0.4800
# 03303, T 1231.8, Ltr 0.1694, Lge 0.1235, Ctr 0.9854, Str 0.7812, Cge 0.9852, Sge 0.4900
# 03359, T 1252.0, Ltr 0.1738, Lge 0.1222, Ctr 0.9852, Str 0.7812, Cge 0.9846, Sge 0.4400
# 03414, T 1272.2, Ltr 0.1498, Lge 0.1101, Ctr 0.9897, Str 0.8438, Cge 0.9867, Sge 0.4900
# 03468, T 1292.5, Ltr 0.1638, Lge 0.1573, Ctr 0.9894, Str 0.8438, Cge 0.9836, Sge 0.4500
# 03523, T 1312.6, Ltr 0.2194, Lge 0.1516, Ctr 0.9761, Str 0.7188, Cge 0.9846, Sge 0.4600
# 03578, T 1332.6, Ltr 0.1490, Lge 0.1425, Ctr 0.9874, Str 0.9375, Cge 0.9846, Sge 0.5300
# 03633, T 1353.0, Ltr 0.1951, Lge 0.0889, Ctr 0.9860, Str 0.8750, Cge 0.9883, Sge 0.5600
# 03686, T 1373.0, Ltr 0.1586, Lge 0.1016, Ctr 0.9900, Str 0.9062, Cge 0.9875, Sge 0.4900
# 03741, T 1393.4, Ltr 0.1404, Lge 0.1356, Ctr 0.9911, Str 0.8750, Cge 0.9855, Sge 0.5800
# 03794, T 1413.3, Ltr 0.1938, Lge 0.1298, Ctr 0.9852, Str 0.7812, Cge 0.9828, Sge 0.4300
# 03847, T 1433.7, Ltr 0.1412, Lge 0.1183, Ctr 0.9899, Str 0.8750, Cge 0.9870, Sge 0.5900
# 03901, T 1453.8, Ltr 0.1894, Lge 0.0941, Ctr 0.9842, Str 0.8438, Cge 0.9890, Sge 0.6000
# 03955, T 1473.8, Ltr 0.1605, Lge 0.0935, Ctr 0.9867, Str 0.8125, Cge 0.9876, Sge 0.5800
# 04008, T 1493.9, Ltr 0.1560, Lge 0.0700, Ctr 0.9886, Str 0.8438, Cge 0.9927, Sge 0.6900
# 04062, T 1514.0, Ltr 0.2174, Lge 0.1912, Ctr 0.9865, Str 0.7812, Cge 0.9780, Sge 0.3800
# 04115, T 1534.0, Ltr 0.1537, Lge 0.0797, Ctr 0.9852, Str 0.8438, Cge 0.9891, Sge 0.5900
# 04169, T 1554.1, Ltr 0.1586, Lge 0.1071, Ctr 0.9871, Str 0.8438, Cge 0.9864, Sge 0.6100
# 04223, T 1574.5, Ltr 0.1071, Lge 0.1316, Ctr 1.0000, Str 1.0000, Cge 0.9869, Sge 0.6200
# 04278, T 1594.7, Ltr 0.1270, Lge 0.1329, Ctr 0.9985, Str 0.9375, Cge 0.9850, Sge 0.5400
# 04333, T 1614.8, Ltr 0.1352, Lge 0.1023, Ctr 0.9929, Str 0.9375, Cge 0.9864, Sge 0.5300
# 04386, T 1635.2, Ltr 0.1423, Lge 0.0890, Ctr 0.9894, Str 0.8125, Cge 0.9875, Sge 0.4600
# 04440, T 1655.1, Ltr 0.1320, Lge 0.0963, Ctr 0.9994, Str 0.9688, Cge 0.9882, Sge 0.5900
# 04495, T 1675.5, Ltr 0.1603, Lge 0.1094, Ctr 0.9889, Str 0.8750, Cge 0.9876, Sge 0.4800
# 04548, T 1695.7, Ltr 0.1474, Lge 0.1107, Ctr 0.9949, Str 0.9375, Cge 0.9868, Sge 0.5600
# 04602, T 1715.5, Ltr 0.1608, Lge 0.1791, Ctr 0.9960, Str 0.8438, Cge 0.9811, Sge 0.3700
# 04656, T 1735.6, Ltr 0.1416, Lge 0.1130, Ctr 0.9899, Str 0.8438, Cge 0.9865, Sge 0.5600
# 04710, T 1755.5, Ltr 0.1868, Lge 0.1135, Ctr 0.9944, Str 0.9375, Cge 0.9862, Sge 0.5500
# 04764, T 1775.7, Ltr 0.1466, Lge 0.0730, Ctr 0.9916, Str 0.9375, Cge 0.9901, Sge 0.6600
# 04819, T 1795.9, Ltr 0.1147, Lge 0.0881, Ctr 0.9966, Str 0.9688, Cge 0.9906, Sge 0.6900
# 04874, T 1816.0, Ltr 0.1130, Lge 0.1065, Ctr 0.9987, Str 0.9688, Cge 0.9868, Sge 0.5900
# 04928, T 1836.3, Ltr 0.1979, Lge 0.0953, Ctr 0.9909, Str 0.8750, Cge 0.9885, Sge 0.5200
# 04982, T 1856.2, Ltr 0.1319, Lge 0.1024, Ctr 0.9929, Str 0.9062, Cge 0.9875, Sge 0.6000
# 05036, T 1876.4, Ltr 0.1575, Lge 0.0744, Ctr 0.9910, Str 0.9062, Cge 0.9914, Sge 0.6300
# 05090, T 1896.6, Ltr 0.1387, Lge 0.1054, Ctr 0.9938, Str 0.9062, Cge 0.9877, Sge 0.5800
# 05144, T 1916.8, Ltr 0.1196, Lge 0.1088, Ctr 0.9929, Str 0.8750, Cge 0.9857, Sge 0.4800
# 05198, T 1936.6, Ltr 0.1441, Lge 0.0758, Ctr 0.9912, Str 0.9375, Cge 0.9902, Sge 0.6900
# 05252, T 1957.3, Ltr 0.1296, Lge 0.1036, Ctr 0.9957, Str 0.8750, Cge 0.9909, Sge 0.7000
# 05304, T 1976.9, Ltr 0.1311, Lge 0.1073, Ctr 0.9956, Str 0.9062, Cge 0.9883, Sge 0.6700
# 05359, T 1997.1, Ltr 0.0968, Lge 0.1633, Ctr 1.0000, Str 1.0000, Cge 0.9860, Sge 0.5500
# 05413, T 2017.0, Ltr 0.1550, Lge 0.0875, Ctr 0.9903, Str 0.9062, Cge 0.9896, Sge 0.6600
# 05466, T 2037.4, Ltr 0.1204, Lge 0.2264, Ctr 0.9966, Str 0.9062, Cge 0.9834, Sge 0.6400
# 05519, T 2057.5, Ltr 0.1255, Lge 0.1421, Ctr 0.9953, Str 0.9375, Cge 0.9868, Sge 0.6100
# 05573, T 2077.5, Ltr 0.1255, Lge 0.0956, Ctr 0.9941, Str 0.9062, Cge 0.9900, Sge 0.6800
# 05628, T 2098.2, Ltr 0.1427, Lge 0.0803, Ctr 0.9953, Str 0.9062, Cge 0.9893, Sge 0.6500
# 05683, T 2118.4, Ltr 0.1344, Lge 0.0857, Ctr 0.9934, Str 0.9062, Cge 0.9909, Sge 0.6000
# 05739, T 2138.5, Ltr 0.1634, Lge 0.1224, Ctr 0.9931, Str 0.9375, Cge 0.9875, Sge 0.5600
# 05793, T 2158.9, Ltr 0.1784, Lge 0.0720, Ctr 0.9853, Str 0.8750, Cge 0.9905, Sge 0.5800
# 05847, T 2179.2, Ltr 0.1259, Lge 0.0641, Ctr 0.9975, Str 0.9688, Cge 0.9921, Sge 0.6800
# 05902, T 2199.3, Ltr 0.0840, Lge 0.1022, Ctr 0.9974, Str 0.9688, Cge 0.9880, Sge 0.5600
# 05957, T 2219.4, Ltr 0.1161, Lge 0.0861, Ctr 0.9978, Str 0.9688, Cge 0.9906, Sge 0.6900
# 06010, T 2239.6, Ltr 0.1470, Lge 0.0660, Ctr 0.9894, Str 0.8125, Cge 0.9906, Sge 0.6700
# 06065, T 2259.7, Ltr 0.1664, Lge 0.1401, Ctr 0.9937, Str 0.9375, Cge 0.9831, Sge 0.5000
# 06120, T 2279.9, Ltr 0.1733, Lge 0.0901, Ctr 0.9829, Str 0.8438, Cge 0.9880, Sge 0.5600
# 06173, T 2300.1, Ltr 0.1269, Lge 0.0721, Ctr 0.9936, Str 0.9375, Cge 0.9926, Sge 0.7600
# 06228, T 2320.4, Ltr 0.1716, Lge 0.0702, Ctr 0.9926, Str 0.9375, Cge 0.9908, Sge 0.6500
# 06282, T 2340.4, Ltr 0.1386, Lge 0.0545, Ctr 0.9975, Str 0.9688, Cge 0.9926, Sge 0.7100
# 06337, T 2360.9, Ltr 0.1400, Lge 0.0512, Ctr 0.9960, Str 0.9062, Cge 0.9926, Sge 0.6100
# 06391, T 2380.5, Ltr 0.1468, Lge 0.0791, Ctr 0.9965, Str 0.8750, Cge 0.9894, Sge 0.6300
# 06446, T 2400.6, Ltr 0.1655, Lge 0.0847, Ctr 0.9900, Str 0.8750, Cge 0.9897, Sge 0.5800
# 06500, T 2420.6, Ltr 0.1530, Lge 0.0538, Ctr 0.9878, Str 0.7812, Cge 0.9925, Sge 0.6800
# 06553, T 2440.5, Ltr 0.1442, Lge 0.0634, Ctr 0.9969, Str 0.9375, Cge 0.9919, Sge 0.7500
# 06609, T 2460.8, Ltr 0.0933, Lge 0.0678, Ctr 0.9987, Str 0.9688, Cge 0.9912, Sge 0.6900
# 06663, T 2481.0, Ltr 0.1460, Lge 0.0936, Ctr 0.9953, Str 0.9375, Cge 0.9879, Sge 0.5400
# 06716, T 2501.4, Ltr 0.1505, Lge 0.0685, Ctr 0.9941, Str 0.9375, Cge 0.9914, Sge 0.7200
# 06769, T 2521.4, Ltr 0.1400, Lge 0.0530, Ctr 0.9955, Str 0.9062, Cge 0.9931, Sge 0.7400
# 06823, T 2541.4, Ltr 0.1310, Lge 0.0799, Ctr 0.9936, Str 0.8750, Cge 0.9896, Sge 0.6900
# 06878, T 2561.7, Ltr 0.1640, Lge 0.0848, Ctr 0.9885, Str 0.7812, Cge 0.9884, Sge 0.5900
# 06931, T 2581.7, Ltr 0.1395, Lge 0.0783, Ctr 0.9954, Str 0.9375, Cge 0.9904, Sge 0.6400
# 06986, T 2601.9, Ltr 0.1150, Lge 0.0546, Ctr 0.9969, Str 0.9375, Cge 0.9923, Sge 0.7400
# 07041, T 2622.1, Ltr 0.0829, Lge 0.0574, Ctr 1.0000, Str 1.0000, Cge 0.9923, Sge 0.6700
# 07095, T 2642.4, Ltr 0.1719, Lge 0.1901, Ctr 0.9907, Str 0.9375, Cge 0.9819, Sge 0.2800
# 07149, T 2662.5, Ltr 0.2284, Lge 0.0478, Ctr 0.9789, Str 0.8438, Cge 0.9934, Sge 0.7100
# 07204, T 2682.8, Ltr 0.1277, Lge 0.0614, Ctr 0.9923, Str 0.8750, Cge 0.9914, Sge 0.6000
# 07256, T 2703.0, Ltr 0.2056, Lge 0.0849, Ctr 0.9938, Str 0.9062, Cge 0.9910, Sge 0.6400
# 07311, T 2723.1, Ltr 0.1456, Lge 0.0573, Ctr 0.9967, Str 0.9688, Cge 0.9924, Sge 0.6700
# 07366, T 2743.3, Ltr 0.1366, Lge 0.0878, Ctr 0.9993, Str 0.9688, Cge 0.9898, Sge 0.7200
# 07420, T 2764.0, Ltr 0.1349, Lge 0.0462, Ctr 0.9953, Str 0.9375, Cge 0.9948, Sge 0.7700
# 07472, T 2783.6, Ltr 0.1244, Lge 0.0604, Ctr 0.9955, Str 0.9375, Cge 0.9929, Sge 0.7300
# 07528, T 2803.8, Ltr 0.1206, Lge 0.0890, Ctr 1.0000, Str 1.0000, Cge 0.9875, Sge 0.5300
# 07583, T 2824.1, Ltr 0.1248, Lge 0.0860, Ctr 0.9993, Str 0.9688, Cge 0.9910, Sge 0.7300
# 07636, T 2844.1, Ltr 0.1737, Lge 0.1036, Ctr 0.9909, Str 0.9062, Cge 0.9891, Sge 0.6600
# 07689, T 2864.1, Ltr 0.1297, Lge 0.0718, Ctr 0.9974, Str 0.9375, Cge 0.9933, Sge 0.7400
# 07744, T 2884.4, Ltr 0.1139, Lge 0.0905, Ctr 0.9962, Str 0.9375, Cge 0.9888, Sge 0.5700
# 07797, T 2904.6, Ltr 0.1405, Lge 0.0703, Ctr 0.9975, Str 0.9062, Cge 0.9905, Sge 0.6900
# 07852, T 2924.8, Ltr 0.1078, Lge 0.0566, Ctr 0.9986, Str 0.9375, Cge 0.9926, Sge 0.7600
# 07906, T 2944.9, Ltr 0.1498, Lge 0.0772, Ctr 0.9923, Str 0.8750, Cge 0.9902, Sge 0.6400
# 07959, T 2965.1, Ltr 0.1378, Lge 0.0657, Ctr 0.9919, Str 0.9375, Cge 0.9919, Sge 0.6900
# 08012, T 2985.2, Ltr 0.1468, Lge 0.0639, Ctr 0.9872, Str 0.8438, Cge 0.9919, Sge 0.6600
# 08067, T 3005.5, Ltr 0.1473, Lge 0.0555, Ctr 0.9967, Str 0.9688, Cge 0.9930, Sge 0.7400
# 08121, T 3025.6, Ltr 0.0928, Lge 0.0502, Ctr 0.9963, Str 0.9375, Cge 0.9930, Sge 0.6500
# 08174, T 3045.9, Ltr 0.1561, Lge 0.0637, Ctr 0.9906, Str 0.8750, Cge 0.9929, Sge 0.7300
# 08229, T 3066.2, Ltr 0.1539, Lge 0.1363, Ctr 0.9885, Str 0.8438, Cge 0.9887, Sge 0.6600
# 08283, T 3086.3, Ltr 0.1270, Lge 0.0807, Ctr 0.9930, Str 0.9375, Cge 0.9889, Sge 0.5900
# 08337, T 3107.0, Ltr 0.1001, Lge 0.0721, Ctr 0.9948, Str 0.9375, Cge 0.9909, Sge 0.6900
# 08391, T 3126.5, Ltr 0.1344, Lge 0.0732, Ctr 0.9944, Str 0.9375, Cge 0.9917, Sge 0.7000
# 08446, T 3146.8, Ltr 0.1127, Lge 0.0597, Ctr 0.9994, Str 0.9688, Cge 0.9907, Sge 0.6600
# 08502, T 3167.0, Ltr 0.1328, Lge 0.0496, Ctr 0.9959, Str 0.9375, Cge 0.9928, Sge 0.7600
# 08556, T 3187.1, Ltr 0.1424, Lge 0.0844, Ctr 0.9953, Str 0.9062, Cge 0.9901, Sge 0.6900
# 08612, T 3207.4, Ltr 0.1434, Lge 0.0986, Ctr 0.9949, Str 0.9062, Cge 0.9863, Sge 0.5700
# 08666, T 3227.6, Ltr 0.1772, Lge 0.0893, Ctr 0.9865, Str 0.8438, Cge 0.9894, Sge 0.6100
# 08720, T 3247.8, Ltr 0.1491, Lge 0.0896, Ctr 0.9906, Str 0.9375, Cge 0.9893, Sge 0.6800
# 08775, T 3268.2, Ltr 0.1873, Lge 0.0815, Ctr 0.9895, Str 0.8438, Cge 0.9905, Sge 0.6900
# 08830, T 3288.4, Ltr 0.1128, Lge 0.0790, Ctr 0.9988, Str 0.9688, Cge 0.9881, Sge 0.6400
# 08882, T 3308.5, Ltr 0.1164, Lge 0.1097, Ctr 0.9957, Str 0.9688, Cge 0.9896, Sge 0.7200
# 08935, T 3328.4, Ltr 0.1509, Lge 0.0733, Ctr 0.9891, Str 0.8438, Cge 0.9901, Sge 0.6100
# 08990, T 3348.7, Ltr 0.1394, Lge 0.0357, Ctr 0.9969, Str 0.9375, Cge 0.9954, Sge 0.7800
# 09045, T 3368.9, Ltr 0.1263, Lge 0.0905, Ctr 0.9933, Str 0.8750, Cge 0.9882, Sge 0.6400
# 09098, T 3388.9, Ltr 0.1421, Lge 0.0697, Ctr 0.9929, Str 0.9062, Cge 0.9935, Sge 0.7900
# 09152, T 3409.1, Ltr 0.1357, Lge 0.0765, Ctr 0.9904, Str 0.7812, Cge 0.9904, Sge 0.6100
# 09207, T 3429.5, Ltr 0.1691, Lge 0.0696, Ctr 0.9950, Str 0.9688, Cge 0.9917, Sge 0.6700
# 09260, T 3449.4, Ltr 0.1421, Lge 0.0924, Ctr 0.9928, Str 0.9062, Cge 0.9896, Sge 0.6700
# 09315, T 3469.8, Ltr 0.1280, Lge 0.0687, Ctr 0.9941, Str 0.9375, Cge 0.9906, Sge 0.6900
# 09370, T 3490.1, Ltr 0.1428, Lge 0.0758, Ctr 0.9968, Str 0.9062, Cge 0.9920, Sge 0.7100
# 09423, T 3510.1, Ltr 0.1391, Lge 0.0665, Ctr 0.9956, Str 0.9062, Cge 0.9915, Sge 0.7100
# 09479, T 3530.1, Ltr 0.1644, Lge 0.1169, Ctr 0.9915, Str 0.9062, Cge 0.9883, Sge 0.6300
# 09535, T 3550.4, Ltr 0.1296, Lge 0.0786, Ctr 0.9898, Str 0.9062, Cge 0.9901, Sge 0.6600
# 09590, T 3570.6, Ltr 0.1611, Lge 0.0532, Ctr 0.9871, Str 0.8750, Cge 0.9927, Sge 0.7000
# 09644, T 3590.7, Ltr 0.1599, Lge 0.0585, Ctr 0.9943, Str 0.9375, Cge 0.9918, Sge 0.6900
# 09698, T 3610.8, Ltr 0.1409, Lge 0.0598, Ctr 0.9944, Str 0.8750, Cge 0.9926, Sge 0.7200
# 09754, T 3631.0, Ltr 0.1637, Lge 0.0523, Ctr 0.9900, Str 0.9375, Cge 0.9926, Sge 0.7300
# 09808, T 3651.2, Ltr 0.1112, Lge 0.0894, Ctr 0.9931, Str 0.9375, Cge 0.9924, Sge 0.7200
# 09863, T 3671.4, Ltr 0.1136, Lge 0.0938, Ctr 0.9979, Str 0.9688, Cge 0.9877, Sge 0.5300
# 09918, T 3691.5, Ltr 0.1333, Lge 0.0761, Ctr 0.9890, Str 0.8750, Cge 0.9905, Sge 0.6200
# 09972, T 3711.8, Ltr 0.1063, Lge 0.0603, Ctr 0.9975, Str 0.9375, Cge 0.9933, Sge 0.7100

 

結果を可視化する

# This cell visualizes the results of training. You can visualize the
# intermediate results by interrupting execution of the cell above, and running
# this cell. You can then resume training by simply executing the above cell
# again.

def softmax_prob_last_dim(x):  # pylint: disable=redefined-outer-name
  e = np.exp(x)
  return e[:, -1] / np.sum(e, axis=-1)


# Plot results curves.
fig = plt.figure(1, figsize=(18, 3))
fig.clf()
x = np.array(logged_iterations)
# Loss.
y_tr = losses_tr
y_ge = losses_ge
ax = fig.add_subplot(1, 3, 1)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Loss across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Loss (binary cross-entropy)")
ax.legend()
# Correct.
y_tr = corrects_tr
y_ge = corrects_ge
ax = fig.add_subplot(1, 3, 2)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction correct across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction nodes/edges correct")
# Solved.
y_tr = solveds_tr
y_ge = solveds_ge
ax = fig.add_subplot(1, 3, 3)
ax.plot(x, y_tr, "k", label="Training")
ax.plot(x, y_ge, "k--", label="Test/generalization")
ax.set_title("Fraction solved across training")
ax.set_xlabel("Training iteration")
ax.set_ylabel("Fraction examples solved")

# Plot graphs and results after each processing step.
# The white node is the start, and the black is the end. Other nodes are colored
# from red to purple to blue, where red means the model is confident the node is
# off the shortest path, blue means the model is confident the node is on the
# shortest path, and purplish colors mean the model isn't sure.
max_graphs_to_plot = 6
num_steps_to_plot = 4
node_size = 120
min_c = 0.3
num_graphs = len(raw_graphs)
targets = utils_np.graphs_tuple_to_data_dicts(test_values["target"])
step_indices = np.floor(
    np.linspace(0, num_processing_steps_ge - 1,
                num_steps_to_plot)).astype(int).tolist()
outputs = list(
    zip(*(utils_np.graphs_tuple_to_data_dicts(test_values["outputs"][i])
          for i in step_indices)))
h = min(num_graphs, max_graphs_to_plot)
w = num_steps_to_plot + 1
fig = plt.figure(101, figsize=(18, h * 3))
fig.clf()
ncs = []
for j, (graph, target, output) in enumerate(zip(raw_graphs, targets, outputs)):
  if j >= h:
    break
  pos = get_node_dict(graph, "pos")
  ground_truth = target["nodes"][:, -1]
  # Ground truth.
  iax = j * (1 + num_steps_to_plot) + 1
  ax = fig.add_subplot(h, w, iax)
  plotter = GraphPlotter(ax, graph, pos)
  color = {}
  for i, n in enumerate(plotter.nodes):
    color[n] = np.array([1.0 - ground_truth[i], 0.0, ground_truth[i], 1.0
                        ]) * (1.0 - min_c) + min_c
  plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
  ax.set_axis_on()
  ax.set_xticks([])
  ax.set_yticks([])
  try:
    ax.set_facecolor([0.9] * 3 + [1.0])
  except AttributeError:
    ax.set_axis_bgcolor([0.9] * 3 + [1.0])
  ax.grid(None)
  ax.set_title("Ground truth\nSolution length: {}".format(
      plotter.solution_length))
  # Prediction.
  for k, outp in enumerate(output):
    iax = j * (1 + num_steps_to_plot) + 2 + k
    ax = fig.add_subplot(h, w, iax)
    plotter = GraphPlotter(ax, graph, pos)
    color = {}
    prob = softmax_prob_last_dim(outp["nodes"])
    for i, n in enumerate(plotter.nodes):
      color[n] = np.array([1.0 - prob[n], 0.0, prob[n], 1.0
                          ]) * (1.0 - min_c) + min_c
    plotter.draw_graph_with_solution(node_size=node_size, node_color=color)
    ax.set_title("Model-predicted\nStep {:02d} / {:02d}".format(
        step_indices[k] + 1, step_indices[-1] + 1))

 

以上






クラスキャット

最近の投稿

  • LangGraph 0.5 : エージェント開発 : エージェント・アーキテクチャ
  • LangGraph 0.5 : エージェント開発 : ワークフローとエージェント
  • LangGraph 0.5 : エージェント開発 : エージェントの実行
  • LangGraph 0.5 : エージェント開発 : prebuilt コンポーネントを使用したエージェント開発
  • LangGraph 0.5 : Get started : ローカルサーバの実行

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (24) LangGraph 0.5 (9) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2019年6月
月 火 水 木 金 土 日
 12
3456789
10111213141516
17181920212223
24252627282930
« 5月   7月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme