关于算法:PGL图学习之项目实践UniMP算法实现论文节点分类新冠疫苗项目实战助力疫情系列九

33次阅读

共计 11304 个字符,预计需要花费 29 分钟才能阅读完成。

原我的项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5100049?contributionType=1

1. 图学习技术与利用

图是一个简单世界的通用语言,社交网络中人与人之间的连贯、蛋白质分子、举荐零碎中用户与物品之间的连贯等等,都能够应用图来表白。图神经网络将神经网络使用至图构造中,能够被形容成消息传递的范式。百度开发了 PGL2.2,基于底层深度学习框架 paddle,给用户裸露了编程接口来实现图网络。与此同时,百度也应用了前沿的图神经网络技术针对一些利用进行模型算法的落地。本次将介绍百度的 PGL 图学习技术与利用。

1.1 图起源与建模

首先和大家分享下图学习支流的图神经网络建模形式。

14 年左右开始,学术界呈现了一些基于图谱合成的技术,通过频域变换,将图变换至频域进行解决,再将处理结果变换回空域来失去图上节点的示意。起初,空域卷积借鉴了图像的二维卷积,并逐步取代了频域图学习办法。图构造上的卷积是对节点街坊的聚合。

基于空间的图神经网络次要须要思考两个问题:

  • 怎么表白节点特色;
  • 怎么表白一整张图。

第一个问题能够应用街坊聚合的办法,第二问题应用节点聚合来解决。

目前大部分支流的图神经网络都能够形容成消息传递的模式。须要思考节点如何将音讯发送至指标节点,而后指标节点如何对收到的节点特色进行接管。

1.2 PGL2.2 回顾介绍

PGL2.2 基于消息传递的思路构建整体框架。PGL 最底层是飞浆外围 paddle 深度学习框架。在此之上,搭建了 CPU 图引擎和 GPU 上进行 tensor 化的图引擎,来不便对图进行如图切分、图存储、图采样、图游走的算法。再上一层,会对用户裸露一些编程接口,包含底层的消息传递接口和图网络实现接口,以及高层的同构图、异构图的编程接口。框架顶层会反对几大类图模型,包含传统图示意学习中的图游走模型、消息传递类模型、常识嵌入类模型等,去撑持上游的利用场景。

最后的 PGL 是基于 paddle1.x 的版本进行开发的,所以那时候还是像 tensorflow 一样的动态图模式。目前 paddle2.0 曾经进行了全面动态化,那么 PGL 也相应地做了动态图的降级。当初去定义一个图神经网络就只须要定义节点数量、边数量以及节点特色,而后将图 tensor 化即可。能够自定义如何将音讯进行发送以及指标节点如何接管音讯。

上图是应用 PGL 构建一个 GAT 网络的例子。最开始会去计算节点的权重,在发送音讯的时候 GAT 会将原节点和指标节点特色进行求和,再加上一个非线性激活函数。在接管的时候,能够通过 reduce_softmax 对边上的权重进行归一化,再乘上 hidden state 进行加权求和。这样就能够很不便地实现一个 GAT 网络。

对于图神经网络来讲,在构建完网络后,要对它进行训练。训练形式和个别机器学习有所不同,须要依据图的规模抉择实用的训练计划。

例如在小图,即图规模小于 GPU 显存的状况下,会应用 full batch 模式进行训练。它其实就是把一整张图的所有节点都搁置在 GPU 上,通过一个图网络来输入所有点的特色。它的益处在于能够跑一个很深的图。这一训练计划会被利用于中小型数据集,例如 Cora、Pubmed、Citeseer、ogbn-arxiv 等。最近在 ICML 上发现了能够重叠至 1000 层的图神经网络,同样也是在这种中小型数据集上做评估。

对于中等规模的图,即图规模大于 GPU 单卡显存,常识能够进行分片训练,每一次将一张子图塞入 GPU 上。PGL 提供了另一个计划,应用分片技术来升高显存应用的峰值。例如对一个简单图进行计算时,它的计算复杂度取决于边计算时显存应用的峰值,此时如果有多块 GPU 就能够把边计算进行分块,每台机器只负责一小部分的计算,这样就能够大大地缩小图神经网络的计算峰值,从而达到更深的图神经网络的训练。分块训练结束后,须要通过 NCCL 来同步节点特色。

