Skip to content

ClasCat® AI Research

クラスキャット – 生成 AI, AI エージェント, MCP

Menu
  • ホーム
    • ClassCat® AI Research ホーム
    • クラスキャット・ホーム
  • OpenAI API
    • OpenAI Python ライブラリ 1.x : 概要
    • OpenAI ブログ
      • GPT の紹介
      • GPT ストアの紹介
      • ChatGPT Team の紹介
    • OpenAI platform 1.x
      • Get Started : イントロダクション
      • Get Started : クイックスタート (Python)
      • Get Started : クイックスタート (Node.js)
      • Get Started : モデル
      • 機能 : 埋め込み
      • 機能 : 埋め込み (ユースケース)
      • ChatGPT : アクション – イントロダクション
      • ChatGPT : アクション – Getting started
      • ChatGPT : アクション – アクション認証
    • OpenAI ヘルプ : ChatGPT
      • ChatGPTとは何ですか?
      • ChatGPT は真実を語っていますか?
      • GPT の作成
      • GPT FAQ
      • GPT vs アシスタント
      • GPT ビルダー
    • OpenAI ヘルプ : ChatGPT > メモリ
      • FAQ
    • OpenAI ヘルプ : GPT ストア
      • 貴方の GPT をフィーチャーする
    • OpenAI Python ライブラリ 0.27 : 概要
    • OpenAI platform
      • Get Started : イントロダクション
      • Get Started : クイックスタート
      • Get Started : モデル
      • ガイド : GPT モデル
      • ガイド : 画像生成 (DALL·E)
      • ガイド : GPT-3.5 Turbo 対応 微調整
      • ガイド : 微調整 1.イントロダクション
      • ガイド : 微調整 2. データセットの準備 / ケーススタディ
      • ガイド : 埋め込み
      • ガイド : 音声テキスト変換
      • ガイド : モデレーション
      • ChatGPT プラグイン : イントロダクション
    • OpenAI Cookbook
      • 概要
      • API 使用方法 : レート制限の操作
      • API 使用方法 : tiktoken でトークンを数える方法
      • GPT : ChatGPT モデルへの入力をフォーマットする方法
      • GPT : 補完をストリームする方法
      • GPT : 大規模言語モデルを扱う方法
      • 埋め込み : 埋め込みの取得
      • GPT-3 の微調整 : 分類サンプルの微調整
      • DALL-E : DALL·E で 画像を生成して編集する方法
      • DALL·E と Segment Anything で動的マスクを作成する方法
      • Whisper プロンプティング・ガイド
  • Gemini API
    • Tutorials : クイックスタート with Python (1) テキスト-to-テキスト生成
    • (2) マルチモーダル入力 / 日本語チャット
    • (3) 埋め込みの使用
    • (4) 高度なユースケース
    • クイックスタート with Node.js
    • クイックスタート with Dart or Flutter (1) 日本語動作確認
    • Gemma
      • 概要 (README)
      • Tutorials : サンプリング
      • Tutorials : KerasNLP による Getting Started
  • Keras 3
    • 新しいマルチバックエンド Keras
    • Keras 3 について
    • Getting Started : エンジニアのための Keras 入門
    • Google Colab 上のインストールと Stable Diffusion デモ
    • コンピュータビジョン – ゼロからの画像分類
    • コンピュータビジョン – 単純な MNIST convnet
    • コンピュータビジョン – EfficientNet を使用した微調整による画像分類
    • コンピュータビジョン – Vision Transformer による画像分類
    • コンピュータビジョン – 最新の MLPモデルによる画像分類
    • コンピュータビジョン – コンパクトな畳込み Transformer
    • Keras Core
      • Keras Core 0.1
        • 新しいマルチバックエンド Keras (README)
        • Keras for TensorFlow, JAX, & PyTorch
        • 開発者ガイド : Getting started with Keras Core
        • 開発者ガイド : 関数型 API
        • 開発者ガイド : シーケンシャル・モデル
        • 開発者ガイド : サブクラス化で新しい層とモデルを作成する
        • 開発者ガイド : 独自のコールバックを書く
      • Keras Core 0.1.1 & 0.1.2 : リリースノート
      • 開発者ガイド
      • Code examples
      • Keras Stable Diffusion
        • 概要
        • 基本的な使い方 (テキスト-to-画像 / 画像-to-画像変換)
        • 混合精度のパフォーマンス
        • インペインティングの簡易アプリケーション
        • (参考) KerasCV – Stable Diffusion を使用した高性能画像生成
  • TensorFlow
    • TF 2 : 初級チュートリアル
    • TF 2 : 上級チュートリアル
    • TF 2 : ガイド
    • TF 1 : チュートリアル
    • TF 1 : ガイド
  • その他
    • 🦜️🔗 LangChain ドキュメント / ユースケース
    • Stable Diffusion WebUI
      • Google Colab で Stable Diffusion WebUI 入門
      • HuggingFace モデル / VAE の導入
      • LoRA の利用
    • Diffusion Models / 拡散モデル
  • クラスキャット
    • 会社案内
    • お問合せ
    • Facebook
    • ClassCat® Blog
