ホーム » DGL » DGL 0.5 ユーザガイド : 5 章 訓練 : 5.4 グラフ分類

DGL 0.5 ユーザガイド : 5 章 訓練 : 5.4 グラフ分類

DGL 0.5ユーザガイド : 5 章 訓練 : 5.4 グラフ分類 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/22/2020 (0.5.2)

* 本ページは、DGL の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ユーザガイド : 5 章 訓練 : 5.4 グラフ分類

大きい単一グラフの代わりに、時に複数のグラフの形式のデータを持つかもしれません、例えば人々のコミュニティの異なる型のリストです。同じコミュニティの人々の間の友好関係をグラフで特徴付けることにより、分類するグラフのリストを得ます。このシナリオでは、グラフ分類モデルはコミュニティの型を識別する手助けができるでしょう、i.e. 構造と全体的な情報に基づいて各グラフを分類します。

 

概要

グラフ分類とノード分類やリンク予測間の主要な相違は予測結果が入力グラフ全体の特性を特徴付けることです。ちょうど前のタスクのようにノード/エッジに渡るメッセージパッシングを遂行しますが、グラフレベル表現を取得しようともします。

グラフ分類は次のように進みます :

グラフ分類プロセス

左から右へ、一般的な実践は :

  • グラフをグラフのバッチに準備する。
  • ノード/エッジ特徴を更新するためにバッチ化されたグラフ上でメッセージ・パッシングする
  • ノード/エッジ特徴をグラフレベル表現に集約する
  • タスクの方に進む分類

 

グラフのバッチ

通常はグラフ分類タスクは多くのグラフを訓練し、そしてモデルを訓練するとき一度に一つのグラフだけを使用する場合それは非常に非効率です。一般的な深層学習実践からミニバッチ訓練のアイデアを拝借して、複数のグラフのバッチをビルドしてそれらをまとめて一つの訓練反復のために送ることができます。

DGL では、グラフのリストの単一のバッチ化されたグラフをビルドできます。このバッチ化グラフは単純に単一の巨大なグラフとして利用できます、この際に個々のコンポーネントは対応する元の小さいグラフを表しています。


バッチ化されたグラフ

 

グラフ読み出し

データの総てのグラフはそのノードとエッジ特徴に加えて、独自の構造を持つかもしれません。単一の予測を行なうために、通常は可能な限り豊富な情報に渡り集約して要約します。このタイプの演算は Readout (読み出し) と命名されます。一般的な集約は総てのノードやエッジ特徴に渡る summation, average, maximum や minimum を含みます。

グラフ $g$ が与えられたとき、平均的な readout 集約を次のように定義できます :

\[
h_g = \frac{1}{|\mathcal{V}|}\sum_{v\in \mathcal{V}}h_v
\]

DGL では対応する関数呼び出しは dgl.readout_nodes() です。

ひとたび \(h_g\) が利用可能であれば、分類出力のためにそれを MLP 層に渡すことができます。

 

ニューラルネットワーク・モデルを書く

モデルへの入力はノードとエッジ特徴を持つバッチ化グラフです。注意すべき一つのことはバッチ化グラフのノードとエッジ特徴はバッチ次元を持たないことです。少しの特別なケアがモデルに置かれるべきです :

 

バッチ化グラフ上の計算

次に、バッチ化グラフの計算プロパティを議論します。

最初に、バッチの異なるグラフは完全に分離しています、i.e. 2 つのグラフに接続するエッジはありません。この良いプロパティにより、総てのメッセージ・パッシング関数は依然として同じ結果を持ちます。

2 番目に、バッチ化グラフ上の readout 関数は各グラフに渡り個別に処理されます。バッチサイズが $B$ で集約される特徴が次元 $D$ を持つと仮定すると、readout 結果の shape は \((B, D)\) となります。