在 PGL 中,只须要一行 DistGPUGraph 命令就能够在原来 full batch 的训练代码中退出这样一个新个性,使得能够在多 GPU 中运行一个深层图神经网络。例如在 obgn-arxiv 中尝试了比较复杂的 TransformerConv 网络,如果应用单卡训练一个三层网络,其 GPU 显存会被占用近 30G,而应用分片训练就能够将它的显存峰值升高。同时,还实现了并行的计算减速,例如原来跑 100 epoch 须要十分钟,当初只须要 200 秒。

在大图的状况下,又回归到平时做数据并行的 mini batch 模式。Mini batch 与 full batch 相比最次要的问题在于它须要做街坊的采样,而街坊数目的晋升会对模型的深度进行限度。这一模式实用于一些巨型数据集,包含 ogbn-products 和 ogbn-papers100m。

发现 PyG 的作者的新工作 GNNAutoScale 可能把一个图神经网络进行主动的深度扩大。它的次要思路是利用 CPU 的缓存技术,将街坊节点的特色缓存至 CPU 内存中。当训练图网络时,能够不必实时获取所有街坊的最新表白,而是获取它的历史 embedding 进行街坊聚合计算。试验发现这样做的成果还是不错的。

在工业界的状况下可能会存在更大的图规模的场景,那么这时候可能单 CPU 也存不下如此图规模的数据,这时须要一个分布式的多机存储和采样。PGL 有一套分布式的图引擎接口,使得能够轻松地在 MPI 以及 K8S 集群上通过 PGL launch 接口进行一键的分布式图引擎部署。目前也反对不同类型的街坊采样、节点遍历和图游走算法。

整体的大规模训练形式包含一个大规模分布式图引擎,两头会蕴含一些图采样的算子和神经网络的开发算子。顶层针对工业界大规模场景,往往须要一个 parameter server 来存储上亿级别的稠密特色。借助 paddlefleet 的大规模参数服务器来反对超大规模的 embedding 存储。

1.3 图神经网络技术

1.3.1 节点分类工作

在算法上也进行了一些钻研。图神经网络与个别机器学习场景有很大的区别。个别的机器学习假如数据之间独立同散布,然而在图网络的场景下,样本是有关联的。预测样本和训练样本有时会存在边关系。通常称这样的工作为半监督节点分类问题。

解决节点分类问题的传统办法是 LPA 标签流传算法,思考链接关系以及标签之间的关系。另外一类办法是以 GCN 为代表的特色流传算法,只思考特色与链接的关系。

通过试验发现在很多数据集下,训练集很难通过过拟合达到 99% 的分类准确率。也就是说,训练集中的特色其实蕴含很大的噪声,使得网络不足过拟合能力。所以,想要显示地将训练 label 退出模型,因为标签能够消减大部分歧义。在训练过程中,为了防止标签泄露,提出了 UniMP 算法,把标签流传和特色流传交融起来。这一办法在三个 open graph benchmark 数据集上获得了 SOTA 的后果。

后续还把 UniMP 利用到更大规模的 KDDCup 21 的较量中,将 UniMP 同构算法做了异构图的拓展,使其在异构图场景下进行分类工作。具体地,在节点街坊采样、批归一化和注意力机制中思考节点之间的关系类型。

1.3.2 链接预测工作

第二个比拟经典的工作是链接预测工作。目前很多人尝试应用 GNN 与 link prediction 进行交融,然而这存在两个瓶颈。首先,GNN 的深度和街坊采样的数量无关;其次,当训练像常识图谱的工作时,每一轮训练都须要遍历训练集的三元组,此时训练的复杂度和街坊节点数量存在线性关系,这就导致了如果街坊比拟多,训练一个 epoch 的耗时很长。

借鉴了最近基于纯特色流传的算法,如 SGC 等图神经网络的简化形式,提出了基于关系的 embedding 流传。发现独自应用 embedding 进行特色流传在常识图谱上是行不通的。因为常识图谱上存在简单的边关系。所以,依据不同关系下 embedding 设计了不同的 score function 进行特色流传。此外,发现之前有一篇论文提出了 OTE 的算法,在图神经网络上进行了两阶段的训练。