Menu

DGL 0.5 ユーザガイド : 3 章 GNN モジュールをビルドする

Posted on 09/19/2020 by Sales Information

DGL 0.5ユーザガイド : 3 章 GNN モジュールをビルドする (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/18/2020 (0.5.1)

* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

  • Chapter 3: Building GNN Modules

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ユーザガイド : 3 章 GNN モジュールをビルドする

DGL NN モジュールは貴方の GNN モデルのためのビルディング・ブロックです。使用中の DNN フレームワーク・バックエンドに依拠して、それは PyTorch の NN モジュール、MXNet Gluon の NN ブロックと TensorFlow の Keras 層から継承しています。DGL NN モジュールでは、forward 関数の構築関数と tensor 演算のパラメータ登録はバックエンド・フレームワークと同じです。このようにして、DGL コードはバックエンド・フレームワーク・コードにシームレスに統合できます。主要な違いは、DGL に固有なメッセージ・パッシング演算にあります。

DGL は多くの一般に使用される Conv 層, Dense Conv 層, Global Pooling 層、そして ユティリティ・モジュール を統合しました。We welcome your contribution!

このセクションでは、貴方自身の DGL NN モジュールをどのようにビルドするかを紹介するためのサンプルとして SAGEConv を PyTorch バックエンドで利用します。

 

DGL NN モジュール構築 (= Construction) 関数

構築関数は以下を行ないます :

  1. オプションを設定する。
  2. 学習可能なパラメータやサブモジュールを登録する。
  3. パラメータをリセットする。
import torch as th
from torch import nn
from torch.nn import init

from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

構築関数では、最初にデータ次元を設定する必要があります。一般的な PyTorch モジュールのためには、次元は通常は入力次元、出力次元そして隠れ次元です。グラフ・ニューラルに対しては、入力次元はソースノード次元と destination ノード次元に分けることができます。

データ次元に加えて、グラフ・ニューラルネットワークのための典型的なオプションは aggregation 型 (self._aggre_type) です。aggregation 型は異なるエッジ上のメッセージがある destination ノードのためにどのように集約されるかを決定します。一般に利用される aggregation 型は mean, sum, max, min を含みます。幾つかのモジュールは lstm のようなより複雑な aggregation を適用するかもしれません。

ここで norm は特徴正規化のための callable 関数です。SAGEConv ペーパーでは、そのような正規化は l2 norm であり得ます : \(h_v = h_v / \lVert h_v \rVert_2\)。

# aggregator type: mean, max_pool, lstm, gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
    raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
    self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
    self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
    self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()

パラメータとサブモジュールを登録します。SAGEConv では、サブモジュールは aggregation 型に従って様々です。これらのモジュールは nn.Linear, nn.LSTM 等のような純粋な PyTorch nn モジュールです。構築関数の最後に、reset_parameters() を呼び出すことにより重み初期化が適用されます。

def reset_parameters(self):
    """Reinitialize learnable parameters."""
    gain = nn.init.calculate_gain('relu')
    if self._aggre_type == 'max_pool':
        nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
    if self._aggre_type == 'lstm':
        self.lstm.reset_parameters()
    if self._aggre_type != 'gcn':
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
    nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

 

DGL NN モジュール Forward 関数

NN モジュールでは、forward() 関数が実際のメッセージ・パッシングと計算を行ないます。パラメータとして通常は tensor を取るPyTorch の NN モジュールと比べて、DGL NN モジュールは追加のパラメータ dgl.DGLGraph を取ります。forward() 関数のための作業負荷は 3 つのパートに分割できます :

  • グラフ確認とグラフ型仕様。
  • メッセージ・パッシングと reducing。
  • 出力のために reduce の後で特徴を更新する。

SAGEConv サンプルの forward() 関数を深く調べましょう。

 

グラフ確認とグラフ型仕様

def forward(self, graph, feat):
    with graph.local_scope():
        # Specify graph type then expand input feature according to graph type
        feat_src, feat_dst = expand_as_pair(feat, graph)

forward() は計算とメッセージ・パッシングで不正な値に導く可能性がある入力の多くの扱いにくいケース (= corner cases) を処理する必要があります。GraphConv のような conv モジュールでの一つの典型的なチェックは入力グラフに 0-in-degree ノードがないことを検証することです。ノードが 0-in-degree を持つとき、mailbox は空となり reduce 関数は総てゼロの値を生成します。これはモデル性能において静かな regression を引き起こすかもしれません。けれども、SAGEConv モジュールでは、集約 (= aggregated) 表現は元のノード特徴と連結されて、forward() の出力は総てゼロではありません。この場合にはそのようなチェックは必要ありません。

DGL NN モジュールは以下を含む異なる型のグラフ入力に渡り再利用可能であるべきです : 均質グラフ、異質グラフ (1.5 異質グラフ)、サブグラフ・ブロック (Chapter 6: Stochastic Training on Large Graphs)。

SAGEConv のための数式は :

\[
h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate}
\left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right) \\
h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}
(h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1} + b) \right)\\
h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l})
\]

