ホーム » DGL » DGL 0.5 ユーザガイド : 2 章 メッセージ・パッシング

DGL 0.5 ユーザガイド : 2 章 メッセージ・パッシング

DGL 0.5ユーザガイド : 2 章 メッセージ・パッシング (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/18/2020 (0.5.1)

* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ユーザガイド : 2 章 メッセージ・パッシング

メッセージ・パッシング・パラダイム

\(x_v\in\mathbb{R}^{d_1}\) をノード $v$ のための特徴とし、そして \(w_{e}\in\mathbb{R}^{d_2}\) をエッジ \(({u}, {v})\) のための特徴とします。メッセージ・パッシング・パラダイム は次のステップ \(t+1\) におけるノード-wise とエッジ-wise な計算を定義します。

\[
\text{Edge-wise: } m_{e}^{(t+1)} = \phi \left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \right) , ({u}, {v},{e}) \in \mathcal{E}.\\
\text{Node-wise: } x_v^{(t+1)} = \psi \left(x_v^{(t)}, \rho\left(\left\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \in \mathcal{E} \right\rbrace \right) \right).
\]

上の等式で、\(\phi\) はエッジ特徴をその付随するノードの特徴と結合することによりメッセージを生成するために 各エッジ上で定義される メッセージ関数 です ; \(\psi\) は reduce 関数 \(\rho\) を使用して incoming メッセージを収集することによりノード特徴を更新するために 各ノード上で定義される 更新関数 です。

 

組込み関数とメッセージ・パッシング API

DGL では、メッセージ関数 は単一引数 edges を取ります、これはソースノード、destination ノードとエッジの特徴にアクセスするために 3 つのメンバー src, dst と data をそれぞれ取ります。

reduce 関数 は単一引数 nodes を取ります。ノードは、その近傍がエッジを通してそれに送るメッセージを集めるためにその mailbox にアクセスできます。最も一般的な reduce 演算の幾つかは sum, max, min 等を含みます。

更新関数は単一引数 nodes を取ります。この関数は、典型的には最後のステップにおけるノードの特徴と結合された、reduce 関数からの集合結果上で作用し、出力をノード特徴としてセーブします。

DGL は一般に使用されるメッセージ関数と reduce 関数を名前空間 dgl.function で 組込み として実装しました。一般に、可能なときにはいつでも 組込み関数を使用することを提案します、何故ならばそれらは大いに最適化されて次元ブロードキャスティングを自動的に扱うからです。

貴方のメッセージ・パッシング関数が組込みで実装できない場合には、ユーザ定義メッセージ/reduce 関数を実装できます (aka. UDF)。

組込みメッセージ関数は unary (単項) かバイナリであり得ます。unary については今のところ copy をサポートしています。バイナリ関数については、今は add, sub, mul, div, dot をサポートします。メッセージ組込み関数のための名前付け慣習としては u は src ノードを表し、v は dst ノードを表し、e はエッジを表します。これらの関数のためのパラメータは相当するノードとエッジのための入力と出力フィールド名を示す文字列です。ここにサポートされる組込み関数の dgl.function があります。例えば、src ノードから hu 特徴をそして dst ノードから hv 特徴を追加してからフィールドのエッジ上の結果をセーブするために、組込み関数 dgl.function.u_add_v(‘hu’, ‘hv’, ‘he’) を利用できます、これは次のメッセージ UDF に等値です :

def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

組込みの reduce 関数は演算 sum, max, min, prod と mean をサポートします。reduce 関数は通常は 2 つのパラメータを持ちます、一つは mailbox のフィールド名のため、一つは destination のフィールド名のためで、両者は文字列です。例えば、dgl.function.sum(‘m’, ‘h’) はメッセージ m を合計する Reduce UDF に等値です :

import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

DGL では、エッジ-wise な計算を呼び出すインターフェイスは apply_edges() です。apply_edges のためのパラメータは API Doc で説明されているようにメッセージ関数と正当なエッジ型です (デフォルトでは、総てのエッジは更新されます)。例えば :

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

ノード-wise な計算を呼び出すインターフェイスは update_all() です。update_all のためのパラメータはメッセージ関数、reduce 関数と更新関数です。3 番目のパラメータを空としておくことで更新関数は update_all の外側で呼び出すこともできます。これは提案されます、何故ならば更新関数はコードを簡潔にするために純粋な tensor 演算として通常は書かれるからです。例えば :

def updata_all_example(graph):
    # store the result in graph.ndata['ft']
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                    fn.sum('m', 'ft'))
    # Call update function outside of update_all
    final_ft = graph.ndata['ft'] * 2
    return final_ft

この呼び出しはソースノード特徴 ft とエッジ特徴 a を乗算することによりメッセージ m を生成し、ノード特徴 ft を更新するためにメッセージ m を合計し、最後に結果 final_ft を得るために ft に 2 を乗算します。呼び出し後、中間メッセージ m はクリーンアップされます。上の関数に対する数式は :

\[
{final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ij})
\]

update_all はメッセージ生成をマージする高位 API で、メッセージ reduction とノード更新は単一呼び出し、これは下で説明されるように最適化のための余地を残します。

 

効率的なメッセージ・パッシング・コードを書く

