ホーム » DGL » DGL 0.5 ユーザガイド : 5 章 訓練 : 5.1 ノード分類/回帰

DGL 0.5 ユーザガイド : 5 章 訓練 : 5.1 ノード分類/回帰

DGL 0.5ユーザガイド : 5 章 訓練 : 5.1 ノード分類/回帰 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/20/2020 (0.5.2)

* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ユーザガイド : 5 章 訓練 : 5.1 ノード分類/回帰

グラフ・ニューラルネットワークのための最もポピュラーで広く採用されているタスクの一つはノード分類で、そこでは訓練/検証/テストの各ノードには事前定義されたカテゴリーのセットから正解カテゴリーが割当てられます。ノード回帰も同様で、そこでは訓練/検証/テストセットの各ノードには正解数字 (= number) が割当てられます。

 

概要

ノードを分類するために、グラフ・ニューラルネットワークはノード自身の特徴、更にはその近傍ノードとエッジ特徴を利用して 2 章: メッセージ・パッシング で議論されたメッセージ・パッシングを遂行します。メッセージ・パッシングは近傍のより大きい範囲からの情報を組込みために複数ラウンド繰り返すことができます。

 

ニューラルネットワーク・モデルを書く

DGL はメッセージ・パッシングの 1 ラウンドを遂行できる 2, 3 の組込みグラフ畳込みモジュールを提供します。このガイドでは、dgl.nn.pytorch.SAGEConv (MXNet と TensorFlow でもまた利用可能です) を選択します、GraphSAGE のためのグラフ畳込みモジュールです。

グラフ上の深層学習モデルのために通常は多層グラフ・ニューラルネットワークを必要とします、そこではマルチラウンドのメッセージ・パッシングを行ないます。これは次のようにグラフ畳込みモジュールをスタックして成すことができます。

# Contruct a two-layer GNN model
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

上のモデルをノード分類のためだけでなく、5.2 Edge Classification/Regression, 5.3 Link Prediction5.4 Graph Classification のような他のダウンストリーム・タスクのための隠れノード表現を得るためにも利用できます。

組込みグラフ畳込みモジュールの完全なリストについては、dgl.nn を参照してください。

DGL ニューラルネットワークがどのように動作するか、そしてメッセージ・パッシングでカスタム・ニューラルネットワーク・モジュールをどのように書くかのより詳細については、3 章: GNN モジュールを構築する のサンプルを参照してください。

 

訓練ループ

full グラフ上の訓練は上で定義されたモデルの順伝播と、予測を訓練ノード上の正解ラベルに対して比較することによる損失を計算することを単純に伴います。

このセクションは訓練ループを示すために DGL 組込みデータセット dgl.data.CiteseerGraphDataset を使用します。ノード特徴とラベルはそのグラフ・インスタンスにストアされて、訓練-検証-テスト分割もまた boolean マスクとしてグラフ上でストアされます。これは 4 章: グラフ・データパイプライン で見たものと同様です。

node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

次は貴方のモデルを精度で評価するサンプルです。

def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

それから以下のように訓練ループを書くことができます。

model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in this example.

GraphSAGE は end-to-end な均質グラフ・ノード分類サンプルを提供します。対応するモデル実装は、調整可能な数の層、dropout 確率とカスタマイズ可能な集約関数と非線形を伴うサンプルの GraphSAGE クラスにあることを見れるでしょう。

 

異質グラフ

貴方のグラフが異質である場合、総てのエッジ型に沿った近傍からメッセージを集めることを望むかもしれません。総てのエッジ型上でメッセージ・パッシングを遂行するためにモジュール dgl.nn.pytorch.HeteroGraphConv (MXNet と Tensorflow でも利用可能です) を利用できます、それから各エッジ型のために異なるグラフ畳込みモジュールを結合します。

次のコードは異質グラフ畳込みモジュールを定義します、それは最初に各エッジ型上で個別のグラフ畳込みを遂行してから、総てのノード型のための最終的な結果として各エッジ型上でメッセージ集約を総計します。

# Define a Heterograph Conv model
import dgl.nn as dglnn

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

dgl.nn.HeteroGraphConv は入力としてノード型とノード特徴 tensor の辞書を取り、そしてノード型とノード特徴のもう一つの辞書を返します。

そこで 異質グラフサンプル でユーザと項目 (= item) 特徴を持つと仮定します。

model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']

次のように単純に順伝播を遂行できます :

node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']

訓練ループは均質グラフのためのものと同じです、今は (そこから予測を計算する) ノード表現の辞書を持つことを除いて。例えば、ユーザノードだけを予測している場合、返された辞書からユーザノード埋め込みを単に抽出できます :

opt = torch.optim.Adam(model.parameters())

for epoch in range(5):
    model.train()
    # forward propagation by using all nodes and extracting the user embeddings
    logits = model(hetero_graph, node_features)['user']
    # compute loss
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    # Compute validation accuracy.  Omitted in this example.
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in the example.

DGL はノード分類のための RGCN の end-to-end なサンプルを提供します。モデル実装ファイル の RelGraphConvLayer で異質グラフ畳込みの定義を見ることができます。

 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com