应用 OGBL-WikiKG2 数据集训练 OTE 模型须要超过 100 个小时,而如果切换到的特色流传算法,即先跑一次 OTE 算法,再进行 REP 特色流传,只须要 1.7 个小时就能够使模型收敛。所以 REP 带来了近 50 倍的训练效率的晋升。还发现只须要正确设定 score function,大部分常识图谱算法应用的特色流传算法都会有成果上的晋升;不同的算法应用 REP 也能够减速它们的收敛。

将这一套办法利用到 KDDCup 21 Wiki90M 的较量中。为了实现较量中要求的超大规模常识图谱的示意,做了一套大规模的常识示意工具 Graph4KG,最终在 KDDCup 中获得了冠军。

1.4 算法利用落地

PGL 在百度外部曾经进行了广泛应用。包含百度搜寻中的网页品质评估,会把网页形成一个动态图,并在图上进行图分类的工作。百度搜寻还应用 PGL 进行网页反作弊,即对大规模节点进行检测。在文本检索利用中,尝试应用图神经网络与自然语言解决中的语言模型相结合。在其余状况下,的落地场景有举荐零碎、风控、百度地图中的流量预测、POI 检索等。

本文以举荐零碎为例,介绍一下平时如何将图神经网络在利用中进行落地。

举荐零碎罕用的算法是基于 item-based 和 user-based 协同过滤算法。Item-based 协同过滤就是举荐和 item 类似的内容,而 user-based 就是举荐类似的用户。这里最重要的是如何去掂量物品与物品之间、用户与用户之间的相似性。

能够将其与图学习联合,应用点击日志来结构图关系(包含社交关系、用户行为、物品关联),而后通过示意学习结构用户物品的向量空间。在这个空间上就能够度量物品之间的相似性,以及用户之间的相似性,进而应用其进行举荐。

罕用的办法有传统的矩阵合成办法,和阿里提出的基于随机游走 + Word2Vec 的 EGES 算法。近几年衰亡了应用图比照学习来取得节点示意。

在举荐算法中,次要的需要是反对简单的构造,反对大规模的实现和疾速的试验老本。心愿有一个工具包能够解决 GNN + 示意学习的问题。所以,对现有的图示意学习算法进行了形象。具体地,将图示意学习分成了四个局部。第一局部是图的类型,将其分为同构图、异构图、二部图,并在图中定义了多种关系,例如点击关系、关注关系等。第二,实现了不同的样本采样的办法,包含在同构图中罕用的 node2Vec 以及异构图中依照用户自定义的 meta path 进行采样。第三局部是节点的示意。能够依据 id 去示意节点,也能够通过图采样应用子图来示意一个节点。还结构了四种 GNN 的聚合形式。

发现不同场景以及不同的图示意的训练形式下,模型成果差别较大。所以的工具还反对大规模稠密特色 side-info 的反对来进行更丰盛的特色组合。用户可能有很多不同的字段,有些字段可能是缺失的,此时只须要通过一个配置表来配置节点蕴含的特色以及字段即可。还反对 GNN 的异构图主动扩大。你能够自定义边关系,如点击关系、购买关系、关注关系等,并选取适合的聚合形式,如 lightgcn,就能够主动的对 GNN 进行异构图扩大,使 lightgcn 变为 relation-wise 的 lightgcn。

对工具进行了瓶颈剖析,发现它次要集中在分布式训练中图采样和负样本结构中。能够通过应用 In-Batch Negative 的办法进行优化,即在 batch 内走负采样,缩小通信开销。这一优化能够使得训练速度晋升四至五倍,而且在训练成果上简直是无损的。此外,在图采样中能够通过对样本重构来升高采样的次数,失去两倍左右的速度晋升,且训练成果根本持平。相比于市面上现有的分布式图示意工具,还能够实现单机、双机、四机甚至更多机器的扩大。

不仅如此,还发现游走类模型训练速度较快,比拟适宜作为优良的热启动参数。具体地,能够先运行一次 metapath2Vce 算法,将训练失去的 embedding 作为初始化参数送入 GNN 中作为热启动的节点示意。发现这样做在成果上有肯定的晋升。

1.5 Q&A

Q1:在特色在多卡之间传递的训练模式中,应用 push 和 pull 的形式通信工夫占比大略有多大?