g1 = dgl.graph(([0, 1], [1, 0]))
g1.ndata['h'] = torch.tensor([1., 2.])
g2 = dgl.graph(([0, 1], [1, 2]))
g2.ndata['h'] = torch.tensor([1., 2., 3.])

dgl.readout_nodes(g1, 'h')
# tensor([3.])  # 1 + 2

bg = dgl.batch([g1, g2])
dgl.readout_nodes(bg, 'h')
# tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]

最後に、バッチ化グラフ上の各ノード/エッジ特徴 tensor は総てのグラフからの対応する特徴 tensor を結合した形式にあります。

bg.ndata['h']
# tensor([1., 2., 1., 2., 3.])

 

モデル定義

上の計算ルールに気付けば、非常に単純なモデルを定義できます。

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
        self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, feat):
        # Apply graph convolution and activation.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = dgl.mean_nodes(g, 'h')
            return self.classify(hg)

 

訓練ループ

データ・ローディング

ひとたびモデルが定義されれば、訓練を開始できます。グラフ分類は大きな単一のグラフの代わりに多くの関連する小さいグラフを扱いますので、洗練されたグラフサンプリング・アルゴリズムを設計する必要なく、通常はグラフの確率的ミニバッチ上で効率的に訓練できます。

4 章: グラフ・データパイプライン で紹介されたようにグラフ分類データセットを持つことを仮定します。

import dgl.data
dataset = dgl.data.GINDataset('MUTAG', False)

グラフ分類データセットの各項目はグラフとそのラベルのペアです。グラフをバッチ処理するために collate 関数をカスタマイズし、DataLoader を活用することによりデータローディング過程を高速化できます :

def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    batched_labels = torch.tensor(labels)
    return batched_graph, batched_labels

それから DataLoader を作成することがでけいます、これはミニバッチでグラフのデータセットに渡り反復します。

from torch.utils.data import DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=1024,
    collate_fn=collate,
    drop_last=False,
    shuffle=True)

 

ループ

それから訓練ループは dataloader に渡り反復してモデルを更新することを単純に伴います。

model = Classifier(10, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
    for batched_graph, labels in dataloader:
        feats = batched_graph.ndata['feats']
        logits = model(batched_graph, feats)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()

DGL はグラフ分類のサンプルとして GIN を実装します。訓練ループは main.py の関数 train の内側にあります。モデル実装は、グラフ畳込み層として dgl.nn.pytorch.GINConv (MXNet と Tensorflow でも利用可能です) を使用して、バッチ正規化等のようなより多くのコンポーネントを持つ gin.py の内側にあります。

 

異質グラフ

異質グラフを持つグラフ分類は均質グラフを持つそれとは少し異なります。異質グラフ畳込みモジュールを必要とすることを除いて、readout 関数で異なる型のノードに渡り集約する必要もあります。

次は各ノード型のためのノード表現の平均を合計するサンプルを示します。

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()

        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)

コードの残りは均質グラフのためのそれと異なりません。

# etypes is the list of edge types as strings.
model = HeteroClassifier(10, 20, 5, etypes)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
    for batched_graph, labels in dataloader:
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
 

以上






AI導入支援 #2 ウェビナー

スモールスタートを可能としたAI導入支援   Vol.2
[無料 WEB セミナー] [詳細]
「画像認識 AI PoC スターターパック」の紹介
既に AI 技術を実ビジネスで活用し、成果を上げている日本企業も多く存在しており、競争優位なビジネスを展開しております。
しかしながら AI を導入したくとも PoC (概念実証) だけでも高額な費用がかかり取組めていない企業も少なくないようです。A I導入時には欠かせない PoC を手軽にしかも短期間で認知度を確認可能とするサービの紹介と共に、AI 技術の特性と具体的な導入プロセスに加え運用時のポイントについても解説いたします。
日時:2021年10月13日(水)
会場:WEBセミナー
共催:クラスキャット、日本FLOW(株)
後援:働き方改革推進コンソーシアム
参加費: 無料 (事前登録制)
人工知能開発支援
◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com