共计 3874 个字符,预计需要花费 10 分钟才能阅读完成。
作者 |Peter Yu
编译 |Flin
起源 |towardsdatascience
最近,我始终在寻找办法来放慢我的钻研和治理我的试验,特地是围绕着写训练管道和治理试验配置文件这两个方面,我发现这两个新我的项目叫做 PyTorch Lightning 和 Hydra。PyTorch Lightning 能够帮忙你疾速编写训练管道,而 Hydra 能够帮忙你无效地治理配置文件。
- PyTorch Lightning:https://github.com/PyTorchLig…
- Hydra:https://hydra.cc/
为了练习应用它们,我决定为 Leela Zero(https://github.com/leela-zero… 编写一个训练管道。我这样做,是因为这是一个范畴很广的我的项目,波及到应用多个 gpu 在大数据集上训练大型网络,能够说是一个非常乏味的技术挑战。此外,我以前已经实现过一个更小版本的 AlphaGo 国际象棋(https://medium.com/@peterkeun…),所以我认为这将是一个乏味的业余我的项目。
在这个博客中,我将解释这个我的项目的次要细节,以便你可能轻松了解我所做的工作。你能够在这里浏览我的代码:https://github.com/yukw777/le…
Leela Zero
第一步是找出 Leela Zero 神经网络的外部工作原理。我大量援用了 Leela Zero 的文档和它的 Tensorflow 训练管道。
神经网络构造
Leela Zero 的神经网络由一个残差塔(ResNet“tower”)组成,塔上有两个“head”,即 AlphaGo Zero 论文(https://deepmind.com/blog/art…)中形容的负责策略的“头”(policy head)和负责计算价值的“头”(value head)。就像论文所述,策略“头”和值“头”开始的那几个卷积滤波器都是 1 ×1,其余所有的卷积滤波器都是 3 ×3。游戏和棋盘特色被编码为 [批次大小,棋盘宽度,棋盘高度,特色数量] 形态的张量,首先通过残差塔输出。而后,塔提取出形象的特色,并通过每个“头”输出这些特色,以计算下一步棋的策略概率分布和游戏的价值,从而预测游戏的获胜者。
你能够在上面的代码片段中找到网络的实现细节。
权重格局
Leela Zero 应用一个简略的文本文件来保留和加载网络权重。文本文件中的每一行都有一系列数字,这些数字示意网络的每一层的权重。首先是残差塔,而后是策略头,而后是值头。
卷积层有 2 个权重行:
- 与 [output, input, filter size, filter size] 形态的卷积权值
- 通道的偏差
Batchnorm 层有 2 个权重行:
- Batchnorm 平均值
- Batchnorm 方差
内积 (齐全连贯) 层有 2 个权重行:
- 带有 [output, input] 形态的层权重
- 输入偏差
我编写了单元测试来确保我的权重文件是正确的。我应用的另一个简略的完整性检查是计算层的数量,在加载我的权值文件后,将其与 Leela Zero 进行比拟。层数公式为:
n_layers = 1 (version number) +
2 (input convolution) +
2 (input batch norm) +
n_res (number of residual blocks) *
8 (first conv + first batch norm +
second conv + second batch norm) +
2 (policy head convolution) +
2 (policy head batch norm) +
2 (policy head linear) +
2 (value head convolution) +
2 (value head batch norm) +
2 (value head first linear) +
2 (value head second linear)
到目前为止,这看起来很简略,然而你须要留神一个实现细节。Leela Zero 实际上应用卷积层的偏差来示意下一个归一化层(batch norm)的可学习参数 (gamma
和beta
)。这样做是为了使权值文件的格局 (只有一行表示层权值,另一行示意偏差) 在增加归一化层时不用更改。
目前,Leela Zero 只应用归一化层的 beta
项,将 gamma
设置为 1。那么,实际上咱们该如何应用卷积偏差,来产生与在归一化层中利用可学习参数雷同的后果呢? 咱们先来看看归一化层的方程:
y = gamma * (x — mean)/sqrt(var — eps) + beta
因为 Leela Zero 将 gamma
设为 1,则方程为:
y = (x — mean)/sqrt(var — eps) + beta
当初,设定 x_conv
是没有偏差的卷积层的输入。而后,咱们想给 x_conv
增加一些偏差,这样当你在没有 beta 的归一化层中运行它时,后果与在只有 beta
的归一化层方程中运行 x_conv
是一样的:
(x_conv + bias — mean)/sqrt(var — eps) =
(x_conv — mean)/sqrt(var — eps) + beta
x_conv + bias — mean =
x_conv — mean + beta * sqrt(var — eps)
bias = beta * sqrt(var — eps)
因而,如果咱们在权值文件中将卷积偏差设置为beta * sqrt(var - eps)
,咱们就会失去冀望的输入,这就是 LeelaZero 所做的。
那么,咱们如何实现它呢? 在 Tensorflow 中,你能够通过调用 tf.layers.batch_normalization(scale=False)
来通知归一化层要疏忽 gamma
项,而后应用它。
遗憾的是,在 PyTorch 中,你不能将归一化层设置为只疏忽 gamma
,你只能通过将仿射参数设置为False: BatchNorm2d(out_channels, affine=False)
,来疏忽gamma
和beta
。所以,我把归一化层设为两个都疏忽,而后简略地在前面加上一个张量,它示意 beta
。而后,应用公式bias = beta * sqrt(var - eps)
来计算权值文件的卷积偏差。
训练管道
在弄清了 Leela Zeros 的神经网络的细节之后,就到了解决训练管道的时候了。正如我提到的,我想练习应用两个工具:PyTorch Lightning 和 Hydra,来放慢编写训练管道和无效治理试验配置。让咱们来具体理解一下我是如何应用它们的。
PyTorch Lightning
编写训练管道是我钻研中最不喜爱的局部: 它波及大量反复的样板代码,而且很难调试。正因为如此,PyTorch Lightning 对我来说就像一股清流,它是一个轻量级的库,PyTorch 没有很多辅助形象,在编写训练管道时,它负责解决大部分样板代码。它容许你关注你的训练管道中更乏味的局部,比方模型架构,并使你的钻研代码更加模块化和可调试。此外,它还反对多 gpu 和 TPU 的开箱即用训练!
为了应用 PyTorch Lightning 作为我的训练管道,我须要做的最多的编码就是编写一个类,我称之为 NetworkLightningModule
,它继承自LightningModule
来指定训练管道的细节,并将其传递给训练器。无关如何编写本人的 LightningModule
的详细信息,能够参考 PyTorch Lightning
的官网文档。
Hydra
我始终在钻研的另一部分是试验治理。当你进行钻研的时候,你不可避免地要运行大量不同的试验来测试你的假如,所以,以一种可扩大的形式跟踪它们是十分重要的。到目前为止,我始终依赖于配置文件来治理我的试验版本,然而应用立体配置文件很快就变得难以治理。应用模板是这个问题的一个解决方案。然而,我发现模板最终也会变得凌乱,因为当你笼罩多个层的值文件来出现你的配置文件时,很难跟踪哪个值来自哪个值文件。
另一方面,Hydra 是一个基于组件的配置管理系统。与应用独自的模板和值文件来出现最终配置不同,你能够组合多个较小的配置文件来组成最终配置。它不如基于模板的配置管理系统灵便,但我发现基于组件的零碎在灵活性和可维护性之间获得了很好的均衡。Hydra 就是这样一个专门为钻研脚本量身定做的零碎。它的调用有点蠢笨,因为它要求你将它用作脚本的次要入口点,但实际上我认为有了这种设计,它很容易与你的训练脚本集成。此外,它容许你通过命令行手动笼罩配置,这在运行你的试验的不同版本时十分有用。我经常应用 Hydra 治理不同规模的网络架构和训练管道配置。
评估
为了评估我的训练网络,我应用 GoMill(https://github.com/mattheww/g…)来举办围棋比赛。它是一个运行在 Go Text Protocol (GTP)引擎上的较量的库,Leela Zero 就是其中之一。你能够在这里(https://github.com/yukw777/le…)找到我应用的较量配置。
论断
通过应用 PyTorch-Lightning 和 Hydra,可能极大地放慢编写训练管道的速度,并无效地治理试验配置文件。我心愿这个我的项目和博客文章也能对你的钻研有所帮忙。你能够在这里查看代码:https://github.com/yukw777/le…
原文链接:https://towardsdatascience.co…
欢送关注磐创 AI 博客站:
http://panchuang.net/
sklearn 机器学习中文官网文档:
http://sklearn123.com/
欢送关注磐创博客资源汇总站:
http://docs.panchuang.net/