A:通信工夫的占比挺大的。如果是特地简略的模型,如 GCN 等,那么应用这种办法训练,通信工夫甚至会比间接跑这个模型的训练工夫还要久。所以这一办法适宜简单模型,即模型计算较多,且通信中特色传递的数据量相比来说较小,这种状况下就比拟适宜这种分布式计算。

Q2:图学习中节点街坊数较多会不会导致特色过平滑?

A:这里采纳的办法很多时候都很暴力,即间接应用 attention 加多头的机制,这样会极大地减缓过平滑问题。因为应用 attention 机制会使得大量特色被 softmax 激活;多头的形式能够使得每个头学到的激活特色不一样。所以这样做肯定比间接应用 GCN 进行聚合会好。

Q3:百度有没有应用图学习在自然语言解决畛域的成功经验?

A:之前有相似的工作,你能够关注 ERINESage 这篇论文。它次要是将图网络和预训练语言模型进行联合。也将图神经网络落地到了例如搜寻、举荐的场景。因为语言模型自身很难对用户日志中蕴含的点击关系进行建模,通过图神经网络就能够将点击日志中的后验关系融入语言模型,进而失去较大的晋升。

Q4:能具体介绍一下 KDD 较量中将同构图拓展至异构图的 UniMP 办法吗?

A:首先,每一个关系类型其实应该有不同的街坊采样办法。例如 paper 到 author 的关系,会独自地依据它来采样街坊节点。如果依照同构图的形式来采样,指标节点的街坊节点可能是论文,也可能是作者或者机构,那么采样的节点是不平均的。其次,在批归一化中依照关系 channel 来进行归一化,因为如果你将 paper 节点和 author 节点同时归一化,因为它们的统计均值和方差不一样,那么这种做法会把两者的统计量同时带骗。同理,在聚合操作中,不同的关系对两个节点的作用不同,须要依照不同关系应用不同的 attention 注意力权重来聚合特色。

2. 基于 UniMP 算法实现论文援用网络节点分类工作

图学习之基于 PGL-UniMP 算法的论文援用网络节点分类工作:https://aistudio.baidu.com/aistudio/projectdetail/5116458?contributionType=1

因为文章篇幅问题,为了让学习者有更好的体验,这里新开一个我的项目实现这个工作。

Epoch 987 Train Acc 0.7554459 Valid Acc 0.7546095
Epoch 988 Train Acc 0.7537374 Valid Acc 0.75717235
Epoch 989 Train Acc 0.75497127 Valid Acc 0.7573859
Epoch 990 Train Acc 0.7611409 Valid Acc 0.75653166
Epoch 991 Train Acc 0.75316787 Valid Acc 0.75489426
Epoch 992 Train Acc 0.749561 Valid Acc 0.7547519
Epoch 993 Train Acc 0.7571544 Valid Acc 0.7551079
Epoch 994 Train Acc 0.7516492 Valid Acc 0.75581974
Epoch 995 Train Acc 0.7563476 Valid Acc 0.7563181
Epoch 996 Train Acc 0.7504627 Valid Acc 0.7538976
Epoch 997 Train Acc 0.7476152 Valid Acc 0.75439596
Epoch 998 Train Acc 0.7539272 Valid Acc 0.7528298
Epoch 999 Train Acc 0.7532153 Valid Acc 0.75396883

3. 新冠疫苗我的项目实战,助力疫情

Kaggle 新冠疫苗研发比赛:https://www.kaggle.com/c/stan…

mRNA 疫苗曾经成为 2019 冠状病毒最快的候选疫苗,但目前它们面临着要害的潜在限度。目前最大的挑战之一是如何设计超稳定的 RNA 分子(mRNA)。传统疫苗是装在注射器里通过冷藏运输到世界各地,但 mRNA 疫苗目前还不可能做到这一点。

钻研人员曾经察看到 RNA 分子有降解的偏向。这是一个重大的限度,降解会使 mRNA 疫苗生效。目前,对于特定 RNA 的骨干中哪个部位最容易受影响的细节知之甚少。在不理解这些状况的状况下,目前针对 COVID-19 的 mRNA 疫苗必须在高度冷藏条件下筹备和运输,它们必须可能失去稳固,否则不太可能送达地球上的每个人。

