关于机器学习:NFResNet去掉BN归一化值得细读的网络信号分析-ICLR-2021

32次阅读

共计 6061 个字符,预计需要花费 16 分钟才能阅读完成。

论文提出 NF-ResNet,依据网络的理论信号传递进行剖析,模仿 BatchNorm 在均值和方差传递上的体现,进而代替 BatchNorm。论文试验和剖析非常足,进去的成果也很不错。一些初始化办法的实践成果是对的,但理论应用会有偏差,论文通过实际剖析发现了这一点进行补充,贯彻了实际出真知的情理

起源:晓飞的算法工程笔记 公众号

论文: Characterizing signal propagation to close the performance gap in unnormalized ResNets

  • 论文地址:https://arxiv.org/abs/2101.08692

Introduction


  BatchNorm 是深度学习中外围计算组件,大部分的 SOTA 图像模型都应用它,次要有以下几个长处:

  • 平滑损失曲线,可应用更大的学习率进行学习。
  • 依据 minibatch 计算的统计信息相当于为以后的 batch 引入噪声,有正则化作用,避免过拟合。
  • 在初始阶段,束缚残差分支的权值,保障深度残差网络有很好的信息传递,可训练超深的网络。

  然而,只管 BatchNorm 很好,但还是有以下毛病:

  • 性能受 batch size 影响大,batch size 小时体现很差。
  • 带来训练和推理时用法不统一的问题。
  • 减少内存耗费。
  • 实现模型时常见的谬误起源,特地是分布式训练。
  • 因为精度问题,难以在不同的硬件上复现训练后果。

  目前,很多钻研开始寻找代替 BatchNorm 的归一化层,但这些代替层要么体现不行,要么会带来新的问题,比方减少推理的计算耗费。而另外一些钻研则尝试去掉归一化层,比方初始化残差分支的权值,使其输入为零,保障训练初期大部分的信息通过 skip path 进行传递。尽管可能训练很深的网络,但应用简略的初始化办法的网络的准确率较差,而且这样的初始化很难用于更简单的网络中。
  因而,论文心愿找出一种无效地训练不含 BatchNorm 的深度残差网络的办法,而且测试集性能可能媲美以后的 SOTA,论文次要奉献如下:

  • 提出信号流传图(Signal Propagation Plots, SPPs),可辅助察看初始阶段的推理信号流传状况,确定如何设计无 BatchNorm 的 ResNet 来达到相似的信号流传成果。
  • 验证发现无 BatchNorm 的 ResNet 成果不好的关键在于非线性激活 (ReLU) 的应用,通过非线性激活的输入的均值总是负数,导致权值的均值随着网络深度的减少而急剧减少。于是提出 Scaled Weight Standardization,可能阻止信号均值的增长,大幅晋升性能。
  • 对 ResNet 进行 normalization-free 革新以及增加 Scaled Weight Standardization 训练,在 ImageNet 上与原版的 ResNet 有相当的性能,层数达到 288 层。
  • 对 RegNet 进行 normalization-free 革新,联合 EfficientNet 的混合缩放,结构了 NF-RegNet 系列,在不同的计算量上都达到与 EfficientNet 相当的性能。

Signal Propagation Plots


  许多钻研从实践上剖析 ResNet 的信号流传,却很少会在设计或魔改网络的时候实地验证不同层数的特色缩放状况。实际上,用任意输出进行前向推理,而后记录网络不同地位特色的统计信息,能够很直观地理解信息流传情况并尽快发现暗藏的问题,不必经验漫长的失败训练。于是,论文提出了信号流传图(Signal Propagation Plots,SPPs),输出随机高斯输出或实在训练样本,而后别离统计每个残差 block 输入的以下信息:

  • Average Channel Squared Mean,在 NHW 维计算均值的平方(均衡正负均值),而后在 C 维计算平均值,越靠近零是越好的。
  • Average Channel Variance,在 NHW 维计算方差,而后在 C 维计算平均值,用于掂量信号的幅度,能够看到信号是爆炸抑或是衰减。
  • Residual Average Channel Variance,仅计算残差分支输入,用于评估分支是否被正确初始化。

  论文对常见的 BN-ReLU-Conv 构造和不常见的 ReLU-BN-Conv 构造进行了试验统计,试验的网络为 600 层 ResNet,采纳 He 初始化,定义 residual block 为 $x_{l+1}=f_{l}(x_{l}) + x_{l}$,从 SPPs 能够发现了以下景象:

  • Average Channel Variance 随着网络深度线性增长,而后在 transition block 处重置为较低值。这是因为在训练初始阶段,residual block 的输入的方差为 $Var(x_{l+1})=Var(f_{l}(x_{l})) + Var(x_{l})$,一直累积 residual branch 和 skip path 的方差。而在 transition block 处,skip path 的输出被 BatchNorm 解决过,所以 block 的输入的方差间接被重置了。

  • BN-ReLU-Conv 的 Average Squared Channel Means 也是随着网络深度一直减少,尽管 BatchNorm 的输入是零均值的,但通过 ReLU 之后就变成了正均值,再与 skip path 相加就一直地减少直到 transition block 的呈现,这种景象可称为 mean-shift。

  • BN-ReLU 的 Residual Average Channel Variance 大概为 0.68,ReLU-BN 的则大概为 1。BN-ReLU 的方差变小次要因为 ReLU,前面会剖析到,但实践应该是 0.34 左右,而且这里每个 transition block 的残差分支输入却为 1,有点奇怪,如果晓得的读者麻烦评论或私信一下。

  如果间接去掉 BatchNorm,Average Squared Channel Means 和 Average Channel Variance 将会一直地减少,这也是深层网络难以训练的起因。所以要去掉 BatchNorm,必须设法模仿 BatchNorm 的信号传递成果。