グラフ型に従ってソースノード特徴 feat_src と destination ノード特徴 feat_dst を指定する必要があります。グラフ型を指定して feat を feat_src と feat_dst 内に拡張 (= expand) するための関数は expand_as_pair() です。この関数の詳細は下で示されます。

def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # Bipartite graph case
        return input_
    elif g is not None and g.is_block:
        # Subgraph block case
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # Homograph case
        return input_, input_

均質グラフ全体の訓練については、ソースノードと destination ノードは同じです。それらは総てグラフのノードです。

異質 (グラフ) なケースについては、グラフは幾つかの 2 部グラフに分割できます、各リレーションに対して一つです。リレーションは (src_type, edge_type, dst_dtype) として表されます。入力特徴 feat がタプルであることを識別するとき、グラフを 2 部として扱います。タプルの最初の要素はソースノード特徴でそして 2 番目の要素は destination ノード特徴です。

ミニバッチ訓練では、与えられた多くの destination ノードからサンプリングされたサブグラフ上で計算が適用されます。サブグラフは DGL ではブロックと呼ばれます。メッセージ・パッシング後、それらの destination ノードだけが更新されます、何故ならばそれらは元の full グラフで持つのと同じ近傍を持つからです。ブロック作成段階では、dst ノードはノードリストの最前部にあります。feat_dst をインデックス [0:g.number_of_dst_nodes()] で見つけられます。

feat_src と feat_dst を決定した後、上の 3 つのグラフ型のための計算は同じです。

 

メッセージ・パッシングと reducing

if self._aggre_type == 'mean':
    graph.srcdata['h'] = feat_src
    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
    check_eq_shape(feat)
    graph.srcdata['h'] = feat_src
    graph.dstdata['h'] = feat_dst     # same as above if homogeneous
    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
    # divide in_degrees
    degs = graph.in_degrees().to(feat_dst)
    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
else:
    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
    rst = self.fc_neigh(h_neigh)
else:
    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

コードは実際にはメッセージ・パッシングと reducing 計算を行ないます。コードのこのパートはモジュール毎に様々です。2 章 メッセージ・パッシング で説明されているように、上のコードの総てのメッセージ・パッシングは DGL のパフォーマンス最適化を完全に利用するために update_all() API と組込みメッセージ/reduce 関数を使用して実装されていることに注意してください。

 

出力のために reducing の後特徴を更新する

# activation
if self.activation is not None:
    rst = self.activation(rst)
# normalization
if self.norm is not None:
    rst = self.norm(rst)
return rst

forward() 関数の最後のパートは reduce 関数の後で特徴を更新することです。一般的な更新演算はオブジェクト構築段階で設定されたオプションに従って活性化関数と正規化を適用します。

 

異質な GraphConv モジュール

dgl.nn.pytorch.HeteroGraphConv は異質グラフ上で DGL NN モジュールを実行するモジュールレベルのカプセル化です。実装ロジックはメッセージ・パッシング・レベルの API multi_update_all() と同じです :

  • 各リレーション r 内の DGL nn モジュール。
  • 複数のリレーションからの同じノード型上の結果をマージする reduction。

これは次のように定式化できます :

\[
h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))
\]

ここで $f_r$ は各リレーション $r$ のための NN モジュールで、$AGG$ は aggregation (集約) 関数です。

 

HeteroGraphConv 実装ロジック

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

ヘテログラフ畳込みは各リレーションを nn モジュールにマップする辞書 mods を取ります。そして複数のリレーションから同じノード型上の結果を集約する関数を設定します。

