DGL 0.5ユーザガイド : 3 章 GNN モジュールをビルドする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/18/2020 (0.5.1)
* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
ユーザガイド : 3 章 GNN モジュールをビルドする
DGL NN モジュールは貴方の GNN モデルのためのビルディング・ブロックです。使用中の DNN フレームワーク・バックエンドに依拠して、それは PyTorch の NN モジュール、MXNet Gluon の NN ブロックと TensorFlow の Keras 層から継承しています。DGL NN モジュールでは、forward 関数の構築関数と tensor 演算のパラメータ登録はバックエンド・フレームワークと同じです。このようにして、DGL コードはバックエンド・フレームワーク・コードにシームレスに統合できます。主要な違いは、DGL に固有なメッセージ・パッシング演算にあります。
DGL は多くの一般に使用される Conv 層, Dense Conv 層, Global Pooling 層、そして ユティリティ・モジュール を統合しました。We welcome your contribution!
このセクションでは、貴方自身の DGL NN モジュールをどのようにビルドするかを紹介するためのサンプルとして SAGEConv を PyTorch バックエンドで利用します。
DGL NN モジュール構築 (= Construction) 関数
構築関数は以下を行ないます :
- オプションを設定する。
- 学習可能なパラメータやサブモジュールを登録する。
- パラメータをリセットする。
import torch as th from torch import nn from torch.nn import init from .... import function as fn from ....base import DGLError from ....utils import expand_as_pair, check_eq_shape class SAGEConv(nn.Module): def __init__(self, in_feats, out_feats, aggregator_type, bias=True, norm=None, activation=None): super(SAGEConv, self).__init__() self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._aggre_type = aggregator_type self.norm = norm self.activation = activation
構築関数では、最初にデータ次元を設定する必要があります。一般的な PyTorch モジュールのためには、次元は通常は入力次元、出力次元そして隠れ次元です。グラフ・ニューラルに対しては、入力次元はソースノード次元と destination ノード次元に分けることができます。
データ次元に加えて、グラフ・ニューラルネットワークのための典型的なオプションは aggregation 型 (self._aggre_type) です。aggregation 型は異なるエッジ上のメッセージがある destination ノードのためにどのように集約されるかを決定します。一般に利用される aggregation 型は mean, sum, max, min を含みます。幾つかのモジュールは lstm のようなより複雑な aggregation を適用するかもしれません。
ここで norm は特徴正規化のための callable 関数です。SAGEConv ペーパーでは、そのような正規化は l2 norm であり得ます : \(h_v = h_v / \lVert h_v \rVert_2\)。
# aggregator type: mean, max_pool, lstm, gcn if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']: raise KeyError('Aggregator type {} not supported.'.format(aggregator_type)) if aggregator_type == 'max_pool': self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) if aggregator_type == 'lstm': self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) if aggregator_type in ['mean', 'max_pool', 'lstm']: self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) self.reset_parameters()
パラメータとサブモジュールを登録します。SAGEConv では、サブモジュールは aggregation 型に従って様々です。これらのモジュールは nn.Linear, nn.LSTM 等のような純粋な PyTorch nn モジュールです。構築関数の最後に、reset_parameters() を呼び出すことにより重み初期化が適用されます。
def reset_parameters(self): """Reinitialize learnable parameters.""" gain = nn.init.calculate_gain('relu') if self._aggre_type == 'max_pool': nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) if self._aggre_type == 'lstm': self.lstm.reset_parameters() if self._aggre_type != 'gcn': nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
DGL NN モジュール Forward 関数
NN モジュールでは、forward() 関数が実際のメッセージ・パッシングと計算を行ないます。パラメータとして通常は tensor を取るPyTorch の NN モジュールと比べて、DGL NN モジュールは追加のパラメータ dgl.DGLGraph を取ります。forward() 関数のための作業負荷は 3 つのパートに分割できます :
- グラフ確認とグラフ型仕様。
- メッセージ・パッシングと reducing。
- 出力のために reduce の後で特徴を更新する。
SAGEConv サンプルの forward() 関数を深く調べましょう。
グラフ確認とグラフ型仕様
def forward(self, graph, feat): with graph.local_scope(): # Specify graph type then expand input feature according to graph type feat_src, feat_dst = expand_as_pair(feat, graph)
forward() は計算とメッセージ・パッシングで不正な値に導く可能性がある入力の多くの扱いにくいケース (= corner cases) を処理する必要があります。GraphConv のような conv モジュールでの一つの典型的なチェックは入力グラフに 0-in-degree ノードがないことを検証することです。ノードが 0-in-degree を持つとき、mailbox は空となり reduce 関数は総てゼロの値を生成します。これはモデル性能において静かな regression を引き起こすかもしれません。けれども、SAGEConv モジュールでは、集約 (= aggregated) 表現は元のノード特徴と連結されて、forward() の出力は総てゼロではありません。この場合にはそのようなチェックは必要ありません。
DGL NN モジュールは以下を含む異なる型のグラフ入力に渡り再利用可能であるべきです : 均質グラフ、異質グラフ (1.5 異質グラフ)、サブグラフ・ブロック (Chapter 6: Stochastic Training on Large Graphs)。
SAGEConv のための数式は :
\[
h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate}
\left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right) \\
h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}
(h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1} + b) \right)\\
h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l})
\]
グラフ型に従ってソースノード特徴 feat_src と destination ノード特徴 feat_dst を指定する必要があります。グラフ型を指定して feat を feat_src と feat_dst 内に拡張 (= expand) するための関数は expand_as_pair() です。この関数の詳細は下で示されます。
def expand_as_pair(input_, g=None): if isinstance(input_, tuple): # Bipartite graph case return input_ elif g is not None and g.is_block: # Subgraph block case if isinstance(input_, Mapping): input_dst = { k: F.narrow_row(v, 0, g.number_of_dst_nodes(k)) for k, v in input_.items()} else: input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes()) return input_, input_dst else: # Homograph case return input_, input_
均質グラフ全体の訓練については、ソースノードと destination ノードは同じです。それらは総てグラフのノードです。
異質 (グラフ) なケースについては、グラフは幾つかの 2 部グラフに分割できます、各リレーションに対して一つです。リレーションは (src_type, edge_type, dst_dtype) として表されます。入力特徴 feat がタプルであることを識別するとき、グラフを 2 部として扱います。タプルの最初の要素はソースノード特徴でそして 2 番目の要素は destination ノード特徴です。
ミニバッチ訓練では、与えられた多くの destination ノードからサンプリングされたサブグラフ上で計算が適用されます。サブグラフは DGL ではブロックと呼ばれます。メッセージ・パッシング後、それらの destination ノードだけが更新されます、何故ならばそれらは元の full グラフで持つのと同じ近傍を持つからです。ブロック作成段階では、dst ノードはノードリストの最前部にあります。feat_dst をインデックス [0:g.number_of_dst_nodes()] で見つけられます。
feat_src と feat_dst を決定した後、上の 3 つのグラフ型のための計算は同じです。
メッセージ・パッシングと reducing
if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn': check_eq_shape(feat) graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) # divide in_degrees degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'max_pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) # GraphSAGE GCN does not require fc_self. if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
コードは実際にはメッセージ・パッシングと reducing 計算を行ないます。コードのこのパートはモジュール毎に様々です。2 章 メッセージ・パッシング で説明されているように、上のコードの総てのメッセージ・パッシングは DGL のパフォーマンス最適化を完全に利用するために update_all() API と組込みメッセージ/reduce 関数を使用して実装されていることに注意してください。
出力のために reducing の後特徴を更新する
# activation if self.activation is not None: rst = self.activation(rst) # normalization if self.norm is not None: rst = self.norm(rst) return rst
forward() 関数の最後のパートは reduce 関数の後で特徴を更新することです。一般的な更新演算はオブジェクト構築段階で設定されたオプションに従って活性化関数と正規化を適用します。
異質な GraphConv モジュール
dgl.nn.pytorch.HeteroGraphConv は異質グラフ上で DGL NN モジュールを実行するモジュールレベルのカプセル化です。実装ロジックはメッセージ・パッシング・レベルの API multi_update_all() と同じです :
- 各リレーション r 内の DGL nn モジュール。
- 複数のリレーションからの同じノード型上の結果をマージする reduction。
これは次のように定式化できます :
\[
h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))
\]
ここで $f_r$ は各リレーション $r$ のための NN モジュールで、$AGG$ は aggregation (集約) 関数です。
HeteroGraphConv 実装ロジック
class HeteroGraphConv(nn.Module): def __init__(self, mods, aggregate='sum'): super(HeteroGraphConv, self).__init__() self.mods = nn.ModuleDict(mods) if isinstance(aggregate, str): self.agg_fn = get_aggregate_fn(aggregate) else: self.agg_fn = aggregate
ヘテログラフ畳込みは各リレーションを nn モジュールにマップする辞書 mods を取ります。そして複数のリレーションから同じノード型上の結果を集約する関数を設定します。
def forward(self, g, inputs, mod_args=None, mod_kwargs=None): if mod_args is None: mod_args = {} if mod_kwargs is None: mod_kwargs = {} outputs = {nty : [] for nty in g.dsttypes}
入力グラフと入力 tensor に加えて、forward() 関数は 2 つの追加の辞書パラメータ mod_args と mod_kwargs を取ります。これら 2 つの辞書は self.mods と同じキーを持ちます。それらは、異なる型のリレーションのための self.mods の対応する NN モジュールを呼び出すときカスタマイズされたパラメータとして使用されます。
出力辞書は各 destination 型 nty のための出力 tensor を保持するために作成されます。各 nty のための値はリストで、一つ以上のリレーションが nty を destination 型として持つ場合、単一ノード型は複数の出力を得るかもしれないことを示すことに注意してください。更なる集約のためにそれらをリストに保持します。
if g.is_block: src_inputs = inputs dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} else: src_inputs = dst_inputs = inputs for stype, etype, dtype in g.canonical_etypes: rel_graph = g[stype, etype, dtype] if rel_graph.number_of_edges() == 0: continue if stype not in src_inputs or dtype not in dst_inputs: continue dstdata = self.mods[etype]( rel_graph, (src_inputs[stype], dst_inputs[dtype]), *mod_args.get(etype, ()), **mod_kwargs.get(etype, {})) outputs[dtype].append(dstdata)
入力 g は異質グラフか異質グラフからのサブグラフ・ブロックであり得ます。普通の NN モジュールでのように、forward() 関数は異なる入力グラフ型を個別に扱う必要があります。
各リレーションは canonical_etype として表されます、これは (stype, etype, dtype) です。canonical_etype をキーとして使用し、2 部グラフ rel_graph を抽出できます。2 部グラフについて、入力特徴はタプル (src_inputs[stype], dst_inputs[dtype]) として体系化されます。各リレーションのための NN モジュールが呼び出されて出力はセーブされます。不必要な呼び出しを避けるため、エッジやその src 型を持つノードを持たないリレーションはスキップされます。
rsts = {} for nty, alist in outputs.items(): if len(alist) != 0: rsts[nty] = self.agg_fn(alist, nty)
最後に、複数のリレーションから同じ destination ノード型上の結果は self.agg_fn 関数を使用して集約されます。サンプルは dgl.nn.pytorch.HeteroGraphConv のための API Doc で見つけられます。
以上