参考自 https://github.com/Ivan0131/gnn-demo/tree/master/gnn1.1.0-tensorflow1.15,本文在源代码基础上增加了测试集的划分和测试集训练精度结果的输出。代码块如下:
# 代码块开始
import tensorflow as tf
from gnn.utils import *
from gnn.models import GNN
from gnn.train import train
from gnn.test import test
# 加载数据集
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data()
# 设置模型超参数
params = {
"learning_rate": 0.01,
"hidden_dim": 16,
"dropout": 0.5,
"weight_decay": 5e-4,
"num_epochs": 200,
"early_stopping": 10,
"num_classes": y_train.shape[1]
}
# 构建模型
model = GNN(params)
# 训练模型
train(model, adj, features, y_train, train_mask, val_mask, params)
# 测试模型
test(model, adj, features, y_test, test_mask)
# 代码块结束
暂无评论