DGL 0.5ユーザガイド : 5 章 訓練 : 5.3 リンク予測 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/22/2020 (0.5.2)
* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
ユーザガイド : 5 章 訓練 : 5.3 リンク予測
幾つかの他の設定では 2 つの与えられたノードの間にエッジが存在するか否かを予測することを望むかもしれません。そのようなモデルはリンク予測モデルと呼称します。
概要
GNN ベースのリンク予測モデルは 2 つのノード $u$ と $v$ の間の接続性の尤度を (それらの多層 GNN から計算されたノード表現) \(\boldsymbol{h}_u^{(L)}\) と \(\boldsymbol{h}_v^{(L)}\) の関数として表します。
\[
y_{u,v} = \phi(\boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)})
\]
このセクションでは \(y_{u,v}\) ノード $u$ とノード $v$ の間のスコアを参照します。
リンク予測モデルを訓練することは、エッジにより接続されるノード間のスコアをノードの任意のペアの間のスコアに対して比較することを伴います。例えば、$u$ と $v$ に接続するエッジが与えられたとき、ノード $u$ と $v$ の間のスコアがノード $u$ と (任意のノイズ分布 \(v’ \sim P_n(v)\) から) サンプリングされたノード \(v’\) の間のスコアよりも高いことを促進します (= encourage)。そのような方法はネガティブ・サンプリングと呼ばれます。
最小化されたときに上の挙動を獲得できる多くの損失関数があります。完全ではないリストは以下を含みます :
- 交差エントロピー損失: \(\mathcal{L} = – \log \sigma (y_{u,v}) – \sum_{v_i \sim P_n(v), i=1,\dots,k}\log \left[ 1 – \sigma (y_{u,v_i})\right]\)
- BPR 損失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} – \log \sigma (y_{u,v} – y_{u,v_i})\)
- Margin 損失: \(\mathcal{L} = \sum_{v_i \sim P_n(v), i=1,\dots,k} \max(0, M – y_{u, v} + y_{u, v_i})\), ここで \(M\) は定数ハイパーパラメータ。
暗黙的フィードバック (= implicit feedback) や ノイズ-contrastive 推定 が何であるかを知っていれば、このアイデアに馴染みがあることを見出すかもしれません。
エッジ分類と異なるモデル実装の差異
$u$ と $v$ の間のスコアを計算するニューラルネットワーク・モデルは 上で 説明されたエッジ回帰モデルと同一です。
エッジ上のスコアを計算する dot 積を使用するサンプルがここにあります。
class DotProductPredictor(nn.Module): def forward(self, graph, h): # h contains the node representations computed from the GNN defined # in the node classification section (Section 5.1). with graph.local_scope(): graph.ndata['h'] = h graph.apply_edges(fn.u_dot_v('h', 'h', 'score')) return graph.edata['score']
訓練ループ
スコア予測モデルはグラフ上で動作しますので、ネガティブ・サンプルをもう一つのグラフとして表現する必要があります。グラフはエッジとして総てのネガティブ・ノードペアを含みます。
次はネガティブ・サンプルをグラフとして表すサンプルを示します。各エッジ \((u,v)\) は $k$ ネガティブ・サンプル \((u,v_i)\) を得ます、そこでは \(v_i\) は一様分布からサンプリングされます。
def construct_negative_graph(graph, k): src, dst = graph.edges() neg_src = src.repeat_interleave(k) neg_dst = torch.randint(0, graph.number_of_nodes(), (len(src) * k,)) return dgl.graph((neg_src, neg_dst), num_nodes=graph.number_of_nodes())
エッジ・スコアを予測するモデルはエッジ分類/回帰のそれと同じです。
class Model(nn.Module): def __init__(self, in_features, hidden_features, out_features): super().__init__() self.sage = SAGE(in_features, hidden_features, out_features) self.pred = DotProductPredictor() def forward(self, g, neg_g, x): h = self.sage(g, x) return self.pred(g, h), self.pred(neg_g, h)
それから訓練ループはネガティブ・グラフを繰り返し構築して損失を計算します。
def compute_loss(pos_score, neg_score): # Margin loss n_edges = pos_score.shape[0] return (1 - neg_score.view(n_edges, -1) + pos_score.unsqueeze(1)).clamp(min=0).mean() node_features = graph.ndata['feat'] n_features = node_features.shape[1] k = 5 model = Model(n_features, 100, 100) opt = torch.optim.Adam(model.parameters()) for epoch in range(10): negative_graph = construct_negative_graph(graph, k) pos_score, neg_score = model(graph, negative_graph, node_features) loss = compute_loss(pos_score, neg_score) opt.zero_grad() loss.backward() opt.step() print(loss.item())
訓練後、ノード表現は次を通して得られます :
node_embeddings = model.sage(graph, node_features)
ノード埋め込みを利用する複数の方法があります。サンプルは訓練ダウンストリーム分類器、あるいは適切な (= relevant) エンティティ・レコメンデーションのための最近傍探索や最大内積探索を行なうことを含みます。
異質グラフ
異質グラフ上のリンク予測は均質グラフ上のそれと大きくは違いません。以下は一つのエッジ型上で予測していることを仮定しますが、それを多重エッジ型に拡張することは容易です。
例えば、リンク予測のためのエッジ型のエッジのスコアを計算するために 上の HeteroDotProductPredictor を再利用できます。
class HeteroDotProductPredictor(nn.Module): def forward(self, graph, h, etype): # h contains the node representations for each node type computed from # the GNN defined in the previous section (Section 5.1). with graph.local_scope(): graph.ndata['h'] = h graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) return graph.edges[etype].data['score']
ネガティブ・サンプリングを遂行するために、(その上でリンク予測を遂行している) エッジ型のためのネガティブ・グラフを構築することができます。
def construct_negative_graph(graph, k, etype): utype, _, vtype = etype src, dst = graph.edges(etype=etype) neg_src = src.repeat_interleave(k) neg_dst = torch.randint(0, graph.number_of_nodes(vtype), (len(src) * k,)) return dgl.heterograph( {etype: (neg_src, neg_dst)}, num_nodes_dict={ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes})
モデルは異質グラフ上のエッジ分類のそれとは少し異なります、何故ならばリンク予測を遂行するところのエッジ型を指定する必要があるからです。
class Model(nn.Module): def __init__(self, in_features, hidden_features, out_features, rel_names): super().__init__() self.sage = RGCN(in_features, hidden_features, out_features, rel_names) self.pred = HeteroDotProductPredictor() def forward(self, g, neg_g, x, etype): h = self.sage(g, x) return self.pred(g, h, etype), self.pred(neg_g, h, etype)
訓練ループは均質グラフのそれと同様です。
def compute_loss(pos_score, neg_score): # Margin loss n_edges = pos_score.shape[0] return (1 - neg_score.view(n_edges, -1) + pos_score.unsqueeze(1)).clamp(min=0).mean() k = 5 model = Model(10, 20, 5, hetero_graph.etypes) user_feats = hetero_graph.nodes['user'].data['feature'] item_feats = hetero_graph.nodes['item'].data['feature'] node_features = {'user': user_feats, 'item': item_feats} opt = torch.optim.Adam(model.parameters()) for epoch in range(10): negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item')) pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item')) loss = compute_loss(pos_score, neg_score) opt.zero_grad() loss.backward() opt.step() print(loss.item())
以上