DGL はメッセージ・パッシングのためのメモリ消費と計算スピードを最適化します。最適化は以下を含みます :

  • マルチカーネルを単一の一つにマージする : これは、複数の組込み関数を一度に呼び出すために update_all を使用することにより達成されます。(スピード最適化)
  • ノードとエッジ上の並列性 : DGL はエッジ-wise 計算 apply_edges を一般化されてサンプリングされた dense-dense 行列乗算 (gSDDMM) 演算として抽象化してエッジに渡る計算を並列化します。同様に、DGL はノード-wise 計算 update_all を一般化された sparse-dense 行列乗算 (gSPMM) 演算として抽象化してノードに渡る計算を並列化します。(スピード最適化)
  • エッジへの不要なメモリコピーを回避する : ソースと destination ノードからの特徴を必要とするメッセージを生成するため、一つの選択肢はソースと destination ノード特徴をそのエッジにコピーすることです。幾つかのグラフについては、エッジの数はノードの数よりも遥かに大きいです。このコピーはコスト高であり得ます。DGL 組込みメッセージ関数はノード特徴をエントリ・インデックスを使用してサンプリングすることによりこのメモリコピーを回避します。(メモリとスピード最適化)
  • エッジ上の特徴ベクトルの具体化を回避する : 完全なメッセージパッシング過程はメッセージ生成、メッセージ reduction とノード更新を含みます。update_all 呼び出しでは、メッセージ関数と reduce 関数はそれらの関数が組込みであれば一つのカーネルにマージされます。エッジ上のメッセージ具体化はありません。(メモリ最適化)

上に従えば、それらの最適化を活用する一般的な方法は貴方自身のメッセージ・パッシング機能を update_all 呼び出しのパラメータとしての組込み関数との結合として構築することです。

エッジ上にメッセージをセーブしなければならない GATConv のような幾つかのケースについては、apply_edges を組込み関数とともに呼び出す必要があります。時にエッジ上のメッセージは高次元でありえて、これはメモリ消費的です。edata 次元をできるだけ低く保つことを提案します。

エッジ上の演算をノードに分割することによりこれをどのように成すかのサンプルがここにあります。この選択肢は以下を行ないます : src 特徴と dst 特徴を結合してから、線形層を適用します、i.e. \(W\times (u || v)\)。src と dst 特徴次元が高い一方で、線形層出力次元は低いです。straight forward な実装は次のようなものです :

linear = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim*2)))
def concat_message_function(edges):
    {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear

提案される実装は線形演算を 2 つに分割します、一つは src 特徴上で適用されて、他方は dst 特徴上で適用されます。最後のステージでエッジ上の線形演算の出力を追加します、i.e. \(W \times (u||v) = W_l \times u + W_r \times v\) ですから、\(W_l\times u + W_r \times v\) を遂行します、そこでは \(W_l\) と \(W_r\) はそれぞれ行列 \(W\) の左と右半分です :

linear_src = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

上の 2 つの実装は数学的に同値です。後者は遥かに効率的です、何故ならばエッジ上の feat_src and feat_dst をセーブする必要がないからです、これはメモリ効率的ではありません。更に、加算は DGL の組込み関数 u_add_v で最適化できるでしょう、これは計算を更に高速化してメモリフットプリントをセーブします。

 

グラフの一部上でメッセージ・パッシングを適用する

グラフのノードの一部だけを更新することを望む場合、その実践は update に含めたいノードのための id を提供することによりサブグラフを作成してから、サブグラフ上で update_all を呼び出すことです。例えば :

nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)

これはミニバッチ訓練における一般的な使用方法です。より詳細な使用方法については Chapter 6: Stochastic Training on Large Graphs ユーザガイドを確認してください。

 

メッセージ・パッシングでエッジ重みを適用する

GNN モデリングにおける一般に見られる実践はメッセージ集約の前にメッセージ上でエッジ重みを適用することです、例えば GAT と幾つかの GCN 亜種 でです。DGL で、これを扱う方法は :

  • 重みをエッジ特徴としてセーブする。
  • メッセージ関数でエッジ特徴をソースノード特徴で乗算する。

例えば :

graph.edata['a'] = affinity
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                 fn.sum('m', 'ft'))

上では、エッジ重みとして affinity を使用しています。エッジ重みは通常はスカラーです。

 

異質グラフ上のメッセージ・パッシング

異質 (グラフ) (1.5 異質グラフ のためのユーザガイド)、あるいは短くヘテログラフはノードとエッジの異なる型を含むグラフです。異なる型のノードとエッジは、各ノードとエッジ型の特質を捕捉するために設計された異なる型の属性を持つ傾向があります。グラフ・ニューラルネットワークのコンテキスト内では、それらの複雑さに依拠して、特定のノードとエッジ型が異なる数の次元を持つ表現でモデル化される必要があります。

ヘテログラフ上のメッセージ・パッシングは 2 つのパートに分割できます :

  1. 各関係 r 内のメッセージ計算と集約。
  2. 複数の関係からの同じノード型上の結果をマージする reduction。

ヘテログラフ上のメッセージ・パッシングを呼び出す DGL のインターフェイスは multi_update_all() です。multi_update_all は (関係をキーとして使用して) 各関係内の update_all のためのパラメータを含む辞書、そして 交差 (= cross) 型 reducer を表す文字列を取ります。reducer は sum, min, max, mean, stack の一つであり得ます。ここにサンプルがあります :

for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # Save it in graph for message passing
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # Specify per-relation message passing functions: (message_func, reduce_func).
    # Note that the results are saved to the same destination feature 'h', which
    # hints the type wise reducer for aggregation.
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types.
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com