由斯坦福大学医学院 (Stanford’s School of Medicine) 计算生物学家瑞朱·达斯 (Rhiju Das) 传授领导的永恒星系 (Eterna) 社区将科学家和比赛玩家汇集在一起,解决谜题并创造药物。Eterna 是一款在线比赛平台,通过谜题挑战玩家解决诸如 mRNA 设计等迷信问题。由斯坦福大学的钻研人员合成并进行试验测试,以取得对于 RNA 分子的新见解。Eterna 社区之前曾经开启了新的迷信原理,对致命疾病做出了新的诊断,并利用世界上最弱小的智力资源改善公众生存。Eterna 社区通过其在 20 多份出版物上的奉献推动了生物技术,包含 RNA 生物技术停顿。

在这次比赛中,咱们心愿利用 Kaggle 社区的数据迷信专业知识来开发模型和设计 RNA 降解规定。模型将预测 RNA 分子每个碱基的可能降解率,训练的对象是由超过 3000 个 RNA 分子组成的 Eterna 数据集子集(它们逾越了一整套序列和构造),以及它们在每个地位的降解率。而后,咱们将依据 Eterna 玩家刚刚为 COVID-19 mRNA 疫苗设计的第二代 RNA 序列为模型评分。这些最终的测试序列目前正在合成和试验表征在斯坦福大学与建模工作并行——天然将评分模型!

进步 mRNA 疫苗的稳定性曾经在摸索,咱们必须解决这一粗浅的迷信挑战,以减速 mRNA 疫苗钻研,并提供一种针对 COVID-19 背地病毒 SARS-CoV- 2 的冰箱稳固疫苗。咱们正在试图解决的问题心愿失去学术实验室、工业研发团队和超级计算机的帮忙,你能够退出电子比赛玩家、科学家和开发者的团队,在 Eterna 永恒星球上反抗这一毁灭性病毒。

3.1 案例简介

将编码的 DNA 送到细胞中,细胞应用 mRNA(Messenger RNA)组装蛋白,免疫系统检测到组装蛋白质当前,利用构建病毒蛋白的编码基因激活免疫系统产生抗体,加强针对冠状病毒的抵挡能力。

不同的 mRNA 生成同一个蛋白质,

mRNA 随着工夫的流逝及温度的变动产生了降解,

如何找到构造更加稳固的 mRNA?利用图神经网络找到更稳固的 mRNA, 色彩越深越稳固.

3.2 新冠疫苗我的项目拔高实战

数据分布特色

查看以后挂载的数据集目录

# 加载一些须要用到的模块,设置随机数
import json
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import networkx as nx

from utils.config import prepare_config, make_dir
from utils.logger import prepare_logger, log_to_file
from data_parser import GraphParser

seed = 123
np.random.seed(seed)
random.seed(seed)
# https://www.kaggle.com/c/stanford-covid-vaccine/data
# 加载训练用的数据
df = pd.read_json('../data/data179441/train.json', lines=True)
# 查看一下数据集的内容
sample = df.loc[0]
print(sample)

