DGL 0.5ユーザガイド : 5 章 グラフ・ニューラルネットワークを訓練する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/20/2020 (0.5.1)
* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
ユーザガイド : 5 章 グラフ・ニューラルネットワークを訓練する
概要
この章は 2章: メッセージ・パッシング で紹介されたメッセージ・パッシング法と 3 章: GNN モジュールを構築する で導入されたニューラルネットワークによりノード分類、エッジ分類、リンク予測、そして小さいグラフのためのグラフ分類のためにグラフ・ニューラルネットワークをどのように訓練するかを議論します。
この章は貴方のグラフそしてそのノードとエッジ特徴の総てが GPU に収められることを仮定しています ;そうでない場合には Chapter 6: Stochastic Training on Large Graphs を見てください。
以下のテキストはグラフとノード/エッジ特徴が既に準備されていることを仮定しています。DGL が提供するデータセットや 4 章: グラフ・データパイプライン で説明されている他の互換な DGLDataset を使用することを計画している場合、次のような何かで単一グラフのデータセットのためにグラフを得ることができます :
import dgl dataset = dgl.data.CiteseerGraphDataset() graph = dataset[0]
Note: この章ではバックエンドとして PyTorch を利用します。
異質グラフ
時に異質グラフ上で作業したいでしょう。ここではノード分類、エッジ分類とリンク予測タスクのためのサンプルとして合成の異質グラフを取ります。
合成異質グラフ hetero_graph はこれらのエッジ型を持ちます :
- (‘user’, ‘follow’, ‘user’)
- (‘user’, ‘followed-by’, ‘user’)
- (‘user’, ‘click’, ‘item’)
- (‘item’, ‘clicked-by’, ‘user’)
- (‘user’, ‘dislike’, ‘item’)
- (‘item’, ‘disliked-by’, ‘user’)
import numpy as np
import torch
n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10
follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)
hetero_graph = dgl.heterograph({
('user', 'follow', 'user'): (follow_src, follow_dst),
('user', 'followed-by', 'user'): (follow_dst, follow_src),
('user', 'click', 'item'): (click_src, click_dst),
('item', 'clicked-by', 'user'): (click_dst, click_src),
('user', 'dislike', 'item'): (dislike_src, dislike_dst),
('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})
hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
# randomly generate training masks on user nodes and click edges
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)
以上