基于图形的深度学习应该怎样做

47次阅读

共计 2937 个字符,预计需要花费 8 分钟才能阅读完成。

因为能从各类简单数据(例如自在文本、图像、视频)中提取简单模式,深度学习技术曾经实现了十分宽泛的利用。然而实际中会发现,很多数据集用图形(Graph)的形式更容易表白,例如社交网络上人们之间的互相关系等。

对于这类数据,深度学习技术中所应用的卷积神经网络、递归神经网络等传统神经网络架构就有些难以应答了,兴许是时候引入一种新的办法。

图形神经网络(GNN)请理解一下

图神经网络(GNN)是当今机器学习畛域最蓬勃发展的方向之一,这种技术通常可用于在图形类型的数据集上训练预测模型,例如:

  • 社交网络数据集,其中用图形显示熟人之间的分割;
  • 举荐零碎数据集,其中用图形显示顾客与物品之间的互动;
  • 化学分析数据集,其中化合物被建模为由原子和化学键组成的图形;
  • 网络安全数据集,其中用图形形容了源 IP 地址和指标 IP 地址之间的连贯……

大部分状况下,这些数据集十分宏大,并且只有局部标记。以一个典型的欺诈检测场景为例,在该场景中,咱们将尝试通过剖析某人与已知欺诈者的分割,来预测其为欺诈行为者的可能性。这个问题能够定义为半监督学习工作,其中只有一小部分图节点将被标记(“欺诈者”或“非法者”)。与尝试构建一个大型的手工标记数据集,并对其进行“线性化”以利用传统的机器学习算法相比,这应该是一个更好的解决方案。

无关 GNN 的进一步介绍,无妨参考这些参考文献。

Amazon SageMaker 现已反对开源的 Deep Graph Library

在 GNN 的实际利用中,通常须要咱们具备特定畛域(批发、金融、化学等)的常识、计算机科学常识(Python、深度学习、开源工具)以及基础设施相干常识(培训、部署和扩大模型)。要求比拟多且简单,很少人可能把握所有这些技能。
而 Amazon SageMaker 通过对 Deep Graph Library 的反对解决了这些问题。
Deep Graph Library(DGL)于 2018 年 12 月首次在 Github 上公布,是一个 Python 开源库,可帮忙钻研人员和科学家利用其数据集疾速构建、训练和评估 GNN。

DGL 建设在风行的深度学习框架之上,例如 PyTorch 和 Apache MXNet。如果相熟其中之一,就会发现它应用起来得心应手。无论应用哪种框架,咱们都能够通过这些对初学者敌对的示例轻松入门。此外,GTC 2019 研讨会的提供的幻灯片和代码也能帮忙咱们疾速上手。
一旦实现了玩具示例,就能够开始摸索 DGL 中已实现的各种前沿模型了。例如,咱们能够通过运行以下命令,应用图卷积网络(GCN)和 CORA 数据集来训练文档分类模型:

`$ python3 train.py --dataset cora --gpu 0 --self-loop`

所有模型的代码均可供查看和调整。这些实现曾经过 AWS 团队认真验证,他们验证了性能申明并确保能够重现后果。

DGL 还蕴含图形数据集的汇合,咱们能够轻松下载并用于试验。
当然,咱们也能够在本地装置和运行 DGL,然而为了更不便,AWS 曾经将其增加到了 PyTorch 和 Apache MXNet 的深度学习容器中。这样就能够轻松地在 Amazon SageMaker 上应用 DGL,以在任意规模上训练和部署模型,而不用治理服务器。

在 Amazon SageMaker 上应用 DGL

AWS 在 Github 存储库中为 SageMaker 增加了残缺示例:在其中一个示例中,咱们应用 Tox21 数据集训练了一个用于分子毒性预测的简略 GNN。

咱们尝试解决的问题是:计算出新化合物对 12 种不同靶标(生物细胞内的受体等)的潜在毒性。能够设想,这种剖析在设计新药时至关重要,而且无需进行体外试验就能疾速预测后果,这有助于钻研人员将精力集中在最有心愿的候选药物上。

数据集蕴含 8,000 多种化合物:每种化合物均建模为图形(原子是顶点,原子键是边),并标记 12 次(每个指标一个标记)。咱们将应用 GNN 建设一个多标签的二元分类模型,使咱们可能预测所考查分子的潜在毒性。
在训练脚本中,咱们能够轻松地从 DGL 汇合中下载所需数据集。

from dgl.data.chem import Tox21
dataset = Tox21()

相似的,咱们也能够应用 DGL Model zoo 轻松构建一个 GNN 分类器:

from dgl import model_zoo
model = model_zoo.chem.GCNClassifier(in_feats=args['n_input'],
    gcn_hidden_feats=[args['n_hidden'] for _ in range(args['n_layers'])],
    n_tasks=dataset.n_tasks,
    classifier_hidden_feats=args['n_hidden']).to(args['device'])

其余代码大部分是原始的 PyTorch,如果您相熟此库,则应用起来就应该可能轻车熟路。

要在 Amazon SageMaker 上运行此代码,咱们要做的就是应用 SageMaker 模拟器,传递 DGL 容器的全名并将训练脚本的名称作为超参数。

estimator = sagemaker.estimator.Estimator(container,
    role,
    train_instance_count=1,
    train_instance_type='ml.p3.2xlarge',
    hyperparameters={'entrypoint': 'main.py'},
    sagemaker_session=sess)
code_location = sess.upload_data(CODE_PATH,
bucket=bucket,
key_prefix=custom_code_upload_location)
estimator.fit({'training-code': code_location})
<output removed>
epoch 23/100, batch 48/49, loss 0.4684
epoch 23/100, batch 49/49, loss 0.5389
epoch 23/100, training roc-auc 0.9451
EarlyStopping counter: 10 out of 10
epoch 23/100, validation roc-auc 0.8375, best validation roc-auc 0.8495
Best validation score 0.8495
Test score 0.8273
2019-11-21 14:11:03 Uploading - Uploading generated training model
2019-11-21 14:11:03 Completed - Training job completed
Training seconds: 209
Billable seconds: 209

当初,咱们能够获取 S3 中经过训练的模型,并将其用于预测大量化合物的毒性,而无需进行理论试验。

立刻尝试

大家当初曾经能够 Amazon SageMaker 上应用 DGL。
自行体验的同时,无妨通过 DGL 论坛、Amazon SageMaker 的 AWS 平台或您罕用的 AWS Support 联系方式向咱们发送反馈。

正文完
 0