index                                                                400
id                                                          id_2a7a4496f
sequence               GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...
structure              .....(((...)))((((((((((((((((((((.((((....)))...
predicted_loop_type    EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...
signal_to_noise                                                        0
SN_filter                                                              0
seq_length                                                           107
seq_scored                                                            68
reactivity_error       [146151.225, 146151.225, 146151.225, 146151.22...
deg_error_Mg_pH10      [104235.1742, 104235.1742, 104235.1742, 104235...
deg_error_pH10         [222620.9531, 222620.9531, 222620.9531, 222620...
deg_error_Mg_50C       [171525.3217, 171525.3217, 171525.3217, 171525...
deg_error_50C          [191738.0886, 191738.0886, 191738.0886, 191738...
reactivity             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_pH10            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_pH10               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_50C             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_50C                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
Name: 0, dtype: object

例如 deg_50C、deg_Mg_50C 这样的值全为 0 的行,就是咱们须要预测的。

structure 一行,数据中的括号是为了形成边用的。

本案例要预测 RNA 序列不同地位的降解速率,训练数据中提供了多个 ground 值,标签包含以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50

  • reactivity – (1×68 vector 训练集,1×91 测试集) 一个浮点数数组,与 seq_scores 有雷同的长度,是前 68 个碱基的反馈活性值,按程序示意,用于确定 RNA 样本可能的二级构造。
  • deg_Mg_pH10 – (训练集 1×68 向量,1×91 测试集)一个浮点数数组,与 seq_scores 有雷同的长度,是前 68 个碱基的反馈活性值,按程序示意,用于确定在高 pH (pH 10)下的降解可能性。
  • deg_Mg_50 – (训练集 1×68 向量,1×91 测试集)一个浮点数数组,与 seq_scores 有雷同的长度,是前 68 个碱基的反馈活性值,按程序示意,用于确定在低温 (50 摄氏度) 下的降解可能性。

    # 利用 GraphParser 结构图构造的数据
    args = prepare_config("./config.yaml", isCreate=False, isSave=False)
    parser = GraphParser(args) # GraphParser 类来自 data_parser.py
    gdata = parser.parse(sample) # GraphParser 里最次要的函数就是 parse(self, sample)
    

    数据格式:

    {'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 1., 0., ..., 0., 0., 0.],
          ...,
          [1., 0., 0., ..., 0., 0., 0.],
          [1., 0., 0., ..., 0., 0., 0.],
          [1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
     'edges': array([[0,   1],
          [1,   0],
          [1,   2],
          ...,
          [142, 105],
          [106, 142],
          [142, 106]]),
     'efeat': array([[0.,  0.,  0.,  1.,  1.],
          [0.,  0.,  0., -1.,  1.],
          [0.,  0.,  0.,  1.,  1.],
          ...,
          [0.,  1.,  0.,  0.,  0.],
          [0.,  1.,  0.,  0.,  0.],
          [0.,  1.,  0.,  0.,  0.]], dtype=float32),
     'labels': array([[0.    ,  0.    ,  0.],
          [0.    ,  0.    ,  0.],
          ...,
          [0.    ,  0.9213,  0.],
          [6.8894,  3.5097,  5.7754],
          [0.    ,  1.8426,  6.0642],
            ...,        
          [0.    ,  0.    ,  0.],
          [0.    ,  0.    ,  0.]], dtype=float32),
     'mask': array([[True],
          [True],
       ......
         [False]])}
# 图数据可视化
fig = plt.figure(figsize=(24, 12))
nx_G = nx.Graph()
nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])

nx_G.add_edges_from(gdata['edges'])
node_color = ['g' for _ in range(sample['seq_length'])] + \
['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]
options = {"node_color": node_color,}
pos = nx.spring_layout(nx_G, iterations=400, k=0.2)
nx.draw(nx_G, pos, **options)

plt.show()

从图中能够看到,绿色节点是碱基,黄色节点是密码子。后果返回的是 MCRMSE 和 loss

{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}

这部分代码实现参考我的项目:[PGL 图学习之基于 GNN 模型新冠疫苗工作[系列九]](https://aistudio.baidu.com/aistudio/projectdetail/5123296?contributionType=1)
# 咱们在 layer.py 里定义了一个新的 gnn 模型(my_gnn),消息传递的过程中退出了边的特色(edge_feat)
# 而后批改 model.py 里的 GNNModel
# 应用批改后的模型,运行 main.py。为节省时间,设置 epochs = 100

# !python main.py --config config.yaml #训练
#!python main.py --mode infer #预测

4. 总结

本我的项目讲了论文节点分类工作和新冠疫苗工作,并在论文节点分类工作中对代码进行具体解说。PGL 八九系列的我的项目耦合性比拟大,也花了挺久工夫钻研心愿对大家有帮忙。

后续将做一次大的总结偏差业务侧该如何落地以及图算法的演绎,之后会进行不定期更新图相干的算法!

  • easydict 库和 collections 库!
  • 从官网数据处理局部,学习到利用 np 的 vstack 实现自环边以及晓得有向边如何增加反向边的数据——这样的一种代码实现边数据转换的形式!
  • 从模型加载局部,学习了多 program 执行的操作,理清了 program 与命名空间之间的分割!
  • 从模型训练局部,强化了执行器执行时,须要传入正确的 program 以及 feed_dict,在 pgl 中能够应用图 Graph 自带的 to_feed 办法返回一个 feed_dict 数据字典作为初始数据,后边再按需增加新数据!
  • 从 model.py 学习了模型的组网,以及 pgl 中 conv 类下的网络模型办法的调用,不便组网!
  • 重点来了:从 build_model.py 学习了模型的参数的加载组合,实现对立的解决和返回对立的算子以及参数!

正文完
 0