def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
    if mod_args is None:
        mod_args = {}
    if mod_kwargs is None:
        mod_kwargs = {}
    outputs = {nty : [] for nty in g.dsttypes}

入力グラフと入力 tensor に加えて、forward() 関数は 2 つの追加の辞書パラメータ mod_args と mod_kwargs を取ります。これら 2 つの辞書は self.mods と同じキーを持ちます。それらは、異なる型のリレーションのための self.mods の対応する NN モジュールを呼び出すときカスタマイズされたパラメータとして使用されます。

出力辞書は各 destination 型 nty のための出力 tensor を保持するために作成されます。各 nty のための値はリストで、一つ以上のリレーションが nty を destination 型として持つ場合、単一ノード型は複数の出力を得るかもしれないことを示すことに注意してください。更なる集約のためにそれらをリストに保持します。

if g.is_block:
    src_inputs = inputs
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
    src_inputs = dst_inputs = inputs

for stype, etype, dtype in g.canonical_etypes:
    rel_graph = g[stype, etype, dtype]
    if rel_graph.number_of_edges() == 0:
        continue
    if stype not in src_inputs or dtype not in dst_inputs:
        continue
    dstdata = self.mods[etype](
        rel_graph,
        (src_inputs[stype], dst_inputs[dtype]),
        *mod_args.get(etype, ()),
        **mod_kwargs.get(etype, {}))
    outputs[dtype].append(dstdata)

入力 g は異質グラフか異質グラフからのサブグラフ・ブロックであり得ます。普通の NN モジュールでのように、forward() 関数は異なる入力グラフ型を個別に扱う必要があります。

各リレーションは canonical_etype として表されます、これは (stype, etype, dtype) です。canonical_etype をキーとして使用し、2 部グラフ rel_graph を抽出できます。2 部グラフについて、入力特徴はタプル (src_inputs[stype], dst_inputs[dtype]) として体系化されます。各リレーションのための NN モジュールが呼び出されて出力はセーブされます。不必要な呼び出しを避けるため、エッジやその src 型を持つノードを持たないリレーションはスキップされます。

rsts = {}
for nty, alist in outputs.items():
    if len(alist) != 0:
        rsts[nty] = self.agg_fn(alist, nty)

最後に、複数のリレーションから同じ destination ノード型上の結果は self.agg_fn 関数を使用して集約されます。サンプルは dgl.nn.pytorch.HeteroGraphConv のための API Doc で見つけられます。

 

以上






クラスキャット

最近の投稿

  • LangGraph Platform : Get started : クイックスタート
  • LangGraph Platform : 概要
  • LangGraph : Prebuilt エージェント : ユーザインターフェイス
  • LangGraph : Prebuilt エージェント : 配備
  • LangGraph : Prebuilt エージェント : マルチエージェント

タグ

AutoGen (13) ClassCat Press Release (20) ClassCat TF/ONNX Hub (11) DGL 0.5 (14) Eager Execution (7) Edward (17) FLUX.1 (16) Gemini (20) HuggingFace Transformers 4.5 (10) HuggingFace Transformers 4.6 (7) HuggingFace Transformers 4.29 (9) Keras 2 Examples (98) Keras 2 Guide (16) Keras 3 (10) Keras Release Note (17) Kubeflow 1.0 (10) LangChain (45) LangGraph (20) MediaPipe 0.8 (11) Model Context Protocol (16) NNI 1.5 (16) OpenAI Agents SDK (8) OpenAI Cookbook (13) OpenAI platform (10) OpenAI platform 1.x (10) OpenAI ヘルプ (8) TensorFlow 2.0 Advanced Tutorials (33) TensorFlow 2.0 Advanced Tutorials (Alpha) (15) TensorFlow 2.0 Advanced Tutorials (Beta) (16) TensorFlow 2.0 Guide (10) TensorFlow 2.0 Guide (Alpha) (16) TensorFlow 2.0 Guide (Beta) (9) TensorFlow 2.0 Release Note (12) TensorFlow 2.0 Tutorials (20) TensorFlow 2.0 Tutorials (Alpha) (14) TensorFlow 2.0 Tutorials (Beta) (12) TensorFlow 2.4 Guide (24) TensorFlow Deploy (8) TensorFlow Get Started (7) TensorFlow Graphics (7) TensorFlow Probability (9) TensorFlow Programmer's Guide (22) TensorFlow Release Note (18) TensorFlow Tutorials (33) TF-Agents 0.4 (11)
2020年9月
月 火 水 木 金 土 日
 123456
78910111213
14151617181920
21222324252627
282930  
« 7月   10月 »
© 2025 ClasCat® AI Research | Powered by Minimalist Blog WordPress Theme