论文提出NF-ResNet,依据网络的理论信号传递进行剖析,模仿BatchNorm在均值和方差传递上的体现,进而代替BatchNorm。论文试验和剖析非常足,进去的成果也很不错。一些初始化办法的实践成果是对的,但理论应用会有偏差,论文通过实际剖析发现了这一点进行补充,贯彻了实际出真知的情理
起源:晓飞的算法工程笔记 公众号
论文: Characterizing signal propagation to close the performance gap in unnormalized ResNets
- 论文地址:https://arxiv.org/abs/2101.08692
Introduction
BatchNorm是深度学习中外围计算组件,大部分的SOTA图像模型都应用它,次要有以下几个长处:
- 平滑损失曲线,可应用更大的学习率进行学习。
- 依据minibatch计算的统计信息相当于为以后的batch引入噪声,有正则化作用,避免过拟合。
- 在初始阶段,束缚残差分支的权值,保障深度残差网络有很好的信息传递,可训练超深的网络。
然而,只管BatchNorm很好,但还是有以下毛病:
- 性能受batch size影响大,batch size小时体现很差。
- 带来训练和推理时用法不统一的问题。
- 减少内存耗费。
- 实现模型时常见的谬误起源,特地是分布式训练。
- 因为精度问题,难以在不同的硬件上复现训练后果。
目前,很多钻研开始寻找代替BatchNorm的归一化层,但这些代替层要么体现不行,要么会带来新的问题,比方减少推理的计算耗费。而另外一些钻研则尝试去掉归一化层,比方初始化残差分支的权值,使其输入为零,保障训练初期大部分的信息通过skip path进行传递。尽管可能训练很深的网络,但应用简略的初始化办法的网络的准确率较差,而且这样的初始化很难用于更简单的网络中。
因而,论文心愿找出一种无效地训练不含BatchNorm的深度残差网络的办法,而且测试集性能可能媲美以后的SOTA,论文次要奉献如下:
- 提出信号流传图(Signal Propagation Plots, SPPs),可辅助察看初始阶段的推理信号流传状况,确定如何设计无BatchNorm的ResNet来达到相似的信号流传成果。
- 验证发现无BatchNorm的ResNet成果不好的关键在于非线性激活(ReLU)的应用,通过非线性激活的输入的均值总是负数,导致权值的均值随着网络深度的减少而急剧减少。于是提出Scaled Weight Standardization,可能阻止信号均值的增长,大幅晋升性能。
- 对ResNet进行normalization-free革新以及增加Scaled Weight Standardization训练,在ImageNet上与原版的ResNet有相当的性能,层数达到288层。
- 对RegNet进行normalization-free革新,联合EfficientNet的混合缩放,结构了NF-RegNet系列,在不同的计算量上都达到与EfficientNet相当的性能。
Signal Propagation Plots
许多钻研从实践上剖析ResNet的信号流传,却很少会在设计或魔改网络的时候实地验证不同层数的特色缩放状况。实际上,用任意输出进行前向推理,而后记录网络不同地位特色的统计信息,能够很直观地理解信息流传情况并尽快发现暗藏的问题,不必经验漫长的失败训练。于是,论文提出了信号流传图(Signal Propagation Plots,SPPs),输出随机高斯输出或实在训练样本,而后别离统计每个残差block输入的以下信息:
- Average Channel Squared Mean,在NHW维计算均值的平方(均衡正负均值),而后在C维计算平均值,越靠近零是越好的。
- Average Channel Variance,在NHW维计算方差,而后在C维计算平均值,用于掂量信号的幅度,能够看到信号是爆炸抑或是衰减。
- Residual Average Channel Variance,仅计算残差分支输入,用于评估分支是否被正确初始化。
论文对常见的BN-ReLU-Conv构造和不常见的ReLU-BN-Conv构造进行了试验统计,试验的网络为600层ResNet,采纳He初始化,定义residual block为$x_{l+1}=f_{l}(x_{l}) + x_{l}$,从SPPs能够发现了以下景象:
- Average Channel Variance随着网络深度线性增长,而后在transition block处重置为较低值。这是因为在训练初始阶段,residual block的输入的方差为$Var(x_{l+1})=Var(f_{l}(x_{l})) + Var(x_{l})$,一直累积residual branch和skip path的方差。而在transition block处,skip path的输出被BatchNorm解决过,所以block的输入的方差间接被重置了。
- BN-ReLU-Conv的Average Squared Channel Means也是随着网络深度一直减少,尽管BatchNorm的输入是零均值的,但通过ReLU之后就变成了正均值,再与skip path相加就一直地减少直到transition block的呈现,这种景象可称为mean-shift。
- BN-ReLU的Residual Average Channel Variance大概为0.68,ReLU-BN的则大概为1。BN-ReLU的方差变小次要因为ReLU,前面会剖析到,但实践应该是0.34左右,而且这里每个transition block的残差分支输入却为1,有点奇怪,如果晓得的读者麻烦评论或私信一下。
如果间接去掉BatchNorm,Average Squared Channel Means和Average Channel Variance将会一直地减少,这也是深层网络难以训练的起因。所以要去掉BatchNorm,必须设法模仿BatchNorm的信号传递成果。
Normalizer-Free ResNets(NF-ResNets)
依据后面的SPPs,论文设计了新的redsidual block$x_{l+1}=x_l+\alpha f_l(x_l/\beta_l)$,次要模仿BatchNorm在均值和方差上的体现,具体如下:
- $f(\cdot)$为residual branch的计算函数,该函数须要非凡初始化,保障初期具备放弃方差的性能,即$Var(f_l(z))=Var(z)$,这样的束缚可能帮忙更好地解释和剖析网络的信号增长。
- $\beta_l=\sqrt{Var(x_l)}$为固定标量,值为输出特色的标准差,保障$f_l(\cdot)$为单位方差。
- $\alpha$为超参数,用于管制block间的方差增长速度。
依据下面的设计,给定$Var(x_0)=1$和$\beta_l=\sqrt{Var(x_l)}$,可依据$Var(x_l)=Var(x_{l-1})+\alpha^2$间接计算第$l$个residual block的输入的方差。为了模仿ResNet中的累积方差在transition block处被重置,须要将transition block的skip path的输出放大为$x_l/\beta_l$,保障每个stage结尾的transition block输入方差满足$Var(x_{l+1})=1+\alpha^2$。将上述简略缩放策略利用到残差网络并去掉BatchNorm层,就失去了Normalizer-Free ResNets(NF-ResNets)。
ReLU Activations Induce Mean Shifts
论文对应用He初始化的NF-ResNet进行SPPs剖析,后果如图2,发现了两个比拟意外的景象:
- Average Channel Squared Mean随着网络变深一直减少,值大到超过了方差,有mean-shift景象。
- 跟BN-ReLU-Conv相似,残差分支输入的方差始终小于1。
为了验证上述景象,论文将网络的ReLU去掉再进行SPPs剖析。如图7所示,当去掉ReLU后,Average Channel Squared Mean靠近于0,而且残差分支输入的靠近1,这表明是ReLU导致了mean-shift景象。
论文也从实践的角度剖析了这一景象,首先定义转化$z=Wg(x)$,$W$为任意且固定的矩阵,$g(\cdot)$为作用于独立同散布输出$x$上的elememt-wise激活函数,所以$g(x)$也是独立同散布的。假如每个维度$i$都有$\mathbb{E}(g(x_i))=\mu_g$以及$Var(g(x_i))=\sigma^2_g$,则输入$z_i=\sum^N_jW_{i,j}g(x_j)$的均值和方差为:
其中,$\mu w_{i,.}$和$\sigma w_{i,.}$为$W$的$i$行(fan-in)的均值和方差:
当$g(\cdot)$为ReLU激活函数时,则$g(x)\ge 0$,意味着后续的线性层的输出都为正均值。如果$x_i\sim\mathcal{N}(0,1)$,则$\mu_g=1/\sqrt{2\pi}$。因为$\mu_g>0$,如果$\mu w_i$也是非零,则$z_i$同样有非零均值。须要留神的是,即便$W$从均值为零的散布中采样而来,其理论的矩阵均值必定不会为零,所以残差分支的任意维度的输入也不会为零,随着网络深度的减少,越来越难训练。
Scaled Weight Standardization
为了打消mean-shift景象以及保障残差分支$f_l(\cdot)$具备方差不变的个性,论文借鉴了Weight Standardization和Centered Weight Standardization,提出Scaled Weight Standardization(Scaled WS)办法,该办法对卷积层的权值从新进行如下的初始化:
$\mu$和$\sigma$为卷积核的fan-in的均值和方差,权值$W$初始为高斯权值,$\gamma$为固定常量。代入公式1能够得出,对于$z=\hat{W}g(x)$,有$\mathbb{E}(z_i)=0$,去除了mean-shift景象。另外,方差变为$Var(z_i)=\gamma^2\sigma^2_g$,$\gamma$值由应用的激活函数决定,可放弃方差不变。
Scaled WS训练时减少的开销很少,而且与batch数据无关,在推理的时候更是无额定开销的。另外,训练和测试时的计算逻辑保持一致,对分布式训练也很敌对。从图2的SPPs曲线能够看出,退出Scaled WS的NF-ResNet-600的体现跟ReLU-BN-Conv十分相似。
Determining Nonlinerity-Specific Constants
最初的因素是$\gamma$值的确定,保障残差分支输入的方差在初始阶段靠近1。$\gamma$值由网络应用的非线性激活类型决定,假如非线性的输出$x\sim\mathcal{N}(0,1)$,则ReLU输入$g(x)=max(x,0)$相当于从方差为$\sigma^2_g=(1/2)(1-(1/\pi))$的高斯分布采样而来。因为$Var(\hat{W}g(x))=\gamma^2\sigma^2_g$,可设置$\gamma=1/\sigma_g=\frac{\sqrt{2}}{\sqrt{1-\frac{1}{\pi}}}$来保障$Var(\hat{W}g(x))=1$。尽管实在的输出不是完全符合$x\sim \mathcal{N}(0,1)$,在实践中上述的$\gamma$设定仍然有不错的体现。
对于其余简单的非线性激活,如SiLU和Swish,公式推导会波及简单的积分,甚至推出不进去。在这种状况下,可应用数值近似的办法。先从高斯分布中采样多个$N$维向量$x$,计算每个向量的激活输入的理论方差$Var(g(x))$,再取理论方差均值的平方根即可。
Other Building Block and Relaxed Constraints
本文的外围在于放弃正确的信息传递,所以许多常见的网络结构都要进行批改。如同抉择$\gamma$值一样,可通过剖析或实际判断必要的批改。比方SE模块$y=sigmoid(MLP(pool(h)))*h$,输入须要与$[0,1]$的权值进行相乘,导致信息传递削弱,网络变得不稳固。应用下面提到的数值近似进行独自剖析,发现冀望方差为0.5,这意味着输入须要乘以2来复原正确的信息传递。
实际上,有时绝对简略的网络结构批改就能够放弃很好的信息传递,而有时候即使网络结构不批改,网络自身也可能对网络结构导致的信息衰减有很好的鲁棒性。因而,论文也尝试在维持稳固训练的前提下,测试Scaled WS层的束缚的最大放松水平。比方,为Scaled WS层复原一些卷积的表达能力,退出可学习的缩放因子和偏置,别离用于权值相乘和非线性输入相加。当这些可学习参数没有任何束缚时,训练的稳定性没有受到很大的影响,反而对大于150层的网络训练有肯定的帮忙。所以,NF-ResNet间接放松了束缚,退出两个可学习参数。
论文的附录有具体的网络实现细节,有趣味的能够去看看。
Summary
总结一下,Normalizer-Free ResNet的外围有以下几点:
- 计算前向流传的冀望方差$\beta^2_l$,每通过一个残差block稳固减少$\alpha^2$,残差分支的输出须要放大$\beta_l$倍。
- 将transition block中skip path的卷积输出放大$\beta_l$倍,并在transition block后将方差重置为$\beta_{l+1}=1+\alpha^2$。
- 对所有的卷积层应用Scaled Weight Standardization初始化,基于$x\sim\mathcal{N}(0,1)$计算激活函数$g(x)$对应的$\gamma$值,为激活函数输入的冀望标准差的倒数$\frac{1}{\sqrt{Var(g(x))}}$。
Experiments
比照RegNet的Normalizer-Free变种与其余办法的比照,绝对于EfficientNet还是差点,但曾经非常靠近了。
Conclusion
论文提出NF-ResNet,依据网络的理论信号传递进行剖析,模仿BatchNorm在均值和方差传递上的体现,进而代替BatchNorm。论文试验和剖析非常足,进去的成果也很不错。一些初始化办法的实践成果是对的,但理论应用会有偏差,论文通过实际剖析发现了这一点进行补充,贯彻了实际出真知的情理。
如果本文对你有帮忙,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】