Keras 2 : examples : グラフデータ – グラフニューラルネットによるノード分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/03/2022 (keras 2.9.0)
* 本ページは、Keras の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Code examples : Graph Data : Node Classification with Graph Neural Networks (Author: Khalid Salama)
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Keras 2 : examples : グラフデータ – グラフニューラルネットによるノード分類
Description : 引用が与えられたとき論文のトピックを予測するためにグラフ・ニューラルネットワーク・モデルを実装します。
イントロダクション
様々な機械学習 (ML) アプリケーションの多くのデータセットはそれらのエンティティ間で構造的な関係性を持ち、これはグラフとして表すことができます。そのようなアプリケーションはソーシャルとコミュニケーション・ネットワーク分析、交通量予測、そして不正検出を含みます。グラフ表現学習 は、様々な ML タスクに対して使用されるグラフデータセットのためのモデルを構築して訓練することを目的としています。
このサンプルは グラフニューラルネットワーク (GNN) モデルの単純な実装を実演します。モデルは、論文の単語と引用ネットワークが与えられたときその主題を予測する、Cora データセット 上でのノード予測タスクのために使用されます。
グラフ畳み込み層がどのように動作するかのより良い理解を提供するためにそれをゼロから実装することに注意してください。しかし、Spectral, StellarGraph と GraphNets のように、rich な GNN API を提供する多くの TensorFlow ベースの専用ライブラリがあります。
セットアップ
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
データセットの準備
Core データセットは、7 つのクラスの一つに分類された 2,708 の科学論文から成ります。引用 (= citation) ネットワークは 5,429 のリンクから成ります。各論文はサイズ 1,433 の二値単語ベクトルを持ち、対応する単語の存在を示します。
データセットのダウンロード
データセットは 2 つのタブ区切りファイルを持ちます : cora.cites と cora.content です。
- cora.cites は 2 つのカラム : cited_paper_id (ターゲット) と citing_paper_id (ソース) を持つ引用レコードを含みます。
- cora.content は 1,435 カラム : paper_id, subject と 1,433 二値特徴を持つ論文コンテンツレコードを含みます。
データセットをダウンロードしましょう
zip_file = keras.utils.get_file(
fname="cora.tgz",
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
extract=True,
)
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
データセットの処理と可視化
そして引用データを Pandas の DataFrame にロードします。
citations = pd.read_csv(
os.path.join(data_dir, "cora.cites"),
sep="\t",
header=None,
names=["target", "source"],
)
print("Citations shape:", citations.shape)
Citations shape: (5429, 2)
次に引用 DataFrame のサンプルを表示します。ターゲットカラムはソースカラムの論文 id により引用された論文 id を含みます。
citations.sample(frac=1).head()
そして論文データを Pandas の DataFrame にロードしましょう。
column_names = ["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"]
papers = pd.read_csv(
os.path.join(data_dir, "cora.content"), sep="\t", header=None, names=column_names,
)
print("Papers shape:", papers.shape)
Papers shape: (2708, 1435)
そして論文 DataFrame のサンプルを表示します。DataFrame は paper_id と subject カラム、そして用語 (= term) が論文に存在するか否かを表す 1,433 二値カラムを含みます。
print(papers.sample(5).T)
1 133 2425 \ paper_id 1061127 34355 1108389 term_0 0 0 0 term_1 0 0 0 term_2 0 0 0 term_3 0 0 0 ... ... ... ... term_1429 0 0 0 term_1430 0 0 0 term_1431 0 0 0 term_1432 0 0 0 subject Rule_Learning Neural_Networks Probabilistic_Methods
2103 1346 paper_id 1153942 80491 term_0 0 0 term_1 0 0 term_2 1 0 term_3 0 0 ... ... ... term_1429 0 0 term_1430 0 0 term_1431 0 0 term_1432 0 0 subject Genetic_Algorithms Neural_Networks
[1435 rows x 5 columns]
各 subject の論文のカウントを表示しましょう。
print(papers.subject.value_counts())
Neural_Networks 818 Probabilistic_Methods 426 Genetic_Algorithms 418 Theory 351 Case_Based 298 Reinforcement_Learning 217 Rule_Learning 180 Name: subject, dtype: int64
論文 id と subjects をゼロベースのインデックスに変換します。
class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
引用グラフを可視化しましょう。グラフの各ノードは論文を表し、ノードの色はその subject に対応します。データセットの論文のサンプルを示しているだけでだることに注意してください。
plt.figure(figsize=(10, 10))
colors = papers["subject"].tolist()
cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)
データセットを階層化された (= stratified) 訓練とテストセットに分割する
train_data, test_data = [], []
for _, group_data in papers.groupby("subject"):
# Select around 50% of the dataset for training.
random_selection = np.random.rand(len(group_data.index)) <= 0.5
train_data.append(group_data[random_selection])
test_data.append(group_data[~random_selection])
train_data = pd.concat(train_data).sample(frac=1)
test_data = pd.concat(test_data).sample(frac=1)
print("Train data shape:", train_data.shape)
print("Test data shape:", test_data.shape)
Train data shape: (1360, 1435) Test data shape: (1348, 1435)
訓練と評価実験の実装
hidden_units = [32, 32]
learning_rate = 0.01
dropout_rate = 0.5
num_epochs = 300
batch_size = 256
この関数は与えられた訓練データを使用してモデルをコンパイルして訓練します。
def run_experiment(model, x_train, y_train):
# Compile the model.
model.compile(
optimizer=keras.optimizers.Adam(learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)
# Create an early stopping callback.
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_acc", patience=50, restore_best_weights=True
)
# Fit the model.
history = model.fit(
x=x_train,
y=y_train,
epochs=num_epochs,
batch_size=batch_size,
validation_split=0.15,
callbacks=[early_stopping],
)
return history
この関数は訓練中のモデルの損失と精度曲線を表示します。
def display_learning_curves(history):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(history.history["loss"])
ax1.plot(history.history["val_loss"])
ax1.legend(["train", "test"], loc="upper right")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax2.plot(history.history["acc"])
ax2.plot(history.history["val_acc"])
ax2.legend(["train", "test"], loc="upper right")
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Accuracy")
plt.show()
順伝播ネットワーク (FFN) モジュールの実装
このモジュールはベースラインと GNN モデルで使用します。
def create_ffn(hidden_units, dropout_rate, name=None):
fnn_layers = []
for units in hidden_units:
fnn_layers.append(layers.BatchNormalization())
fnn_layers.append(layers.Dropout(dropout_rate))
fnn_layers.append(layers.Dense(units, activation=tf.nn.gelu))
return keras.Sequential(fnn_layers, name=name)
ベースライン・ニューラルネットワーク・モデルの構築
ベースラインモデル用のデータの準備
feature_names = set(papers.columns) - {"paper_id", "subject"}
num_features = len(feature_names)
num_classes = len(class_idx)
# Create train and test features as a numpy array.
x_train = train_data[feature_names].to_numpy()
x_test = test_data[feature_names].to_numpy()
# Create train and test targets as a numpy array.
y_train = train_data["subject"]
y_test = test_data["subject"]
ベースライン分類器の実装
5 つのスキップ接続を持つ FFN ブロックを追加し、後で構築する GNN モデルとおおよそ同じ数のパラメータを持つベースラインモデルを生成するようにします。
def create_baseline_model(hidden_units, num_classes, dropout_rate=0.2):
inputs = layers.Input(shape=(num_features,), name="input_features")
x = create_ffn(hidden_units, dropout_rate, name=f"ffn_block1")(inputs)
for block_idx in range(4):
# Create an FFN block.
x1 = create_ffn(hidden_units, dropout_rate, name=f"ffn_block{block_idx + 2}")(x)
# Add skip connection.
x = layers.Add(name=f"skip_connection{block_idx + 2}")([x, x1])
# Compute logits.
logits = layers.Dense(num_classes, name="logits")(x)
# Create the model.
return keras.Model(inputs=inputs, outputs=logits, name="baseline")
baseline_model = create_baseline_model(hidden_units, num_classes, dropout_rate)
baseline_model.summary()
Model: "baseline" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_features (InputLayer) [(None, 1433)] 0 __________________________________________________________________________________________________ ffn_block1 (Sequential) (None, 32) 52804 input_features[0][0] __________________________________________________________________________________________________ ffn_block2 (Sequential) (None, 32) 2368 ffn_block1[0][0] __________________________________________________________________________________________________ skip_connection2 (Add) (None, 32) 0 ffn_block1[0][0] ffn_block2[0][0] __________________________________________________________________________________________________ ffn_block3 (Sequential) (None, 32) 2368 skip_connection2[0][0] __________________________________________________________________________________________________ skip_connection3 (Add) (None, 32) 0 skip_connection2[0][0] ffn_block3[0][0] __________________________________________________________________________________________________ ffn_block4 (Sequential) (None, 32) 2368 skip_connection3[0][0] __________________________________________________________________________________________________ skip_connection4 (Add) (None, 32) 0 skip_connection3[0][0] ffn_block4[0][0] __________________________________________________________________________________________________ ffn_block5 (Sequential) (None, 32) 2368 skip_connection4[0][0] __________________________________________________________________________________________________ skip_connection5 (Add) (None, 32) 0 skip_connection4[0][0] ffn_block5[0][0] __________________________________________________________________________________________________ logits (Dense) (None, 7) 231 skip_connection5[0][0] ================================================================================================== Total params: 62,507 Trainable params: 59,065 Non-trainable params: 3,442 _______________________________________________________________________________________
ベースライン分類器の訓練
history = run_experiment(baseline_model, x_train, y_train)
Epoch 1/300 5/5 [==============================] - 3s 203ms/step - loss: 4.1695 - acc: 0.1660 - val_loss: 1.9008 - val_acc: 0.3186 Epoch 2/300 5/5 [==============================] - 0s 15ms/step - loss: 2.9269 - acc: 0.2630 - val_loss: 1.8906 - val_acc: 0.3235 Epoch 3/300 5/5 [==============================] - 0s 15ms/step - loss: 2.5669 - acc: 0.2424 - val_loss: 1.8713 - val_acc: 0.3186 Epoch 4/300 5/5 [==============================] - 0s 15ms/step - loss: 2.1377 - acc: 0.3147 - val_loss: 1.8687 - val_acc: 0.3529 Epoch 5/300 5/5 [==============================] - 0s 15ms/step - loss: 2.0256 - acc: 0.3297 - val_loss: 1.8285 - val_acc: 0.3235 Epoch 6/300 5/5 [==============================] - 0s 15ms/step - loss: 1.8148 - acc: 0.3495 - val_loss: 1.8000 - val_acc: 0.3235 Epoch 7/300 5/5 [==============================] - 0s 15ms/step - loss: 1.7216 - acc: 0.3883 - val_loss: 1.7771 - val_acc: 0.3333 Epoch 8/300 5/5 [==============================] - 0s 15ms/step - loss: 1.6941 - acc: 0.3910 - val_loss: 1.7528 - val_acc: 0.3284 Epoch 9/300 5/5 [==============================] - 0s 15ms/step - loss: 1.5690 - acc: 0.4358 - val_loss: 1.7128 - val_acc: 0.3333 Epoch 10/300 5/5 [==============================] - 0s 15ms/step - loss: 1.5139 - acc: 0.4367 - val_loss: 1.6650 - val_acc: 0.3676 Epoch 11/300 5/5 [==============================] - 0s 15ms/step - loss: 1.4370 - acc: 0.4930 - val_loss: 1.6145 - val_acc: 0.3775 Epoch 12/300 5/5 [==============================] - 0s 15ms/step - loss: 1.3696 - acc: 0.5109 - val_loss: 1.5787 - val_acc: 0.3873 Epoch 13/300 5/5 [==============================] - 0s 15ms/step - loss: 1.3979 - acc: 0.5341 - val_loss: 1.5564 - val_acc: 0.3922 Epoch 14/300 5/5 [==============================] - 0s 15ms/step - loss: 1.2681 - acc: 0.5599 - val_loss: 1.5547 - val_acc: 0.3922 Epoch 15/300 5/5 [==============================] - 0s 16ms/step - loss: 1.1970 - acc: 0.5807 - val_loss: 1.5735 - val_acc: 0.3873 Epoch 16/300 5/5 [==============================] - 0s 15ms/step - loss: 1.1555 - acc: 0.6032 - val_loss: 1.5131 - val_acc: 0.4216 Epoch 17/300 5/5 [==============================] - 0s 15ms/step - loss: 1.1234 - acc: 0.6130 - val_loss: 1.4385 - val_acc: 0.4608 Epoch 18/300 5/5 [==============================] - 0s 14ms/step - loss: 1.0507 - acc: 0.6306 - val_loss: 1.3929 - val_acc: 0.4804 Epoch 19/300 5/5 [==============================] - 0s 15ms/step - loss: 1.0341 - acc: 0.6393 - val_loss: 1.3628 - val_acc: 0.4902 Epoch 20/300 5/5 [==============================] - 0s 35ms/step - loss: 0.9457 - acc: 0.6693 - val_loss: 1.3383 - val_acc: 0.4902 Epoch 21/300 5/5 [==============================] - 0s 17ms/step - loss: 0.9054 - acc: 0.6756 - val_loss: 1.3365 - val_acc: 0.4951 Epoch 22/300 5/5 [==============================] - 0s 15ms/step - loss: 0.8952 - acc: 0.6854 - val_loss: 1.3228 - val_acc: 0.5049 Epoch 23/300 5/5 [==============================] - 0s 15ms/step - loss: 0.8413 - acc: 0.7217 - val_loss: 1.2924 - val_acc: 0.5294 Epoch 24/300 5/5 [==============================] - 0s 15ms/step - loss: 0.8543 - acc: 0.6998 - val_loss: 1.2379 - val_acc: 0.5490 Epoch 25/300 5/5 [==============================] - 0s 16ms/step - loss: 0.7632 - acc: 0.7376 - val_loss: 1.1516 - val_acc: 0.5833 Epoch 26/300 5/5 [==============================] - 0s 15ms/step - loss: 0.7189 - acc: 0.7496 - val_loss: 1.1296 - val_acc: 0.5931 Epoch 27/300 5/5 [==============================] - 0s 15ms/step - loss: 0.7433 - acc: 0.7482 - val_loss: 1.0937 - val_acc: 0.6127 Epoch 28/300 5/5 [==============================] - 0s 15ms/step - loss: 0.7310 - acc: 0.7440 - val_loss: 1.0950 - val_acc: 0.5980 Epoch 29/300 5/5 [==============================] - 0s 16ms/step - loss: 0.7059 - acc: 0.7654 - val_loss: 1.1343 - val_acc: 0.5882 Epoch 30/300 5/5 [==============================] - 0s 21ms/step - loss: 0.6831 - acc: 0.7645 - val_loss: 1.1938 - val_acc: 0.5686 Epoch 31/300 5/5 [==============================] - 0s 23ms/step - loss: 0.6741 - acc: 0.7788 - val_loss: 1.1281 - val_acc: 0.5931 Epoch 32/300 5/5 [==============================] - 0s 16ms/step - loss: 0.6344 - acc: 0.7753 - val_loss: 1.0870 - val_acc: 0.6029 Epoch 33/300 5/5 [==============================] - 0s 16ms/step - loss: 0.6052 - acc: 0.7876 - val_loss: 1.0947 - val_acc: 0.6127 Epoch 34/300 5/5 [==============================] - 0s 15ms/step - loss: 0.6313 - acc: 0.7908 - val_loss: 1.1186 - val_acc: 0.5882 Epoch 35/300 5/5 [==============================] - 0s 16ms/step - loss: 0.6163 - acc: 0.7955 - val_loss: 1.0899 - val_acc: 0.6176 Epoch 36/300 5/5 [==============================] - 0s 16ms/step - loss: 0.5388 - acc: 0.8203 - val_loss: 1.1222 - val_acc: 0.5882 Epoch 37/300 5/5 [==============================] - 0s 16ms/step - loss: 0.5487 - acc: 0.8080 - val_loss: 1.0205 - val_acc: 0.6127 Epoch 38/300 5/5 [==============================] - 0s 16ms/step - loss: 0.5885 - acc: 0.7903 - val_loss: 0.9268 - val_acc: 0.6569 Epoch 39/300 5/5 [==============================] - 0s 15ms/step - loss: 0.5541 - acc: 0.8025 - val_loss: 0.9367 - val_acc: 0.6471 Epoch 40/300 5/5 [==============================] - 0s 36ms/step - loss: 0.5594 - acc: 0.7935 - val_loss: 0.9688 - val_acc: 0.6275 Epoch 41/300 5/5 [==============================] - 0s 17ms/step - loss: 0.5255 - acc: 0.8169 - val_loss: 1.0076 - val_acc: 0.6324 Epoch 42/300 5/5 [==============================] - 0s 16ms/step - loss: 0.5284 - acc: 0.8180 - val_loss: 1.0106 - val_acc: 0.6373 Epoch 43/300 5/5 [==============================] - 0s 15ms/step - loss: 0.5141 - acc: 0.8188 - val_loss: 0.8842 - val_acc: 0.6912 Epoch 44/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4767 - acc: 0.8342 - val_loss: 0.8249 - val_acc: 0.7108 Epoch 45/300 5/5 [==============================] - 0s 15ms/step - loss: 0.5915 - acc: 0.8055 - val_loss: 0.8567 - val_acc: 0.6912 Epoch 46/300 5/5 [==============================] - 0s 15ms/step - loss: 0.5026 - acc: 0.8357 - val_loss: 0.9287 - val_acc: 0.6618 Epoch 47/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4859 - acc: 0.8304 - val_loss: 0.9044 - val_acc: 0.6667 Epoch 48/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4860 - acc: 0.8440 - val_loss: 0.8672 - val_acc: 0.6912 Epoch 49/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4723 - acc: 0.8358 - val_loss: 0.8717 - val_acc: 0.6863 Epoch 50/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4831 - acc: 0.8457 - val_loss: 0.8674 - val_acc: 0.6912 Epoch 51/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4873 - acc: 0.8353 - val_loss: 0.8587 - val_acc: 0.7010 Epoch 52/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4537 - acc: 0.8472 - val_loss: 0.8544 - val_acc: 0.7059 Epoch 53/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4684 - acc: 0.8425 - val_loss: 0.8423 - val_acc: 0.7206 Epoch 54/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4436 - acc: 0.8523 - val_loss: 0.8607 - val_acc: 0.6961 Epoch 55/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4589 - acc: 0.8335 - val_loss: 0.8462 - val_acc: 0.7059 Epoch 56/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4757 - acc: 0.8360 - val_loss: 0.8415 - val_acc: 0.7010 Epoch 57/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4270 - acc: 0.8593 - val_loss: 0.8094 - val_acc: 0.7255 Epoch 58/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4530 - acc: 0.8307 - val_loss: 0.8357 - val_acc: 0.7108 Epoch 59/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4370 - acc: 0.8453 - val_loss: 0.8804 - val_acc: 0.7108 Epoch 60/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4379 - acc: 0.8465 - val_loss: 0.8791 - val_acc: 0.7108 Epoch 61/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4254 - acc: 0.8615 - val_loss: 0.8355 - val_acc: 0.7059 Epoch 62/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3929 - acc: 0.8696 - val_loss: 0.8355 - val_acc: 0.7304 Epoch 63/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4039 - acc: 0.8516 - val_loss: 0.8576 - val_acc: 0.7353 Epoch 64/300 5/5 [==============================] - 0s 35ms/step - loss: 0.4220 - acc: 0.8596 - val_loss: 0.8848 - val_acc: 0.7059 Epoch 65/300 5/5 [==============================] - 0s 17ms/step - loss: 0.4091 - acc: 0.8521 - val_loss: 0.8560 - val_acc: 0.7108 Epoch 66/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4658 - acc: 0.8470 - val_loss: 0.8518 - val_acc: 0.7206 Epoch 67/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4269 - acc: 0.8437 - val_loss: 0.7878 - val_acc: 0.7255 Epoch 68/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4368 - acc: 0.8438 - val_loss: 0.7859 - val_acc: 0.7255 Epoch 69/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4113 - acc: 0.8452 - val_loss: 0.8056 - val_acc: 0.7402 Epoch 70/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4304 - acc: 0.8469 - val_loss: 0.8093 - val_acc: 0.7451 Epoch 71/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4159 - acc: 0.8585 - val_loss: 0.8090 - val_acc: 0.7451 Epoch 72/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4218 - acc: 0.8610 - val_loss: 0.8028 - val_acc: 0.7402 Epoch 73/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3632 - acc: 0.8714 - val_loss: 0.8153 - val_acc: 0.7304 Epoch 74/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3745 - acc: 0.8722 - val_loss: 0.8299 - val_acc: 0.7402 Epoch 75/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3997 - acc: 0.8680 - val_loss: 0.8445 - val_acc: 0.7255 Epoch 76/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4143 - acc: 0.8620 - val_loss: 0.8344 - val_acc: 0.7206 Epoch 77/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4006 - acc: 0.8616 - val_loss: 0.8358 - val_acc: 0.7255 Epoch 78/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4266 - acc: 0.8532 - val_loss: 0.8266 - val_acc: 0.7206 Epoch 79/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4337 - acc: 0.8523 - val_loss: 0.8181 - val_acc: 0.7206 Epoch 80/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3857 - acc: 0.8624 - val_loss: 0.8143 - val_acc: 0.7206 Epoch 81/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4146 - acc: 0.8567 - val_loss: 0.8192 - val_acc: 0.7108 Epoch 82/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3638 - acc: 0.8794 - val_loss: 0.8248 - val_acc: 0.7206 Epoch 83/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4126 - acc: 0.8678 - val_loss: 0.8565 - val_acc: 0.7255 Epoch 84/300 5/5 [==============================] - 0s 36ms/step - loss: 0.3941 - acc: 0.8530 - val_loss: 0.8624 - val_acc: 0.7206 Epoch 85/300 5/5 [==============================] - 0s 17ms/step - loss: 0.3843 - acc: 0.8786 - val_loss: 0.8389 - val_acc: 0.7255 Epoch 86/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3651 - acc: 0.8747 - val_loss: 0.8314 - val_acc: 0.7206 Epoch 87/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3911 - acc: 0.8657 - val_loss: 0.8736 - val_acc: 0.7255 Epoch 88/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3706 - acc: 0.8714 - val_loss: 0.9159 - val_acc: 0.7108 Epoch 89/300 5/5 [==============================] - 0s 15ms/step - loss: 0.4403 - acc: 0.8386 - val_loss: 0.9038 - val_acc: 0.7206 Epoch 90/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3865 - acc: 0.8668 - val_loss: 0.8733 - val_acc: 0.7206 Epoch 91/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3757 - acc: 0.8643 - val_loss: 0.8704 - val_acc: 0.7157 Epoch 92/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3828 - acc: 0.8669 - val_loss: 0.8786 - val_acc: 0.7157 Epoch 93/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3651 - acc: 0.8787 - val_loss: 0.8977 - val_acc: 0.7206 Epoch 94/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3913 - acc: 0.8614 - val_loss: 0.9415 - val_acc: 0.7206 Epoch 95/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3995 - acc: 0.8590 - val_loss: 0.9495 - val_acc: 0.7157 Epoch 96/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4228 - acc: 0.8508 - val_loss: 0.9490 - val_acc: 0.7059 Epoch 97/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3853 - acc: 0.8789 - val_loss: 0.9402 - val_acc: 0.7157 Epoch 98/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3711 - acc: 0.8812 - val_loss: 0.9283 - val_acc: 0.7206 Epoch 99/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3949 - acc: 0.8578 - val_loss: 0.9591 - val_acc: 0.7108 Epoch 100/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3563 - acc: 0.8780 - val_loss: 0.9744 - val_acc: 0.7206 Epoch 101/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3579 - acc: 0.8815 - val_loss: 0.9358 - val_acc: 0.7206 Epoch 102/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4069 - acc: 0.8698 - val_loss: 0.9245 - val_acc: 0.7157 Epoch 103/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3161 - acc: 0.8955 - val_loss: 0.9401 - val_acc: 0.7157 Epoch 104/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3346 - acc: 0.8910 - val_loss: 0.9517 - val_acc: 0.7157 Epoch 105/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4204 - acc: 0.8538 - val_loss: 0.9366 - val_acc: 0.7157 Epoch 106/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3492 - acc: 0.8821 - val_loss: 0.9424 - val_acc: 0.7353 Epoch 107/300 5/5 [==============================] - 0s 16ms/step - loss: 0.4002 - acc: 0.8604 - val_loss: 0.9842 - val_acc: 0.7157 Epoch 108/300 5/5 [==============================] - 0s 35ms/step - loss: 0.3701 - acc: 0.8736 - val_loss: 0.9999 - val_acc: 0.7010 Epoch 109/300 5/5 [==============================] - 0s 17ms/step - loss: 0.3391 - acc: 0.8866 - val_loss: 0.9768 - val_acc: 0.6961 Epoch 110/300 5/5 [==============================] - 0s 15ms/step - loss: 0.3857 - acc: 0.8739 - val_loss: 0.9953 - val_acc: 0.7255 Epoch 111/300 5/5 [==============================] - 0s 16ms/step - loss: 0.3822 - acc: 0.8731 - val_loss: 0.9817 - val_acc: 0.7255 Epoch 112/300 5/5 [==============================] - 0s 23ms/step - loss: 0.3211 - acc: 0.8887 - val_loss: 0.9781 - val_acc: 0.7108 Epoch 113/300 5/5 [==============================] - 0s 20ms/step - loss: 0.3473 - acc: 0.8715 - val_loss: 0.9927 - val_acc: 0.6912 Epoch 114/300 5/5 [==============================] - 0s 20ms/step - loss: 0.4026 - acc: 0.8621 - val_loss: 1.0002 - val_acc: 0.6863 Epoch 115/300 5/5 [==============================] - 0s 20ms/step - loss: 0.3413 - acc: 0.8837 - val_loss: 1.0031 - val_acc: 0.6912 Epoch 116/300 5/5 [==============================] - 0s 20ms/step - loss: 0.3653 - acc: 0.8765 - val_loss: 1.0065 - val_acc: 0.7010 Epoch 117/300 5/5 [==============================] - 0s 21ms/step - loss: 0.3147 - acc: 0.8974 - val_loss: 1.0206 - val_acc: 0.7059 Epoch 118/300 5/5 [==============================] - 0s 21ms/step - loss: 0.3639 - acc: 0.8783 - val_loss: 1.0206 - val_acc: 0.7010 Epoch 119/300 5/5 [==============================] - 0s 19ms/step - loss: 0.3660 - acc: 0.8696 - val_loss: 1.0260 - val_acc: 0.6912 Epoch 120/300 5/5 [==============================] - 0s 18ms/step - loss: 0.3624 - acc: 0.8708 - val_loss: 1.0619 - val_acc: 0.6814
学習カーブをプロットしましょう。
display_learning_curves(history)
そしてベースラインモデルをテストデータ分割で評価します。
_, test_accuracy = baseline_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Test accuracy: 73.52%
ベースラインモデルの評価を調べる
単語の存在確率に関して二値単語ベクトルをランダムに生成することにより新しいデータインスタンスを作成しましょう。
def generate_random_instances(num_instances):
token_probability = x_train.mean(axis=0)
instances = []
for _ in range(num_instances):
probabilities = np.random.uniform(size=len(token_probability))
instance = (probabilities <= token_probability).astype(int)
instances.append(instance)
return np.array(instances)
def display_class_probabilities(probabilities):
for instance_idx, probs in enumerate(probabilities):
print(f"Instance {instance_idx + 1}:")
for class_idx, prob in enumerate(probs):
print(f"- {class_values[class_idx]}: {round(prob * 100, 2)}%")
そしてこれらのランダムに生成されたインスタンスが与えられたときのベースラインモデル予測を示します。
new_instances = generate_random_instances(num_classes)
logits = baseline_model.predict(new_instances)
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
display_class_probabilities(probabilities)
Instance 1: - Case_Based: 13.02% - Genetic_Algorithms: 6.89% - Neural_Networks: 23.32% - Probabilistic_Methods: 47.89% - Reinforcement_Learning: 2.66% - Rule_Learning: 1.18% - Theory: 5.03% Instance 2: - Case_Based: 1.64% - Genetic_Algorithms: 59.74% - Neural_Networks: 27.13% - Probabilistic_Methods: 9.02% - Reinforcement_Learning: 1.05% - Rule_Learning: 0.12% - Theory: 1.31% Instance 3: - Case_Based: 1.35% - Genetic_Algorithms: 77.41% - Neural_Networks: 9.56% - Probabilistic_Methods: 7.89% - Reinforcement_Learning: 0.42% - Rule_Learning: 0.46% - Theory: 2.92% Instance 4: - Case_Based: 0.43% - Genetic_Algorithms: 3.87% - Neural_Networks: 92.88% - Probabilistic_Methods: 0.97% - Reinforcement_Learning: 0.56% - Rule_Learning: 0.09% - Theory: 1.2% Instance 5: - Case_Based: 0.11% - Genetic_Algorithms: 0.17% - Neural_Networks: 10.26% - Probabilistic_Methods: 0.5% - Reinforcement_Learning: 0.35% - Rule_Learning: 0.63% - Theory: 87.97% Instance 6: - Case_Based: 0.98% - Genetic_Algorithms: 23.37% - Neural_Networks: 70.76% - Probabilistic_Methods: 1.12% - Reinforcement_Learning: 2.23% - Rule_Learning: 0.21% - Theory: 1.33% Instance 7: - Case_Based: 0.64% - Genetic_Algorithms: 2.42% - Neural_Networks: 27.19% - Probabilistic_Methods: 14.07% - Reinforcement_Learning: 1.62% - Rule_Learning: 9.35% - Theory: 44.7%
グラフニューラルネットワーク・モデルの構築
グラフモデルのためのデータの準備
訓練のためにグラフデータを準備してモデルにロードすることは GNN モデルで最も困難なパートで、これは専用ライブラリで様々な方法で対処されます。この例では、データセットがメモリに全体的に収まる単一グラフからなる場合に適した、グラフデータを準備して使用するための単純なアプローチを示します。
グラフデータは graph_info タプルで表され、これは以下の 3 つの要素から構成されます :
- node_features : これは [num_nodes, num_features] NumPy 配列で、ノード特徴を含みます。このデータセットでは、ノードは論文で、node_features は各論文の word-presence 二値ベクトルです。
- edges : これは [num_edges, num_edges] NumPy 配列で、ノード間のリンクのスパース 隣接行列 を表します。この例では、リンクは論文間の引用です。
- edge_weights (オプション) : これは [num_edges] NumPy 配列でエッジ重みを含みます、これはグラフのノード間の関係を定量化します。この例では、論文引用への重みはありません。
# Create an edges array (sparse adjacency matrix) of shape [2, num_edges].
edges = citations[["source", "target"]].to_numpy().T
# Create an edge weights array of ones.
edge_weights = tf.ones(shape=edges.shape[1])
# Create a node features array of shape [num_nodes, num_features].
node_features = tf.cast(
papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32
)
# Create graph info tuple with node_features, edges, and edge_weights.
graph_info = (node_features, edges, edge_weights)
print("Edges shape:", edges.shape)
print("Nodes shape:", node_features.shape)
Edges shape: (2, 5429) Nodes shape: (2708, 1433)
グラフ畳み込み層の実装
グラフ畳み込みモジュールを Keras 層として実装します。GraphConvLayer は以下のステップを実行します :
- 準備 : 入力ノード表現はメッセージを生成するために FFN を使用して処理されます。線形変換を表現に適用するだけで処理を単純化できます。
- 集約 (集計) : 各ノードの近傍のメッセージは、各ノードに対して単一の集約メッセージを準備するために sum, mean と max のような順列不変なプーリング演算を使用して edge_weights に関して集約されます。例えば、近傍メッセージを集約するために使用される tf.math.unsorted_segment_sum API を見てください。
- 更新 : node_repesentations と aggregated_messages — 両者の shape [num_nodes, representation_dim] — はノード表現 (ノード埋め込み) の新しい状態を生成するために連結して処理されます。combination_type が gru であれば、node_repesentations と aggregated_messages はシークエンスを作成するためにスタックされてから、GRU 層により処理されます。そうでない場合には、 node_repesentations と aggregated_messages は加算されるか連結されて、FFN を使用して処理されます。
実装されたテクニックは、グラフ畳み込みネットワーク, GraphSage, グラフ Isomorphism ネットワーク, Simple グラフネットワーク, そして Gated グラフ・シークエンス・ニューラルネットワーク からのアイデアを使用しています。カバーされていない 2 つの他の主要テクニックは グラフ注意ネットワーク と Message Passing ニューラルネットワーク です。
class GraphConvLayer(layers.Layer):
def __init__(
self,
hidden_units,
dropout_rate=0.2,
aggregation_type="mean",
combination_type="concat",
normalize=False,
*args,
**kwargs,
):
super(GraphConvLayer, self).__init__(*args, **kwargs)
self.aggregation_type = aggregation_type
self.combination_type = combination_type
self.normalize = normalize
self.ffn_prepare = create_ffn(hidden_units, dropout_rate)
if self.combination_type == "gated":
self.update_fn = layers.GRU(
units=hidden_units,
activation="tanh",
recurrent_activation="sigmoid",
dropout=dropout_rate,
return_state=True,
recurrent_dropout=dropout_rate,
)
else:
self.update_fn = create_ffn(hidden_units, dropout_rate)
def prepare(self, node_repesentations, weights=None):
# node_repesentations shape is [num_edges, embedding_dim].
messages = self.ffn_prepare(node_repesentations)
if weights is not None:
messages = messages * tf.expand_dims(weights, -1)
return messages
def aggregate(self, node_indices, neighbour_messages):
# node_indices shape is [num_edges].
# neighbour_messages shape: [num_edges, representation_dim].
num_nodes = tf.math.reduce_max(node_indices) + 1
if self.aggregation_type == "sum":
aggregated_message = tf.math.unsorted_segment_sum(
neighbour_messages, node_indices, num_segments=num_nodes
)
elif self.aggregation_type == "mean":
aggregated_message = tf.math.unsorted_segment_mean(
neighbour_messages, node_indices, num_segments=num_nodes
)
elif self.aggregation_type == "max":
aggregated_message = tf.math.unsorted_segment_max(
neighbour_messages, node_indices, num_segments=num_nodes
)
else:
raise ValueError(f"Invalid aggregation type: {self.aggregation_type}.")
return aggregated_message
def update(self, node_repesentations, aggregated_messages):
# node_repesentations shape is [num_nodes, representation_dim].
# aggregated_messages shape is [num_nodes, representation_dim].
if self.combination_type == "gru":
# Create a sequence of two elements for the GRU layer.
h = tf.stack([node_repesentations, aggregated_messages], axis=1)
elif self.combination_type == "concat":
# Concatenate the node_repesentations and aggregated_messages.
h = tf.concat([node_repesentations, aggregated_messages], axis=1)
elif self.combination_type == "add":
# Add node_repesentations and aggregated_messages.
h = node_repesentations + aggregated_messages
else:
raise ValueError(f"Invalid combination type: {self.combination_type}.")
# Apply the processing function.
node_embeddings = self.update_fn(h)
if self.combination_type == "gru":
node_embeddings = tf.unstack(node_embeddings, axis=1)[-1]
if self.normalize:
node_embeddings = tf.nn.l2_normalize(node_embeddings, axis=-1)
return node_embeddings
def call(self, inputs):
"""Process the inputs to produce the node_embeddings.
inputs: a tuple of three elements: node_repesentations, edges, edge_weights.
Returns: node_embeddings of shape [num_nodes, representation_dim].
"""
node_repesentations, edges, edge_weights = inputs
# Get node_indices (source) and neighbour_indices (target) from edges.
node_indices, neighbour_indices = edges[0], edges[1]
# neighbour_repesentations shape is [num_edges, representation_dim].
neighbour_repesentations = tf.gather(node_repesentations, neighbour_indices)
# Prepare the messages of the neighbours.
neighbour_messages = self.prepare(neighbour_repesentations, edge_weights)
# Aggregate the neighbour messages.
aggregated_messages = self.aggregate(node_indices, neighbour_messages)
# Update the node embedding with the neighbour messages.
return self.update(node_repesentations, aggregated_messages)
グラフニューラルネットワーク・ノード分類器の実装
GNN 分類器モデルは、以下のように、the Design Space for Graph Neural Networks アプローチに従っています :
- 初期ノード表現を生成するために FFN を使用して前処理をノード特徴に適用します。
- ノード埋め込みを生成するために、スキップ接続を持つ、一つ以上のグラフ畳み込み層をノード表現に適用します。
- 最終的なノード埋め込みを生成するために FFN を使用してノード埋め込みに後処理を適用します。
- ノードクラスを予測するためにノード埋め込みを Softmax 層に供給します。
追加された各グラフ畳み込み層は近傍からの更なるレベルからの情報を捕捉します。けれども、多くのグラフ畳み込み層の追加は oversmoothing を引き起こす可能性があり、そこではモデルは総てのノードに対して類似の埋め込みを生成します。
Keras モデルのコンストラクタに渡される graph_info は、訓練や予測のための入力データではなく、Keras モデルオブジェクトのプロパティとして使用されることに注意してください。モデルは node_indices の バッチ を受け取り、これは graph_info からノード特徴と近傍を検索するために使用されます。
class GNNNodeClassifier(tf.keras.Model):
def __init__(
self,
graph_info,
num_classes,
hidden_units,
aggregation_type="sum",
combination_type="concat",
dropout_rate=0.2,
normalize=True,
*args,
**kwargs,
):
super(GNNNodeClassifier, self).__init__(*args, **kwargs)
# Unpack graph_info to three elements: node_features, edges, and edge_weight.
node_features, edges, edge_weights = graph_info
self.node_features = node_features
self.edges = edges
self.edge_weights = edge_weights
# Set edge_weights to ones if not provided.
if self.edge_weights is None:
self.edge_weights = tf.ones(shape=edges.shape[1])
# Scale edge_weights to sum to 1.
self.edge_weights = self.edge_weights / tf.math.reduce_sum(self.edge_weights)
# Create a process layer.
self.preprocess = create_ffn(hidden_units, dropout_rate, name="preprocess")
# Create the first GraphConv layer.
self.conv1 = GraphConvLayer(
hidden_units,
dropout_rate,
aggregation_type,
combination_type,
normalize,
name="graph_conv1",
)
# Create the second GraphConv layer.
self.conv2 = GraphConvLayer(
hidden_units,
dropout_rate,
aggregation_type,
combination_type,
normalize,
name="graph_conv2",
)
# Create a postprocess layer.
self.postprocess = create_ffn(hidden_units, dropout_rate, name="postprocess")
# Create a compute logits layer.
self.compute_logits = layers.Dense(units=num_classes, name="logits")
def call(self, input_node_indices):
# Preprocess the node_features to produce node representations.
x = self.preprocess(self.node_features)
# Apply the first graph conv layer.
x1 = self.conv1((x, self.edges, self.edge_weights))
# Skip connection.
x = x1 + x
# Apply the second graph conv layer.
x2 = self.conv2((x, self.edges, self.edge_weights))
# Skip connection.
x = x2 + x
# Postprocess node embedding.
x = self.postprocess(x)
# Fetch node embeddings for the input node_indices.
node_embeddings = tf.gather(x, input_node_indices)
# Compute logits
return self.compute_logits(node_embeddings)
GNN モデルのインスタンス化と呼び出しをテストしましょう。N ノードインデックスを提供すれば、グラフのサイズに関係なく、出力は shape [N, num_classes] のテンソルになることに気づくでしょう。
gnn_model = GNNNodeClassifier(
graph_info=graph_info,
num_classes=num_classes,
hidden_units=hidden_units,
dropout_rate=dropout_rate,
name="gnn_model",
)
print("GNN output shape:", gnn_model([1, 10, 100]))
gnn_model.summary()
GNN output shape: tf.Tensor( [[ 0.00620723 0.06162593 0.0176599 0.00830251 -0.03019211 -0.00402163 0.00277454] [ 0.01705155 -0.0467547 0.01400987 -0.02146192 -0.11757397 0.10820404 -0.0375765 ] [-0.02516522 -0.05514468 -0.03842098 -0.0495692 -0.05128997 -0.02241635 -0.07738923]], shape=(3, 7), dtype=float32) Model: "gnn_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= preprocess (Sequential) (2708, 32) 52804 _________________________________________________________________ graph_conv1 (GraphConvLayer) multiple 5888 _________________________________________________________________ graph_conv2 (GraphConvLayer) multiple 5888 _________________________________________________________________ postprocess (Sequential) (2708, 32) 2368 _________________________________________________________________ logits (Dense) multiple 231 ================================================================= Total params: 67,179 Trainable params: 63,481 Non-trainable params: 3,698 _________________________________________________________________
GNN モデルの訓練
モデルを訓練するのに標準的な教師あり交差エントロピー損失を使用することに注意してください。しかし、生成されたノード埋め込みに対して別の自己教師あり損失項を追加することも可能です、これはグラフの近傍ノードが類似の表現を持つ一方で、遠いノードが非類似な表現を持つことを確実にします。
x_train = train_data.paper_id.to_numpy()
history = run_experiment(gnn_model, x_train, y_train)
Epoch 1/300 5/5 [==============================] - 4s 188ms/step - loss: 2.2529 - acc: 0.1793 - val_loss: 1.8933 - val_acc: 0.2941 Epoch 2/300 5/5 [==============================] - 0s 83ms/step - loss: 1.9866 - acc: 0.2601 - val_loss: 1.8753 - val_acc: 0.3186 Epoch 3/300 5/5 [==============================] - 0s 77ms/step - loss: 1.8794 - acc: 0.2846 - val_loss: 1.8655 - val_acc: 0.3186 Epoch 4/300 5/5 [==============================] - 0s 74ms/step - loss: 1.8432 - acc: 0.3078 - val_loss: 1.8529 - val_acc: 0.3186 Epoch 5/300 5/5 [==============================] - 0s 69ms/step - loss: 1.8314 - acc: 0.3134 - val_loss: 1.8429 - val_acc: 0.3186 Epoch 6/300 5/5 [==============================] - 0s 68ms/step - loss: 1.8157 - acc: 0.3208 - val_loss: 1.8326 - val_acc: 0.3186 Epoch 7/300 5/5 [==============================] - 0s 94ms/step - loss: 1.8112 - acc: 0.3071 - val_loss: 1.8265 - val_acc: 0.3186 Epoch 8/300 5/5 [==============================] - 0s 67ms/step - loss: 1.8028 - acc: 0.3132 - val_loss: 1.8171 - val_acc: 0.3186 Epoch 9/300 5/5 [==============================] - 0s 68ms/step - loss: 1.8007 - acc: 0.3206 - val_loss: 1.7961 - val_acc: 0.3186 Epoch 10/300 5/5 [==============================] - 0s 68ms/step - loss: 1.7571 - acc: 0.3259 - val_loss: 1.7623 - val_acc: 0.3186 Epoch 11/300 5/5 [==============================] - 0s 68ms/step - loss: 1.7373 - acc: 0.3279 - val_loss: 1.7131 - val_acc: 0.3186 Epoch 12/300 5/5 [==============================] - 0s 76ms/step - loss: 1.7130 - acc: 0.3169 - val_loss: 1.6552 - val_acc: 0.3186 Epoch 13/300 5/5 [==============================] - 0s 70ms/step - loss: 1.6989 - acc: 0.3315 - val_loss: 1.6075 - val_acc: 0.3284 Epoch 14/300 5/5 [==============================] - 0s 79ms/step - loss: 1.6733 - acc: 0.3522 - val_loss: 1.6027 - val_acc: 0.3333 Epoch 15/300 5/5 [==============================] - 0s 75ms/step - loss: 1.6060 - acc: 0.3641 - val_loss: 1.6422 - val_acc: 0.3480 Epoch 16/300 5/5 [==============================] - 0s 68ms/step - loss: 1.5783 - acc: 0.3924 - val_loss: 1.6893 - val_acc: 0.3676 Epoch 17/300 5/5 [==============================] - 0s 70ms/step - loss: 1.5269 - acc: 0.4315 - val_loss: 1.7534 - val_acc: 0.3725 Epoch 18/300 5/5 [==============================] - 0s 77ms/step - loss: 1.4558 - acc: 0.4633 - val_loss: 1.7224 - val_acc: 0.4167 Epoch 19/300 5/5 [==============================] - 0s 75ms/step - loss: 1.4131 - acc: 0.4765 - val_loss: 1.6482 - val_acc: 0.4510 Epoch 20/300 5/5 [==============================] - 0s 70ms/step - loss: 1.3880 - acc: 0.4859 - val_loss: 1.4956 - val_acc: 0.4706 Epoch 21/300 5/5 [==============================] - 0s 73ms/step - loss: 1.3223 - acc: 0.5166 - val_loss: 1.5299 - val_acc: 0.4853 Epoch 22/300 5/5 [==============================] - 0s 75ms/step - loss: 1.3226 - acc: 0.5172 - val_loss: 1.6304 - val_acc: 0.4902 Epoch 23/300 5/5 [==============================] - 0s 75ms/step - loss: 1.2888 - acc: 0.5267 - val_loss: 1.6679 - val_acc: 0.5000 Epoch 24/300 5/5 [==============================] - 0s 69ms/step - loss: 1.2478 - acc: 0.5279 - val_loss: 1.6552 - val_acc: 0.4853 Epoch 25/300 5/5 [==============================] - 0s 70ms/step - loss: 1.1978 - acc: 0.5720 - val_loss: 1.6705 - val_acc: 0.4902 Epoch 26/300 5/5 [==============================] - 0s 70ms/step - loss: 1.1814 - acc: 0.5596 - val_loss: 1.6327 - val_acc: 0.5343 Epoch 27/300 5/5 [==============================] - 0s 68ms/step - loss: 1.1085 - acc: 0.5979 - val_loss: 1.5184 - val_acc: 0.5245 Epoch 28/300 5/5 [==============================] - 0s 69ms/step - loss: 1.0695 - acc: 0.6078 - val_loss: 1.5212 - val_acc: 0.4853 Epoch 29/300 5/5 [==============================] - 0s 70ms/step - loss: 1.1063 - acc: 0.6002 - val_loss: 1.5988 - val_acc: 0.4706 Epoch 30/300 5/5 [==============================] - 0s 68ms/step - loss: 1.0194 - acc: 0.6326 - val_loss: 1.5636 - val_acc: 0.4951 Epoch 31/300 5/5 [==============================] - 0s 70ms/step - loss: 1.0320 - acc: 0.6268 - val_loss: 1.5191 - val_acc: 0.5196 Epoch 32/300 5/5 [==============================] - 0s 82ms/step - loss: 0.9749 - acc: 0.6433 - val_loss: 1.5922 - val_acc: 0.5098 Epoch 33/300 5/5 [==============================] - 0s 85ms/step - loss: 0.9095 - acc: 0.6717 - val_loss: 1.5879 - val_acc: 0.5000 Epoch 34/300 5/5 [==============================] - 0s 78ms/step - loss: 0.9324 - acc: 0.6903 - val_loss: 1.5717 - val_acc: 0.4951 Epoch 35/300 5/5 [==============================] - 0s 80ms/step - loss: 0.8908 - acc: 0.6953 - val_loss: 1.5010 - val_acc: 0.5098 Epoch 36/300 5/5 [==============================] - 0s 99ms/step - loss: 0.8858 - acc: 0.6977 - val_loss: 1.5939 - val_acc: 0.5147 Epoch 37/300 5/5 [==============================] - 0s 79ms/step - loss: 0.8376 - acc: 0.6991 - val_loss: 1.4000 - val_acc: 0.5833 Epoch 38/300 5/5 [==============================] - 0s 75ms/step - loss: 0.8657 - acc: 0.7080 - val_loss: 1.3288 - val_acc: 0.5931 Epoch 39/300 5/5 [==============================] - 0s 86ms/step - loss: 0.9160 - acc: 0.6819 - val_loss: 1.1358 - val_acc: 0.6275 Epoch 40/300 5/5 [==============================] - 0s 80ms/step - loss: 0.8676 - acc: 0.7109 - val_loss: 1.0618 - val_acc: 0.6765 Epoch 41/300 5/5 [==============================] - 0s 72ms/step - loss: 0.8065 - acc: 0.7246 - val_loss: 1.0785 - val_acc: 0.6765 Epoch 42/300 5/5 [==============================] - 0s 76ms/step - loss: 0.8478 - acc: 0.7145 - val_loss: 1.0502 - val_acc: 0.6569 Epoch 43/300 5/5 [==============================] - 0s 78ms/step - loss: 0.8125 - acc: 0.7068 - val_loss: 0.9888 - val_acc: 0.6520 Epoch 44/300 5/5 [==============================] - 0s 68ms/step - loss: 0.7791 - acc: 0.7425 - val_loss: 0.9820 - val_acc: 0.6618 Epoch 45/300 5/5 [==============================] - 0s 69ms/step - loss: 0.7492 - acc: 0.7368 - val_loss: 0.9297 - val_acc: 0.6961 Epoch 46/300 5/5 [==============================] - 0s 71ms/step - loss: 0.7521 - acc: 0.7668 - val_loss: 0.9757 - val_acc: 0.6961 Epoch 47/300 5/5 [==============================] - 0s 71ms/step - loss: 0.7090 - acc: 0.7587 - val_loss: 0.9676 - val_acc: 0.7059 Epoch 48/300 5/5 [==============================] - 0s 68ms/step - loss: 0.7008 - acc: 0.7430 - val_loss: 0.9457 - val_acc: 0.7010 Epoch 49/300 5/5 [==============================] - 0s 69ms/step - loss: 0.6919 - acc: 0.7584 - val_loss: 0.9998 - val_acc: 0.6569 Epoch 50/300 5/5 [==============================] - 0s 68ms/step - loss: 0.7583 - acc: 0.7628 - val_loss: 0.9707 - val_acc: 0.6667 Epoch 51/300 5/5 [==============================] - 0s 69ms/step - loss: 0.6575 - acc: 0.7697 - val_loss: 0.9260 - val_acc: 0.6814 Epoch 52/300 5/5 [==============================] - 0s 78ms/step - loss: 0.6751 - acc: 0.7774 - val_loss: 0.9173 - val_acc: 0.6765 Epoch 53/300 5/5 [==============================] - 0s 92ms/step - loss: 0.6964 - acc: 0.7561 - val_loss: 0.8985 - val_acc: 0.6961 Epoch 54/300 5/5 [==============================] - 0s 77ms/step - loss: 0.6386 - acc: 0.7872 - val_loss: 0.9455 - val_acc: 0.6961 Epoch 55/300 5/5 [==============================] - 0s 77ms/step - loss: 0.6110 - acc: 0.8130 - val_loss: 0.9780 - val_acc: 0.6716 Epoch 56/300 5/5 [==============================] - 0s 76ms/step - loss: 0.6483 - acc: 0.7703 - val_loss: 0.9650 - val_acc: 0.6863 Epoch 57/300 5/5 [==============================] - 0s 78ms/step - loss: 0.6811 - acc: 0.7706 - val_loss: 0.9446 - val_acc: 0.6667 Epoch 58/300 5/5 [==============================] - 0s 76ms/step - loss: 0.6391 - acc: 0.7852 - val_loss: 0.9059 - val_acc: 0.7010 Epoch 59/300 5/5 [==============================] - 0s 76ms/step - loss: 0.6533 - acc: 0.7784 - val_loss: 0.8964 - val_acc: 0.7108 Epoch 60/300 5/5 [==============================] - 0s 101ms/step - loss: 0.6587 - acc: 0.7863 - val_loss: 0.8417 - val_acc: 0.7108 Epoch 61/300 5/5 [==============================] - 0s 84ms/step - loss: 0.5776 - acc: 0.8166 - val_loss: 0.8035 - val_acc: 0.7304 Epoch 62/300 5/5 [==============================] - 0s 80ms/step - loss: 0.6396 - acc: 0.7792 - val_loss: 0.8072 - val_acc: 0.7500 Epoch 63/300 5/5 [==============================] - 0s 67ms/step - loss: 0.6201 - acc: 0.7972 - val_loss: 0.7809 - val_acc: 0.7696 Epoch 64/300 5/5 [==============================] - 0s 68ms/step - loss: 0.6358 - acc: 0.7875 - val_loss: 0.7635 - val_acc: 0.7500 Epoch 65/300 5/5 [==============================] - 0s 70ms/step - loss: 0.5914 - acc: 0.8027 - val_loss: 0.8147 - val_acc: 0.7402 Epoch 66/300 5/5 [==============================] - 0s 69ms/step - loss: 0.5960 - acc: 0.7955 - val_loss: 0.9350 - val_acc: 0.7304 Epoch 67/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5752 - acc: 0.8001 - val_loss: 0.9849 - val_acc: 0.7157 Epoch 68/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5189 - acc: 0.8322 - val_loss: 1.0268 - val_acc: 0.7206 Epoch 69/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5413 - acc: 0.8078 - val_loss: 0.9132 - val_acc: 0.7549 Epoch 70/300 5/5 [==============================] - 0s 75ms/step - loss: 0.5231 - acc: 0.8222 - val_loss: 0.8673 - val_acc: 0.7647 Epoch 71/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5416 - acc: 0.8219 - val_loss: 0.8179 - val_acc: 0.7696 Epoch 72/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5060 - acc: 0.8263 - val_loss: 0.7870 - val_acc: 0.7794 Epoch 73/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5502 - acc: 0.8221 - val_loss: 0.7749 - val_acc: 0.7549 Epoch 74/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5111 - acc: 0.8434 - val_loss: 0.7830 - val_acc: 0.7549 Epoch 75/300 5/5 [==============================] - 0s 69ms/step - loss: 0.5119 - acc: 0.8386 - val_loss: 0.8140 - val_acc: 0.7451 Epoch 76/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4922 - acc: 0.8433 - val_loss: 0.8149 - val_acc: 0.7353 Epoch 77/300 5/5 [==============================] - 0s 71ms/step - loss: 0.5217 - acc: 0.8188 - val_loss: 0.7784 - val_acc: 0.7598 Epoch 78/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5027 - acc: 0.8410 - val_loss: 0.7660 - val_acc: 0.7696 Epoch 79/300 5/5 [==============================] - 0s 67ms/step - loss: 0.5307 - acc: 0.8265 - val_loss: 0.7217 - val_acc: 0.7696 Epoch 80/300 5/5 [==============================] - 0s 68ms/step - loss: 0.5164 - acc: 0.8239 - val_loss: 0.6974 - val_acc: 0.7647 Epoch 81/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4404 - acc: 0.8526 - val_loss: 0.6891 - val_acc: 0.7745 Epoch 82/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4565 - acc: 0.8449 - val_loss: 0.6839 - val_acc: 0.7696 Epoch 83/300 5/5 [==============================] - 0s 67ms/step - loss: 0.4759 - acc: 0.8491 - val_loss: 0.7162 - val_acc: 0.7745 Epoch 84/300 5/5 [==============================] - 0s 70ms/step - loss: 0.5154 - acc: 0.8476 - val_loss: 0.7889 - val_acc: 0.7598 Epoch 85/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4847 - acc: 0.8480 - val_loss: 0.7579 - val_acc: 0.7794 Epoch 86/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4519 - acc: 0.8592 - val_loss: 0.7056 - val_acc: 0.7941 Epoch 87/300 5/5 [==============================] - 0s 67ms/step - loss: 0.5038 - acc: 0.8472 - val_loss: 0.6725 - val_acc: 0.7794 Epoch 88/300 5/5 [==============================] - 0s 92ms/step - loss: 0.4729 - acc: 0.8454 - val_loss: 0.7057 - val_acc: 0.7745 Epoch 89/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4811 - acc: 0.8562 - val_loss: 0.6784 - val_acc: 0.7990 Epoch 90/300 5/5 [==============================] - 0s 70ms/step - loss: 0.4102 - acc: 0.8779 - val_loss: 0.6383 - val_acc: 0.8039 Epoch 91/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4493 - acc: 0.8703 - val_loss: 0.6574 - val_acc: 0.7941 Epoch 92/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4560 - acc: 0.8610 - val_loss: 0.6764 - val_acc: 0.7941 Epoch 93/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4465 - acc: 0.8626 - val_loss: 0.6628 - val_acc: 0.7892 Epoch 94/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4773 - acc: 0.8446 - val_loss: 0.6573 - val_acc: 0.7941 Epoch 95/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4313 - acc: 0.8734 - val_loss: 0.6875 - val_acc: 0.7941 Epoch 96/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4668 - acc: 0.8598 - val_loss: 0.6712 - val_acc: 0.8039 Epoch 97/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4329 - acc: 0.8696 - val_loss: 0.6274 - val_acc: 0.8088 Epoch 98/300 5/5 [==============================] - 0s 71ms/step - loss: 0.4223 - acc: 0.8542 - val_loss: 0.6259 - val_acc: 0.7990 Epoch 99/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4677 - acc: 0.8488 - val_loss: 0.6431 - val_acc: 0.8186 Epoch 100/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3933 - acc: 0.8753 - val_loss: 0.6559 - val_acc: 0.8186 Epoch 101/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3945 - acc: 0.8777 - val_loss: 0.6461 - val_acc: 0.8186 Epoch 102/300 5/5 [==============================] - 0s 70ms/step - loss: 0.4671 - acc: 0.8324 - val_loss: 0.6607 - val_acc: 0.7990 Epoch 103/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3890 - acc: 0.8762 - val_loss: 0.6792 - val_acc: 0.7941 Epoch 104/300 5/5 [==============================] - 0s 67ms/step - loss: 0.4336 - acc: 0.8646 - val_loss: 0.6854 - val_acc: 0.7990 Epoch 105/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4304 - acc: 0.8651 - val_loss: 0.6949 - val_acc: 0.8039 Epoch 106/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4043 - acc: 0.8723 - val_loss: 0.6941 - val_acc: 0.7892 Epoch 107/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4043 - acc: 0.8713 - val_loss: 0.6798 - val_acc: 0.8088 Epoch 108/300 5/5 [==============================] - 0s 70ms/step - loss: 0.4647 - acc: 0.8599 - val_loss: 0.6726 - val_acc: 0.8039 Epoch 109/300 5/5 [==============================] - 0s 73ms/step - loss: 0.3916 - acc: 0.8820 - val_loss: 0.6680 - val_acc: 0.8137 Epoch 110/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3990 - acc: 0.8875 - val_loss: 0.6580 - val_acc: 0.8137 Epoch 111/300 5/5 [==============================] - 0s 95ms/step - loss: 0.4240 - acc: 0.8786 - val_loss: 0.6487 - val_acc: 0.8137 Epoch 112/300 5/5 [==============================] - 0s 67ms/step - loss: 0.4050 - acc: 0.8633 - val_loss: 0.6471 - val_acc: 0.8186 Epoch 113/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4120 - acc: 0.8522 - val_loss: 0.6375 - val_acc: 0.8137 Epoch 114/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3802 - acc: 0.8793 - val_loss: 0.6454 - val_acc: 0.8137 Epoch 115/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4073 - acc: 0.8730 - val_loss: 0.6504 - val_acc: 0.8088 Epoch 116/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3573 - acc: 0.8948 - val_loss: 0.6501 - val_acc: 0.7990 Epoch 117/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4238 - acc: 0.8611 - val_loss: 0.7339 - val_acc: 0.7843 Epoch 118/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3565 - acc: 0.8832 - val_loss: 0.7533 - val_acc: 0.7941 Epoch 119/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3863 - acc: 0.8834 - val_loss: 0.7470 - val_acc: 0.8186 Epoch 120/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3935 - acc: 0.8768 - val_loss: 0.6778 - val_acc: 0.8333 Epoch 121/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3745 - acc: 0.8862 - val_loss: 0.6741 - val_acc: 0.8137 Epoch 122/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4152 - acc: 0.8647 - val_loss: 0.6594 - val_acc: 0.8235 Epoch 123/300 5/5 [==============================] - 0s 64ms/step - loss: 0.3987 - acc: 0.8813 - val_loss: 0.6478 - val_acc: 0.8235 Epoch 124/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4005 - acc: 0.8798 - val_loss: 0.6837 - val_acc: 0.8284 Epoch 125/300 5/5 [==============================] - 0s 68ms/step - loss: 0.4366 - acc: 0.8699 - val_loss: 0.6456 - val_acc: 0.8235 Epoch 126/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3544 - acc: 0.8852 - val_loss: 0.6967 - val_acc: 0.8088 Epoch 127/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3835 - acc: 0.8676 - val_loss: 0.7279 - val_acc: 0.8088 Epoch 128/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3932 - acc: 0.8723 - val_loss: 0.7471 - val_acc: 0.8137 Epoch 129/300 5/5 [==============================] - 0s 66ms/step - loss: 0.3788 - acc: 0.8822 - val_loss: 0.7028 - val_acc: 0.8284 Epoch 130/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3546 - acc: 0.8876 - val_loss: 0.6424 - val_acc: 0.8382 Epoch 131/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4244 - acc: 0.8784 - val_loss: 0.6478 - val_acc: 0.8382 Epoch 132/300 5/5 [==============================] - 0s 66ms/step - loss: 0.4120 - acc: 0.8689 - val_loss: 0.6834 - val_acc: 0.8186 Epoch 133/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3585 - acc: 0.8872 - val_loss: 0.6802 - val_acc: 0.8186 Epoch 134/300 5/5 [==============================] - 0s 71ms/step - loss: 0.3782 - acc: 0.8788 - val_loss: 0.6936 - val_acc: 0.8235 Epoch 135/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3459 - acc: 0.8776 - val_loss: 0.6776 - val_acc: 0.8431 Epoch 136/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3176 - acc: 0.9108 - val_loss: 0.6881 - val_acc: 0.8382 Epoch 137/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3205 - acc: 0.9052 - val_loss: 0.6934 - val_acc: 0.8431 Epoch 138/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4079 - acc: 0.8782 - val_loss: 0.6830 - val_acc: 0.8431 Epoch 139/300 5/5 [==============================] - 0s 71ms/step - loss: 0.3465 - acc: 0.8973 - val_loss: 0.6876 - val_acc: 0.8431 Epoch 140/300 5/5 [==============================] - 0s 95ms/step - loss: 0.3935 - acc: 0.8766 - val_loss: 0.7166 - val_acc: 0.8382 Epoch 141/300 5/5 [==============================] - 0s 71ms/step - loss: 0.3905 - acc: 0.8868 - val_loss: 0.7320 - val_acc: 0.8284 Epoch 142/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3482 - acc: 0.8887 - val_loss: 0.7575 - val_acc: 0.8186 Epoch 143/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3567 - acc: 0.8820 - val_loss: 0.7537 - val_acc: 0.8235 Epoch 144/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3427 - acc: 0.8753 - val_loss: 0.7225 - val_acc: 0.8284 Epoch 145/300 5/5 [==============================] - 0s 72ms/step - loss: 0.3894 - acc: 0.8750 - val_loss: 0.7228 - val_acc: 0.8333 Epoch 146/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3585 - acc: 0.8938 - val_loss: 0.6870 - val_acc: 0.8284 Epoch 147/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3450 - acc: 0.8830 - val_loss: 0.6666 - val_acc: 0.8284 Epoch 148/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3174 - acc: 0.8929 - val_loss: 0.6683 - val_acc: 0.8382 Epoch 149/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3357 - acc: 0.9041 - val_loss: 0.6676 - val_acc: 0.8480 Epoch 150/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3597 - acc: 0.8792 - val_loss: 0.6913 - val_acc: 0.8235 Epoch 151/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3043 - acc: 0.9093 - val_loss: 0.7146 - val_acc: 0.8039 Epoch 152/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3935 - acc: 0.8814 - val_loss: 0.6716 - val_acc: 0.8382 Epoch 153/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3200 - acc: 0.8898 - val_loss: 0.6832 - val_acc: 0.8578 Epoch 154/300 5/5 [==============================] - 0s 71ms/step - loss: 0.3738 - acc: 0.8809 - val_loss: 0.6622 - val_acc: 0.8529 Epoch 155/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3784 - acc: 0.8777 - val_loss: 0.6510 - val_acc: 0.8431 Epoch 156/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3565 - acc: 0.8962 - val_loss: 0.6600 - val_acc: 0.8333 Epoch 157/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2935 - acc: 0.9137 - val_loss: 0.6732 - val_acc: 0.8333 Epoch 158/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3130 - acc: 0.9060 - val_loss: 0.7070 - val_acc: 0.8284 Epoch 159/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3386 - acc: 0.8937 - val_loss: 0.6865 - val_acc: 0.8480 Epoch 160/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3310 - acc: 0.9038 - val_loss: 0.7082 - val_acc: 0.8382 Epoch 161/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3232 - acc: 0.8993 - val_loss: 0.7184 - val_acc: 0.8431 Epoch 162/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3062 - acc: 0.9036 - val_loss: 0.7070 - val_acc: 0.8382 Epoch 163/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3374 - acc: 0.8962 - val_loss: 0.7187 - val_acc: 0.8284 Epoch 164/300 5/5 [==============================] - 0s 94ms/step - loss: 0.3249 - acc: 0.8977 - val_loss: 0.7197 - val_acc: 0.8382 Epoch 165/300 5/5 [==============================] - 0s 69ms/step - loss: 0.4041 - acc: 0.8764 - val_loss: 0.7195 - val_acc: 0.8431 Epoch 166/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3356 - acc: 0.9015 - val_loss: 0.7114 - val_acc: 0.8333 Epoch 167/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3006 - acc: 0.9017 - val_loss: 0.6988 - val_acc: 0.8235 Epoch 168/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3368 - acc: 0.8970 - val_loss: 0.6795 - val_acc: 0.8284 Epoch 169/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3049 - acc: 0.9124 - val_loss: 0.6590 - val_acc: 0.8333 Epoch 170/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3652 - acc: 0.8900 - val_loss: 0.6538 - val_acc: 0.8431 Epoch 171/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3153 - acc: 0.9094 - val_loss: 0.6342 - val_acc: 0.8480 Epoch 172/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2881 - acc: 0.9038 - val_loss: 0.6242 - val_acc: 0.8382 Epoch 173/300 5/5 [==============================] - 0s 66ms/step - loss: 0.3764 - acc: 0.8824 - val_loss: 0.6220 - val_acc: 0.8480 Epoch 174/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3352 - acc: 0.8958 - val_loss: 0.6305 - val_acc: 0.8578 Epoch 175/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3450 - acc: 0.9026 - val_loss: 0.6426 - val_acc: 0.8578 Epoch 176/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3471 - acc: 0.8941 - val_loss: 0.6653 - val_acc: 0.8333 Epoch 177/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3373 - acc: 0.8970 - val_loss: 0.6941 - val_acc: 0.8137 Epoch 178/300 5/5 [==============================] - 0s 69ms/step - loss: 0.2986 - acc: 0.9092 - val_loss: 0.6841 - val_acc: 0.8137 Epoch 179/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3466 - acc: 0.9038 - val_loss: 0.6704 - val_acc: 0.8284 Epoch 180/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3661 - acc: 0.8998 - val_loss: 0.6995 - val_acc: 0.8235 Epoch 181/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3163 - acc: 0.8902 - val_loss: 0.6806 - val_acc: 0.8235 Epoch 182/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3278 - acc: 0.9025 - val_loss: 0.6815 - val_acc: 0.8284 Epoch 183/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3343 - acc: 0.8960 - val_loss: 0.6704 - val_acc: 0.8333 Epoch 184/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3172 - acc: 0.8906 - val_loss: 0.6434 - val_acc: 0.8333 Epoch 185/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3679 - acc: 0.8921 - val_loss: 0.6394 - val_acc: 0.8529 Epoch 186/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3030 - acc: 0.9079 - val_loss: 0.6677 - val_acc: 0.8480 Epoch 187/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3102 - acc: 0.8908 - val_loss: 0.6456 - val_acc: 0.8529 Epoch 188/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2763 - acc: 0.9140 - val_loss: 0.6151 - val_acc: 0.8431 Epoch 189/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3298 - acc: 0.8964 - val_loss: 0.6119 - val_acc: 0.8676 Epoch 190/300 5/5 [==============================] - 0s 69ms/step - loss: 0.2928 - acc: 0.9094 - val_loss: 0.6141 - val_acc: 0.8480 Epoch 191/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3066 - acc: 0.9093 - val_loss: 0.6393 - val_acc: 0.8480 Epoch 192/300 5/5 [==============================] - 0s 94ms/step - loss: 0.2988 - acc: 0.9060 - val_loss: 0.6380 - val_acc: 0.8431 Epoch 193/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3654 - acc: 0.8800 - val_loss: 0.6102 - val_acc: 0.8578 Epoch 194/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3482 - acc: 0.8981 - val_loss: 0.6396 - val_acc: 0.8480 Epoch 195/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3029 - acc: 0.9083 - val_loss: 0.6410 - val_acc: 0.8431 Epoch 196/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3276 - acc: 0.8931 - val_loss: 0.6209 - val_acc: 0.8529 Epoch 197/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3252 - acc: 0.8989 - val_loss: 0.6153 - val_acc: 0.8578 Epoch 198/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3542 - acc: 0.8917 - val_loss: 0.6079 - val_acc: 0.8627 Epoch 199/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3191 - acc: 0.9006 - val_loss: 0.6087 - val_acc: 0.8578 Epoch 200/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3077 - acc: 0.9008 - val_loss: 0.6209 - val_acc: 0.8529 Epoch 201/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3045 - acc: 0.9076 - val_loss: 0.6609 - val_acc: 0.8333 Epoch 202/300 5/5 [==============================] - 0s 71ms/step - loss: 0.3053 - acc: 0.9058 - val_loss: 0.7324 - val_acc: 0.8284 Epoch 203/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3107 - acc: 0.8985 - val_loss: 0.7755 - val_acc: 0.8235 Epoch 204/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3047 - acc: 0.8995 - val_loss: 0.7936 - val_acc: 0.7941 Epoch 205/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3131 - acc: 0.9098 - val_loss: 0.6453 - val_acc: 0.8529 Epoch 206/300 5/5 [==============================] - 0s 71ms/step - loss: 0.3795 - acc: 0.8849 - val_loss: 0.6213 - val_acc: 0.8529 Epoch 207/300 5/5 [==============================] - 0s 70ms/step - loss: 0.2903 - acc: 0.9114 - val_loss: 0.6354 - val_acc: 0.8578 Epoch 208/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2599 - acc: 0.9164 - val_loss: 0.6390 - val_acc: 0.8676 Epoch 209/300 5/5 [==============================] - 0s 71ms/step - loss: 0.2954 - acc: 0.9041 - val_loss: 0.6376 - val_acc: 0.8775 Epoch 210/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3250 - acc: 0.9023 - val_loss: 0.6206 - val_acc: 0.8725 Epoch 211/300 5/5 [==============================] - 0s 69ms/step - loss: 0.2694 - acc: 0.9149 - val_loss: 0.6177 - val_acc: 0.8676 Epoch 212/300 5/5 [==============================] - 0s 71ms/step - loss: 0.2920 - acc: 0.9054 - val_loss: 0.6438 - val_acc: 0.8627 Epoch 213/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2861 - acc: 0.9048 - val_loss: 0.7128 - val_acc: 0.8480 Epoch 214/300 5/5 [==============================] - 0s 65ms/step - loss: 0.2916 - acc: 0.9083 - val_loss: 0.7030 - val_acc: 0.8431 Epoch 215/300 5/5 [==============================] - 0s 91ms/step - loss: 0.3288 - acc: 0.8887 - val_loss: 0.6593 - val_acc: 0.8529 Epoch 216/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3802 - acc: 0.8875 - val_loss: 0.6165 - val_acc: 0.8578 Epoch 217/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2905 - acc: 0.9175 - val_loss: 0.6141 - val_acc: 0.8725 Epoch 218/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3078 - acc: 0.9104 - val_loss: 0.6158 - val_acc: 0.8676 Epoch 219/300 5/5 [==============================] - 0s 66ms/step - loss: 0.2757 - acc: 0.9214 - val_loss: 0.6195 - val_acc: 0.8578 Epoch 220/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3159 - acc: 0.8958 - val_loss: 0.6375 - val_acc: 0.8578 Epoch 221/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3348 - acc: 0.8944 - val_loss: 0.6839 - val_acc: 0.8431 Epoch 222/300 5/5 [==============================] - 0s 70ms/step - loss: 0.3239 - acc: 0.8936 - val_loss: 0.6450 - val_acc: 0.8578 Epoch 223/300 5/5 [==============================] - 0s 73ms/step - loss: 0.2783 - acc: 0.9081 - val_loss: 0.6163 - val_acc: 0.8627 Epoch 224/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2852 - acc: 0.9165 - val_loss: 0.6495 - val_acc: 0.8431 Epoch 225/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3073 - acc: 0.8902 - val_loss: 0.6622 - val_acc: 0.8529 Epoch 226/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3127 - acc: 0.9102 - val_loss: 0.6652 - val_acc: 0.8431 Epoch 227/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3248 - acc: 0.9067 - val_loss: 0.6475 - val_acc: 0.8529 Epoch 228/300 5/5 [==============================] - 0s 69ms/step - loss: 0.3155 - acc: 0.9089 - val_loss: 0.6263 - val_acc: 0.8382 Epoch 229/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3585 - acc: 0.8898 - val_loss: 0.6308 - val_acc: 0.8578 Epoch 230/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2812 - acc: 0.9180 - val_loss: 0.6201 - val_acc: 0.8529 Epoch 231/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3070 - acc: 0.8984 - val_loss: 0.6170 - val_acc: 0.8431 Epoch 232/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3433 - acc: 0.8909 - val_loss: 0.6568 - val_acc: 0.8431 Epoch 233/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2844 - acc: 0.9085 - val_loss: 0.6571 - val_acc: 0.8529 Epoch 234/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3122 - acc: 0.9044 - val_loss: 0.6516 - val_acc: 0.8480 Epoch 235/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3047 - acc: 0.9232 - val_loss: 0.6505 - val_acc: 0.8480 Epoch 236/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2913 - acc: 0.9192 - val_loss: 0.6432 - val_acc: 0.8529 Epoch 237/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2505 - acc: 0.9322 - val_loss: 0.6462 - val_acc: 0.8627 Epoch 238/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3033 - acc: 0.9085 - val_loss: 0.6378 - val_acc: 0.8627 Epoch 239/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3418 - acc: 0.8975 - val_loss: 0.6232 - val_acc: 0.8578 Epoch 240/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3167 - acc: 0.9051 - val_loss: 0.6284 - val_acc: 0.8627 Epoch 241/300 5/5 [==============================] - 0s 69ms/step - loss: 0.2637 - acc: 0.9145 - val_loss: 0.6427 - val_acc: 0.8627 Epoch 242/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2678 - acc: 0.9227 - val_loss: 0.6492 - val_acc: 0.8578 Epoch 243/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2730 - acc: 0.9113 - val_loss: 0.6736 - val_acc: 0.8578 Epoch 244/300 5/5 [==============================] - 0s 93ms/step - loss: 0.3013 - acc: 0.9077 - val_loss: 0.7138 - val_acc: 0.8333 Epoch 245/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3151 - acc: 0.9096 - val_loss: 0.7278 - val_acc: 0.8382 Epoch 246/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3307 - acc: 0.9058 - val_loss: 0.6944 - val_acc: 0.8627 Epoch 247/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2631 - acc: 0.9236 - val_loss: 0.6789 - val_acc: 0.8529 Epoch 248/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3215 - acc: 0.9027 - val_loss: 0.6790 - val_acc: 0.8529 Epoch 249/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2968 - acc: 0.9038 - val_loss: 0.6864 - val_acc: 0.8480 Epoch 250/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2998 - acc: 0.9078 - val_loss: 0.7079 - val_acc: 0.8480 Epoch 251/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2375 - acc: 0.9197 - val_loss: 0.7252 - val_acc: 0.8529 Epoch 252/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2955 - acc: 0.9178 - val_loss: 0.7298 - val_acc: 0.8284 Epoch 253/300 5/5 [==============================] - 0s 69ms/step - loss: 0.2946 - acc: 0.9039 - val_loss: 0.7172 - val_acc: 0.8284 Epoch 254/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3051 - acc: 0.9087 - val_loss: 0.6861 - val_acc: 0.8382 Epoch 255/300 5/5 [==============================] - 0s 67ms/step - loss: 0.3563 - acc: 0.8882 - val_loss: 0.6739 - val_acc: 0.8480 Epoch 256/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3144 - acc: 0.8969 - val_loss: 0.6970 - val_acc: 0.8382 Epoch 257/300 5/5 [==============================] - 0s 68ms/step - loss: 0.3210 - acc: 0.9152 - val_loss: 0.7106 - val_acc: 0.8333 Epoch 258/300 5/5 [==============================] - 0s 67ms/step - loss: 0.2523 - acc: 0.9214 - val_loss: 0.7111 - val_acc: 0.8431 Epoch 259/300 5/5 [==============================] - 0s 68ms/step - loss: 0.2552 - acc: 0.9236 - val_loss: 0.7258 - val_acc: 0.8382
学習カーブをプロットしましょう。
display_learning_curves(history)
そして GNN モデルをテストデータ分割で評価します。結果は訓練サンプルに応じて様々かもしれませんが、GNN モデルはテスト精度の点で常にベースラインモデルの性能を上回ります。
x_test = test_data.paper_id.to_numpy()
_, test_accuracy = gnn_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Test accuracy: 80.19%
GNN モデル予測を調べる
新しいインスタンスをノードとして node_features に追加します、そして既存のノードへのリンク (引用) を生成します。
# First we add the N new_instances as nodes to the graph
# by appending the new_instance to node_features.
num_nodes = node_features.shape[0]
new_node_features = np.concatenate([node_features, new_instances])
# Second we add the M edges (citations) from each new node to a set
# of existing nodes in a particular subject
new_node_indices = [i + num_nodes for i in range(num_classes)]
new_citations = []
for subject_idx, group in papers.groupby("subject"):
subject_papers = list(group.paper_id)
# Select random x papers specific subject.
selected_paper_indices1 = np.random.choice(subject_papers, 5)
# Select random y papers from any subject (where y < x).
selected_paper_indices2 = np.random.choice(list(papers.paper_id), 2)
# Merge the selected paper indices.
selected_paper_indices = np.concatenate(
[selected_paper_indices1, selected_paper_indices2], axis=0
)
# Create edges between a citing paper idx and the selected cited papers.
citing_paper_indx = new_node_indices[subject_idx]
for cited_paper_idx in selected_paper_indices:
new_citations.append([citing_paper_indx, cited_paper_idx])
new_citations = np.array(new_citations).T
new_edges = np.concatenate([edges, new_citations], axis=1)
そして GNN モデルの node_features とエッジを更新しましょう。
print("Original node_features shape:", gnn_model.node_features.shape)
print("Original edges shape:", gnn_model.edges.shape)
gnn_model.node_features = new_node_features
gnn_model.edges = new_edges
gnn_model.edge_weights = tf.ones(shape=new_edges.shape[1])
print("New node_features shape:", gnn_model.node_features.shape)
print("New edges shape:", gnn_model.edges.shape)
logits = gnn_model.predict(tf.convert_to_tensor(new_node_indices))
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
display_class_probabilities(probabilities)
Original node_features shape: (2708, 1433) Original edges shape: (2, 5429) New node_features shape: (2715, 1433) New edges shape: (2, 5478) Instance 1: - Case_Based: 4.35% - Genetic_Algorithms: 4.19% - Neural_Networks: 1.49% - Probabilistic_Methods: 1.68% - Reinforcement_Learning: 21.34% - Rule_Learning: 52.82% - Theory: 14.14% Instance 2: - Case_Based: 0.01% - Genetic_Algorithms: 99.88% - Neural_Networks: 0.03% - Probabilistic_Methods: 0.0% - Reinforcement_Learning: 0.07% - Rule_Learning: 0.0% - Theory: 0.01% Instance 3: - Case_Based: 0.1% - Genetic_Algorithms: 59.18% - Neural_Networks: 39.17% - Probabilistic_Methods: 0.38% - Reinforcement_Learning: 0.55% - Rule_Learning: 0.08% - Theory: 0.54% Instance 4: - Case_Based: 0.14% - Genetic_Algorithms: 10.44% - Neural_Networks: 84.1% - Probabilistic_Methods: 3.61% - Reinforcement_Learning: 0.71% - Rule_Learning: 0.16% - Theory: 0.85% Instance 5: - Case_Based: 0.27% - Genetic_Algorithms: 0.15% - Neural_Networks: 0.48% - Probabilistic_Methods: 0.23% - Reinforcement_Learning: 0.79% - Rule_Learning: 0.45% - Theory: 97.63% Instance 6: - Case_Based: 3.12% - Genetic_Algorithms: 1.35% - Neural_Networks: 19.72% - Probabilistic_Methods: 0.48% - Reinforcement_Learning: 39.56% - Rule_Learning: 28.0% - Theory: 7.77% Instance 7: - Case_Based: 1.6% - Genetic_Algorithms: 34.76% - Neural_Networks: 4.45% - Probabilistic_Methods: 9.59% - Reinforcement_Learning: 2.97% - Rule_Learning: 4.05% - Theory: 42.6%
(幾つかの引用が追加される) 期待される主題の確率がベースラインモデルに比べて高いことに注目してください。
以上