Normalizer-Free ResNets(NF-ResNets)


  依据后面的 SPPs,论文设计了新的 redsidual block$x_{l+1}=x_l+\alpha f_l(x_l/\beta_l)$,次要模仿 BatchNorm 在均值和方差上的体现,具体如下:

  • $f(\cdot)$ 为 residual branch 的计算函数,该函数须要非凡初始化,保障初期具备放弃方差的性能,即 $Var(f_l(z))=Var(z)$,这样的束缚可能帮忙更好地解释和剖析网络的信号增长。
  • $\beta_l=\sqrt{Var(x_l)}$ 为固定标量,值为输出特色的标准差,保障 $f_l(\cdot)$ 为单位方差。
  • $\alpha$ 为超参数,用于管制 block 间的方差增长速度。

  依据下面的设计,给定 $Var(x_0)=1$ 和 $\beta_l=\sqrt{Var(x_l)}$,可依据 $Var(x_l)=Var(x_{l-1})+\alpha^2$ 间接计算第 $l$ 个 residual block 的输入的方差。为了模仿 ResNet 中的累积方差在 transition block 处被重置,须要将 transition block 的 skip path 的输出放大为 $x_l/\beta_l$,保障每个 stage 结尾的 transition block 输入方差满足 $Var(x_{l+1})=1+\alpha^2$。将上述简略缩放策略利用到残差网络并去掉 BatchNorm 层,就失去了 Normalizer-Free ResNets(NF-ResNets)。

ReLU Activations Induce Mean Shifts

  论文对应用 He 初始化的 NF-ResNet 进行 SPPs 剖析,后果如图 2,发现了两个比拟意外的景象:

  • Average Channel Squared Mean 随着网络变深一直减少,值大到超过了方差,有 mean-shift 景象。
  • 跟 BN-ReLU-Conv 相似,残差分支输入的方差始终小于 1。

  为了验证上述景象,论文将网络的 ReLU 去掉再进行 SPPs 剖析。如图 7 所示,当去掉 ReLU 后,Average Channel Squared Mean 靠近于 0,而且残差分支输入的靠近 1,这表明是 ReLU 导致了 mean-shift 景象。
  论文也从实践的角度剖析了这一景象,首先定义转化 $z=Wg(x)$,$W$ 为任意且固定的矩阵,$g(\cdot)$ 为作用于独立同散布输出 $x$ 上的 elememt-wise 激活函数,所以 $g(x)$ 也是独立同散布的。假如每个维度 $i$ 都有 $\mathbb{E}(g(x_i))=\mu_g$ 以及 $Var(g(x_i))=\sigma^2_g$,则输入 $z_i=\sum^N_jW_{i,j}g(x_j)$ 的均值和方差为:

  其中,$\mu w_{i,.}$ 和 $\sigma w_{i,.}$ 为 $W$ 的 $i$ 行 (fan-in) 的均值和方差:

  当 $g(\cdot)$ 为 ReLU 激活函数时,则 $g(x)\ge 0$,意味着后续的线性层的输出都为正均值。如果 $x_i\sim\mathcal{N}(0,1)$,则 $\mu_g=1/\sqrt{2\pi}$。因为 $\mu_g>0$,如果 $\mu w_i$ 也是非零,则 $z_i$ 同样有非零均值。须要留神的是,即便 $W$ 从均值为零的散布中采样而来,其理论的矩阵均值必定不会为零,所以残差分支的任意维度的输入也不会为零,随着网络深度的减少,越来越难训练。

