模型训练部署 Pipeline 介绍:以一个简单图神经网络为例
流程
-
获取图数据:通过钞能力或白嫖获取有标注的数据。
-
图数据预处理:
-
将原始数据转换成图结构,包括节点特征、边特征和图结构信息。
-
对节点特征和边特征进行归一化、标准化等预处理操作。
-
-
图神经网络模型构建:定义层和模型。
-
模型训练:设置超参数、优化器,得到训练后的权重。
-
模型评估: 在测试集上评估训练好的模型性能,如分类准确率、F1 分数等。
-
模型优化: 保存模型,对模型进行必要的优化,如量化、剪枝等, 减小模型大小并提高推理速度。
-
模型部署和维护:将模型部署到生产环境中。监控模型在生产环境中的运行状态和性能指标。
训练
我们用的是之前自己实现的 GraphConv 层。
graph_conv.py
1import torch
2from torch import Tensor
3from torch_geometric.nn import MessagePassing
4from torch_geometric.utils import add_self_loops, degree
5torch.manual_seed(42)
6
7class GraphConv(MessagePassing):
8 def __init__(self, in_channels: int, out_channels: int):
9 super(GraphConv, self).__init__(aggr='add') # "Add" aggregation.
10 self.lin = torch.nn.Linear(in_channels, out_channels)
11 self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
12 self.reset_parameters()
13
14 def reset_parameters(self):
15 torch.nn.init.xavier_uniform_(self.lin.weight)
16 torch.nn.init.zeros_(self.bias)
17
18 def forward(self, x: Tensor, edge_index: Tensor):
19 # Add self-loops to the adjacency matrix.
20 edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
21
22 # Compute normalization.
23 row, col = edge_index
24 deg = degree(col, x.size(0), dtype=x.dtype)
25 deg_inv_sqrt = deg.pow(-0.5)
26 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
27
28 # Start propagating messages.
29 return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm)
30
31 def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
32 # Normalize node features.
33 return norm.view(-1, 1) * x_j
34
35 def update(self, aggr_out: Tensor):
36 # Add bias after aggregation.
37 biased = self.lin(aggr_out) + self.bias
38 return biased
train.py
读取数据集、设置超参数、训练、评估、保存模型。
1```python
2import sys
3sys.path.append('./')
4from graph_conv import GraphConv
5
6import torch
7import torch.nn.functional as F
8from torch_geometric.datasets import Planetoid
9from torch_geometric.data import Data, Dataset
10
11class GCN(torch.nn.Module):
12 def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
13 super().__init__()
14 self.conv1 = GraphConv(in_channels, hidden_channels)
15 self.conv2 = GraphConv(hidden_channels, out_channels)
16
17 def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
18 x = self.conv1(x, edge_index)
19 x = F.relu(x)
20 x = F.dropout(x, p=0.5, training=self.training)
21 x = self.conv2(x, edge_index)
22 return F.log_softmax(x, dim=1)
23
24def train(model: GCN, data: Dataset, lr: float, epochs: int) -> None:
25 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
26
27 for epoch in range(epochs):
28 optimizer.zero_grad()
29 output = model(data.x, data.edge_index)
30 loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
31 loss.backward()
32 optimizer.step()
33
34 evaluate(model, data)
35 acc = evaluate(model, data)
36 print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {acc:.4f}')
37
38def evaluate(model: GCN, data: Data) -> float:
39 model.eval()
40 _, pred = model(data.x, data.edge_index).max(dim=1)
41 correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
42 acc = correct / int(data.test_mask.sum())
43 return acc
44
45def main():
46 dataset_path = './dataset'
47 dataset = Planetoid(root=dataset_path, name='Cora')
48 data = dataset[0]
49
50 hidden_dim = 16
51 lr = 0.001
52 epochs = 100
53
54 model = GCN(dataset.num_features, hidden_dim, dataset.num_classes)
55
56 train(model, data, lr, epochs)
57
58 test_acc = evaluate(model, data)
59 print(f'Test Accuracy: {test_acc:.4f}')
60
61 torch.save(model.state_dict(), './gcn_cora.pth')
62
63if __name__ == '__main__':
64 main()
推理模型
训练得到的 pth
模型更像一个 checkpoint
,而不是真正用来推理的模型。因此还需要导出。
导出推理模型
一般 Pytorch 使用 torch.onnx.export
可以导出为 ONNX 格式,这个格式比较通用。
1$ pip3 install onnx
1def main():
2 ...
3 export_onnx(model, data)
4
5def export_onnx(model: GCN, data: Data) -> None:
6 model.eval()
7 x = data.x
8 edge_index = data.edge_index
9 torch.onnx.export(model, (x, edge_index), 'gcn_cora.onnx', input_names=['x', 'edge_index'], output_names=['output'], opset_version=11)
10
11 print('ONNX model exported successfully')
查看模型结构
ONNX 库提供了模型结构探索的 API:
1import onnx
2import argparse
3
4if __name__ == '__main__':
5 parser = argparse.ArgumentParser()
6 parser.add_argument('--model_path', type=str, required=True)
7 args = parser.parse_args()
8
9 model = onnx.load(args.model_path)
10
11 def dim_str(dims):
12 return ', '.join(str(dim.dim_value) if dim.WhichOneof('value') == 'dim_value' else '?' for dim in dims)
13
14 print("Inputs:")
15 for input in model.graph.input:
16 print(input.name, dim_str(input.type.tensor_type.shape.dim))
17 print("Outputs:")
18 for output in model.graph.output:
19 print(output.name, dim_str(output.type.tensor_type.shape.dim))
20
21 print("Nodes:")
22 for node in model.graph.node:
23 print(node.name, node.op_type, ":", ", ".join(node.input), "=>", ", ".join(node.output))
另外最方便的就是用 onnx.helper.printable_graph
函数。
Netron,也叫 ONNX Visualizer 可以直接在浏览器中查看 ONNX 模型的结构。访问:https://netron.app
图中的节点就是 Op,除开常见的,简要介绍一些:
-
Gather: 从输入张量中根据索引张量提取元素。常用于从大张量中提取部分元素,如从一个批次的数据中提取某些样本。
-
输入: 一个数据张量和一个索引张量
-
输出:根据索引提取的子张量。
-
-
Expand: 通过复制沿指定维度扩展输入张量的大小。常用于将小张量扩展到与其他张量匹配的大小,以便进行后续的计算。
-
输入: 一个张量和一个目标形状
-
输出:扩展后的张量。
-
-
Gemm: 执行通用矩阵乘法 (General Matrix Multiplication)。是神经网络中常见的线性变换操作。
-
输入: 两个矩阵和一个可选的偏置向量
-
输出:矩阵乘法的结果。
-
-
Squeeze: 从输入张量中删除所有大小为 1 的维度。常用于去除不必要的维度,以便于后续的计算和处理。
-
输入:一个张量
-
输出:删除了大小为 1 的维度的张量。
-
-
ScatterElements: 根据索引张量更新输入张量的元素。常用于有选择地更新张量中的部分元素,如更新神经网络的参数。
-
输入: 一个数据张量、一个索引张量和一个更新张量
-
输出:更新后的数据张量。
-
-
Tile: 通过复制沿指定维度扩展输入张量。与 Expand 算子类似,但 Tile 是通过复制来扩展,而 Expand 是通过复制和填充。
-
输入: 一个张量和一个重复次数张量
-
输出:扩展后的张量。
-
模型小型化
之前的步骤我们得知了模型的结构如下:
1graph torch_jit (
2 %x[FLOAT, 2708x1433]
3 %edge_index[INT64, 2x10556]
4) initializers (
5 %conv1.bias[FLOAT, 16]
6 %conv1.lin.weight[FLOAT, 16x1433]
7 %conv1.lin.bias[FLOAT, 16]
8 %conv2.bias[FLOAT, 7]
9 %conv2.lin.weight[FLOAT, 7x16]
10 %conv2.lin.bias[FLOAT, 7]
11) {
12 %/conv1/Constant_output_0 = Constant[value = <Scalar Tensor []>]()
13 %/conv1/Constant_1_output_0 = Constant[value = <Tensor>]()
14 %onnx::Tile_10 = Constant[value = <Tensor>]()
15 %/conv1/Constant_2_output_0 = Constant[value = <Tensor>]()
16 %/conv1/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/conv1/Constant_2_output_0)
17 %/conv1/Expand_output_0 = Expand(%/conv1/Constant_1_output_0, %/conv1/ConstantOfShape_output_0)
18 %/conv1/Tile_output_0 = Tile(%/conv1/Expand_output_0, %onnx::Tile_10)
19 %/conv1/Constant_3_output_0 = Constant[value = <Scalar Tensor []>]()
20 %/conv1/Concat_output_0 = Concat[axis = 1](%edge_index, %/conv1/Tile_output_0)
21 %/conv1/Split_output_0, %/conv1/Split_output_1 = Split[axis = 0, split = [1, 1]](%/conv1/Concat_output_0)
22 %/conv1/Squeeze_output_0 = Squeeze[axes = [0]](%/conv1/Split_output_0)
23 %/conv1/Squeeze_1_output_0 = Squeeze[axes = [0]](%/conv1/Split_output_1)
24 %/conv1/Constant_4_output_0 = Constant[value = <Tensor>]()
25 %/conv1/Constant_5_output_0 = Constant[value = <Tensor>]()
26 %/conv1/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv1/Constant_5_output_0, %/conv1/Squeeze_1_output_0, %/conv1/Constant_4_output_0)
27 %/conv1/Constant_6_output_0 = Constant[value = <Tensor>]()
28 %/conv1/Add_output_0 = Add(%/conv1/Constant_6_output_0, %/conv1/ScatterElements_output_0)
29 %/conv1/Constant_7_output_0 = Constant[value = <Scalar Tensor []>]()
30 %/conv1/Pow_output_0 = Pow(%/conv1/Add_output_0, %/conv1/Constant_7_output_0)
31 %/conv1/Gather_output_0 = Gather[axis = 0](%/conv1/Pow_output_0, %/conv1/Squeeze_output_0)
32 %/conv1/Gather_1_output_0 = Gather[axis = 0](%/conv1/Pow_output_0, %/conv1/Squeeze_1_output_0)
33 %/conv1/Mul_output_0 = Mul(%/conv1/Gather_output_0, %/conv1/Gather_1_output_0)
34 %/conv1/Gather_2_output_0 = Gather[axis = 0](%/conv1/Concat_output_0, %/conv1/Constant_3_output_0)
35 %/conv1/Gather_3_output_0 = Gather[axis = 0](%/conv1/Concat_output_0, %/conv1/Constant_output_0)
36 %/conv1/Gather_4_output_0 = Gather[axis = -2](%x, %/conv1/Gather_3_output_0)
37 %/conv1/Constant_8_output_0 = Constant[value = <Tensor>]()
38 %/conv1/Reshape_output_0 = Reshape(%/conv1/Mul_output_0, %/conv1/Constant_8_output_0)
39 %/conv1/Mul_1_output_0 = Mul(%/conv1/Reshape_output_0, %/conv1/Gather_4_output_0)
40 %/conv1/aggr_module/Constant_output_0 = Constant[value = <Tensor>]()
41 %/conv1/aggr_module/Reshape_output_0 = Reshape(%/conv1/Gather_2_output_0, %/conv1/aggr_module/Constant_output_0)
42 %/conv1/aggr_module/Shape_output_0 = Shape(%/conv1/Mul_1_output_0)
43 %/conv1/aggr_module/Expand_output_0 = Expand(%/conv1/aggr_module/Reshape_output_0, %/conv1/aggr_module/Shape_output_0)
44 %/conv1/aggr_module/Constant_1_output_0 = Constant[value = <Tensor>]()
45 %/conv1/aggr_module/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv1/aggr_module/Constant_1_output_0, %/conv1/aggr_module/Expand_output_0, %/conv1/Mul_1_output_0)
46 %/conv1/aggr_module/Constant_2_output_0 = Constant[value = <Tensor>]()
47 %/conv1/aggr_module/Add_output_0 = Add(%/conv1/aggr_module/Constant_2_output_0, %/conv1/aggr_module/ScatterElements_output_0)
48 %/conv1/lin/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%/conv1/aggr_module/Add_output_0, %conv1.lin.weight, %conv1.lin.bias)
49 %/conv1/Add_1_output_0 = Add(%/conv1/lin/Gemm_output_0, %conv1.bias)
50 %/Relu_output_0 = Relu(%/conv1/Add_1_output_0)
51 %/conv2/Constant_output_0 = Constant[value = <Tensor>]()
52 %/conv2/Constant_1_output_0 = Constant[value = <Tensor>]()
53 %/conv2/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/conv2/Constant_1_output_0)
54 %/conv2/Expand_output_0 = Expand(%/conv2/Constant_output_0, %/conv2/ConstantOfShape_output_0)
55 %/conv2/Tile_output_0 = Tile(%/conv2/Expand_output_0, %onnx::Tile_10)
56 %/conv2/Concat_output_0 = Concat[axis = 1](%edge_index, %/conv2/Tile_output_0)
57 %/conv2/Split_output_0, %/conv2/Split_output_1 = Split[axis = 0, split = [1, 1]](%/conv2/Concat_output_0)
58 %/conv2/Squeeze_output_0 = Squeeze[axes = [0]](%/conv2/Split_output_0)
59 %/conv2/Squeeze_1_output_0 = Squeeze[axes = [0]](%/conv2/Split_output_1)
60 %/conv2/Constant_2_output_0 = Constant[value = <Tensor>]()
61 %/conv2/Constant_3_output_0 = Constant[value = <Tensor>]()
62 %/conv2/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv2/Constant_3_output_0, %/conv2/Squeeze_1_output_0, %/conv2/Constant_2_output_0)
63 %/conv2/Constant_4_output_0 = Constant[value = <Tensor>]()
64 %/conv2/Add_output_0 = Add(%/conv2/Constant_4_output_0, %/conv2/ScatterElements_output_0)
65 %/conv2/Constant_5_output_0 = Constant[value = <Scalar Tensor []>]()
66 %/conv2/Pow_output_0 = Pow(%/conv2/Add_output_0, %/conv2/Constant_5_output_0)
67 %/conv2/Gather_output_0 = Gather[axis = 0](%/conv2/Pow_output_0, %/conv2/Squeeze_output_0)
68 %/conv2/Gather_1_output_0 = Gather[axis = 0](%/conv2/Pow_output_0, %/conv2/Squeeze_1_output_0)
69 %/conv2/Mul_output_0 = Mul(%/conv2/Gather_output_0, %/conv2/Gather_1_output_0)
70 %/conv2/Gather_2_output_0 = Gather[axis = 0](%/conv2/Concat_output_0, %/conv1/Constant_3_output_0)
71 %/conv2/Gather_3_output_0 = Gather[axis = 0](%/conv2/Concat_output_0, %/conv1/Constant_output_0)
72 %/conv2/Gather_4_output_0 = Gather[axis = -2](%/Relu_output_0, %/conv2/Gather_3_output_0)
73 %/conv2/Constant_6_output_0 = Constant[value = <Tensor>]()
74 %/conv2/Reshape_output_0 = Reshape(%/conv2/Mul_output_0, %/conv2/Constant_6_output_0)
75 %/conv2/Mul_1_output_0 = Mul(%/conv2/Reshape_output_0, %/conv2/Gather_4_output_0)
76 %/conv2/aggr_module/Constant_output_0 = Constant[value = <Tensor>]()
77 %/conv2/aggr_module/Reshape_output_0 = Reshape(%/conv2/Gather_2_output_0, %/conv2/aggr_module/Constant_output_0)
78 %/conv2/aggr_module/Shape_output_0 = Shape(%/conv2/Mul_1_output_0)
79 %/conv2/aggr_module/Expand_output_0 = Expand(%/conv2/aggr_module/Reshape_output_0, %/conv2/aggr_module/Shape_output_0)
80 %/conv2/aggr_module/Constant_1_output_0 = Constant[value = <Tensor>]()
81 %/conv2/aggr_module/ScatterElements_output_0 = ScatterElements[axis = 0](%/conv2/aggr_module/Constant_1_output_0, %/conv2/aggr_module/Expand_output_0, %/conv2/Mul_1_output_0)
82 %/conv2/aggr_module/Constant_2_output_0 = Constant[value = <Tensor>]()
83 %/conv2/aggr_module/Add_output_0 = Add(%/conv2/aggr_module/Constant_2_output_0, %/conv2/aggr_module/ScatterElements_output_0)
84 %/conv2/lin/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transB = 1](%/conv2/aggr_module/Add_output_0, %conv2.lin.weight, %conv2.lin.bias)
85 %/conv2/Add_1_output_0 = Add(%/conv2/lin/Gemm_output_0, %conv2.bias)
86 %output = LogSoftmax[axis = 1](%/conv2/Add_1_output_0)
87 return %output
88}
这是一个两层的图卷积神经网络。第一层包含一个图卷积层和一个全连接层,第二层也包含一个图卷积层和一个全连接层,最后接一个 LogSoftmax 作为输出层。
常见的小型化(压缩)技术:
-
参数量剪枝
-
对全连接层的权重矩阵进行剪枝,移除一些绝对值较小的权重
-
对图卷积层的权重向量进行剪枝
-
-
网络结构剪枝
- 移除其中一层的图卷积层和全连接层,从而减少模型深度
-
量化
- 将权重和激活从 FP32 量化到 INT8 等更低比特的定点数
下面我们实现这些操作。由于封装的很好,不需要知道原理,调包就行:
1from __future__ import annotations
2import argparse
3import onnxruntime
4from onnxruntime.quantization import quantize_dynamic, QuantType
5from onnxruntime.quantization.shape_inference import quant_pre_process
6from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession
7from torch_geometric.datasets import Planetoid
8
9def evaluate_model(model: InferenceSession, data) -> float:
10 """
11 Evaluate the accuracy of the given model on the test set.
12
13 Args:
14 model (InferenceSession): The ONNX model to be evaluated.
15 data: The dataset containing features, edge indices, labels, and test mask.
16
17 Returns:
18 float: The accuracy of the model on the test set.
19 """
20 x, edge_index, y = data.x, data.edge_index, data.y
21 test_mask = data.test_mask
22 input_name = model.get_inputs()[0].name
23 output_name = model.get_outputs()[0].name
24 output = model.run([output_name], {input_name: x.cpu().numpy(), 'edge_index': edge_index.cpu().numpy()})[0]
25
26 output_test = output[test_mask]
27 pred = output_test.argmax(axis=1)
28 correct = (pred == y[test_mask].numpy()).sum()
29 acc = correct / test_mask.sum().item()
30 return acc
31
32def load_data(dataset_path: str):
33 dataset = Planetoid(root=dataset_path, name='Cora')
34 return dataset[0]
35
36def main(model_path: str, dataset_path: str):
37 """
38 Main function to evaluate, preprocess, and quantize the ONNX model.
39
40 Args:
41 model_path (str): The path to the ONNX model file.
42 dataset_path (str): The path to the dataset directory.
43 """
44 # Load data
45 data = load_data(dataset_path)
46
47 # Evaluate original model
48 session = onnxruntime.InferenceSession(model_path)
49 original_acc = evaluate_model(session, data)
50 print(f"Original model accuracy on test set: {original_acc:.4f}")
51
52 prefix = model_path.split('.')[0]
53 # Preprocess model
54 preprocessed_model_path = prefix + '_preprocessed.onnx'
55 quant_pre_process(model_path, preprocessed_model_path)
56 preprocessed_session = onnxruntime.InferenceSession(preprocessed_model_path)
57 preprocessed_acc = evaluate_model(preprocessed_session, data)
58 print(f"Preprocessed model accuracy on test set: {preprocessed_acc:.4f}")
59
60 # Quantize model
61 quantized_model_path = prefix + '_quantized.onnx'
62 quantize_dynamic(preprocessed_model_path, quantized_model_path, weight_type=QuantType.QUInt8)
63 quantized_session = onnxruntime.InferenceSession(quantized_model_path)
64 quantized_acc = evaluate_model(quantized_session, data)
65 print(f"Quantized model accuracy on test set: {quantized_acc:.4f}")
66
67if __name__ == "__main__":
68 parser = argparse.ArgumentParser(description="Evaluate and quantize ONNX model.")
69 parser.add_argument("--model_path", required=False, type=str, help="Path to the ONNX model file.", default="gcn_cora.onnx")
70 parser.add_argument("--dataset_path", required=False, type=str, help="Path to the dataset directory.", default="./dataset")
71 args = parser.parse_args()
72
73 main(args.model_path, args.dataset_path)
部署和推理
最后可以把模型部署到具体的硬件上,并进行推理,以 ORT 为例:
session = ort.InferenceSession("optimized_model.onnx", sess_options, providers=["CUDAExecutionProvider"])
当然也可以用 NV 研发的 TensorRT。另外,目前国内有很多芯片公司研发了自己的推理芯片,也产生了大量的岗位。
ONNX 虽然通用性很好,但支持的算子和框架并不统一。每个公司都想定义自己的 IR,导致迁移工作非常困难。于是 LLVM 团队提出了 MLIR,MLIR 支持 Dialect,相当于可以把各种方言转换为 MLIR(这个转换称为 lowing),进一步可以翻译为机器码执行。
onnx/onnx-mlir 实现了将 ONNX lowering 到 MLIR.
附:DGL 版本
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from dgl.data import CoraGraphDataset
5from dgl.nn import GraphConv
6from dgl import AddSelfLoop
7
8class GCN(nn.Module):
9 def __init__(self, in_feats, h_feats, num_classes):
10 super(GCN, self).__init__()
11 self.conv1 = GraphConv(in_feats, h_feats)
12 self.conv2 = GraphConv(h_feats, num_classes)
13
14 def forward(self, g, in_feat):
15 h = self.conv1(g, in_feat)
16 h = F.relu(h)
17 h = F.dropout(h, training=self.training)
18 h = self.conv2(g, h)
19 return F.log_softmax(h, dim=1)
20
21def evaluate(model, g, features, labels, mask):
22 model.eval()
23 with torch.no_grad():
24 logits = model(g, features)
25 logits = logits[mask]
26 labels = labels[mask]
27 _, indices = torch.max(logits, dim=1)
28 correct = torch.sum(indices == labels)
29 return correct.item() * 1.0 / len(labels)
30
31def main():
32 # load and preprocess dataset
33 transform = AddSelfLoop()
34 data = CoraGraphDataset(transform=transform)
35 g = data[0]
36 features = g.ndata['feat']
37 labels = g.ndata['label']
38 train_mask = g.ndata['train_mask']
39 val_mask = g.ndata['val_mask']
40 test_mask = g.ndata['test_mask']
41 in_feats = features.shape[1]
42 n_classes = data.num_labels
43 n_edges = g.number_of_edges()
44 print("""----Data statistics------'
45 Edges %d
46 Classes %d
47 Train samples %d
48 Val samples %d
49 Test samples %d""" %
50 (n_edges, n_classes,
51 train_mask.int().sum().item(),
52 val_mask.int().sum().item(),
53 test_mask.int().sum().item()))
54
55 # create GCN model
56 model = GCN(in_feats, 16, n_classes)
57
58 # use optimizer
59 optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
60
61 # initialize graph
62 dur = []
63 for epoch in range(100):
64 model.train()
65 # forward
66 logits = model(g, features)
67 loss = F.nll_loss(logits[train_mask], labels[train_mask])
68
69 optimizer.zero_grad()
70 loss.backward()
71 optimizer.step()
72
73 acc = evaluate(model, g, features, labels, val_mask)
74 print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | "
75 .format(epoch, loss.item(), acc))
76
77 print()
78 acc = evaluate(model, g, features, labels, test_mask)
79 print("Test Accuracy {:.4f}".format(acc))
80
81if __name__ == '__main__':
82 main()