Scaled Weight Standardization

  为了打消 mean-shift 景象以及保障残差分支 $f_l(\cdot)$ 具备方差不变的个性,论文借鉴了 Weight Standardization 和 Centered Weight Standardization,提出 Scaled Weight Standardization(Scaled WS)办法,该办法对卷积层的权值从新进行如下的初始化:

  $\mu$ 和 $\sigma$ 为卷积核的 fan-in 的均值和方差,权值 $W$ 初始为高斯权值,$\gamma$ 为固定常量。代入公式 1 能够得出,对于 $z=\hat{W}g(x)$,有 $\mathbb{E}(z_i)=0$,去除了 mean-shift 景象。另外,方差变为 $Var(z_i)=\gamma^2\sigma^2_g$,$\gamma$ 值由应用的激活函数决定,可放弃方差不变。
  Scaled WS 训练时减少的开销很少,而且与 batch 数据无关,在推理的时候更是无额定开销的。另外,训练和测试时的计算逻辑保持一致,对分布式训练也很敌对。从图 2 的 SPPs 曲线能够看出,退出 Scaled WS 的 NF-ResNet-600 的体现跟 ReLU-BN-Conv 十分相似。

Determining Nonlinerity-Specific Constants

  最初的因素是 $\gamma$ 值的确定,保障残差分支输入的方差在初始阶段靠近 1。$\gamma$ 值由网络应用的非线性激活类型决定,假如非线性的输出 $x\sim\mathcal{N}(0,1)$,则 ReLU 输入 $g(x)=max(x,0)$ 相当于从方差为 $\sigma^2_g=(1/2)(1-(1/\pi))$ 的高斯分布采样而来。因为 $Var(\hat{W}g(x))=\gamma^2\sigma^2_g$,可设置 $\gamma=1/\sigma_g=\frac{\sqrt{2}}{\sqrt{1-\frac{1}{\pi}}}$ 来保障 $Var(\hat{W}g(x))=1$。尽管实在的输出不是完全符合 $x\sim \mathcal{N}(0,1)$,在实践中上述的 $\gamma$ 设定仍然有不错的体现。
  对于其余简单的非线性激活,如 SiLU 和 Swish,公式推导会波及简单的积分,甚至推出不进去。在这种状况下,可应用数值近似的办法。先从高斯分布中采样多个 $N$ 维向量 $x$,计算每个向量的激活输入的理论方差 $Var(g(x))$,再取理论方差均值的平方根即可。

Other Building Block and Relaxed Constraints

  本文的外围在于放弃正确的信息传递,所以许多常见的网络结构都要进行批改。如同抉择 $\gamma$ 值一样,可通过剖析或实际判断必要的批改。比方 SE 模块 $y=sigmoid(MLP(pool(h)))*h$,输入须要与 $[0,1]$ 的权值进行相乘,导致信息传递削弱,网络变得不稳固。应用下面提到的数值近似进行独自剖析,发现冀望方差为 0.5,这意味着输入须要乘以 2 来复原正确的信息传递。
  实际上,有时绝对简略的网络结构批改就能够放弃很好的信息传递,而有时候即使网络结构不批改,网络自身也可能对网络结构导致的信息衰减有很好的鲁棒性。因而,论文也尝试在维持稳固训练的前提下,测试 Scaled WS 层的束缚的最大放松水平。比方,为 Scaled WS 层复原一些卷积的表达能力,退出可学习的缩放因子和偏置,别离用于权值相乘和非线性输入相加。当这些可学习参数没有任何束缚时,训练的稳定性没有受到很大的影响,反而对大于 150 层的网络训练有肯定的帮忙。所以,NF-ResNet 间接放松了束缚,退出两个可学习参数。
  论文的附录有具体的网络实现细节,有趣味的能够去看看。

Summary

  总结一下,Normalizer-Free ResNet 的外围有以下几点:

  • 计算前向流传的冀望方差 $\beta^2_l$,每通过一个残差 block 稳固减少 $\alpha^2$,残差分支的输出须要放大 $\beta_l$ 倍。
  • 将 transition block 中 skip path 的卷积输出放大 $\beta_l$ 倍,并在 transition block 后将方差重置为 $\beta_{l+1}=1+\alpha^2$。
  • 对所有的卷积层应用 Scaled Weight Standardization 初始化,基于 $x\sim\mathcal{N}(0,1)$ 计算激活函数 $g(x)$ 对应的 $\gamma$ 值,为激活函数输入的冀望标准差的倒数 $\frac{1}{\sqrt{Var(g(x))}}$。

Experiments


  比照 RegNet 的 Normalizer-Free 变种与其余办法的比照,绝对于 EfficientNet 还是差点,但曾经非常靠近了。

Conclusion


  论文提出 NF-ResNet,依据网络的理论信号传递进行剖析,模仿 BatchNorm 在均值和方差传递上的体现,进而代替 BatchNorm。论文试验和剖析非常足,进去的成果也很不错。一些初始化办法的实践成果是对的,但理论应用会有偏差,论文通过实际剖析发现了这一点进行补充,贯彻了实际出真知的情理。



如果本文对你有帮忙,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

正文完
 0