在过来的大半年里,以Stable Diffusion为代表的AI绘画是世界上最为炽热的AI方向之一。或者大家会有疑难,Stable Diffusion里的这个"Diffusion"是什么意思?其实,扩散模型(Diffusion Model)正是Stable Diffusion中负责生成图像的模型。想要了解Stable Diffusion的原理,就肯定绕不过扩散模型的学习。

在这篇文章里,我会由浅入深地对最根底的去噪扩散概率模型(Denoising Diffusion Probabilistic Models, DDPM)进行解说。我会先介绍扩散模型生成图像的基本原理,再用简略的数学语言对扩散模型建模,最初给出扩散模型的一份PyTorch实现。本文不会去堆砌过于简单的数学公式,哪怕你没有相干的数学背景,也可能轻松了解扩散模型的原理。

扩散模型与图像生成

在意识扩散模型之前,咱们先退一步,看看个别的神经网络模型是怎么生成图像的。显然,为了生成丰盛的图像,一个图像生成程序要依据随机数来生成图像。通常,这种随机数是一个满足规范正态分布的随机向量。这样,每次要生成新图像时,只须要从规范正态分布里随机生成一个向量并输出给程序就行了。

而在AI绘画程序中,负责生成图像的是一个神经网络模型。神经网络须要从数据中学习。对于图像生成工作,神经网络的训练数据个别是一些同类型的图片。比方一个绘制人脸的神经网络会用人脸照片来训练。也就是说,神经网络会学习如何把一个向量映射成一张图片,并确保这个图片和训练集的图片是一类图片。

可是,相比其余AI工作,图像生成工作对神经网络来说更加艰难一点——图像生成工作不足无效的领导。在其余AI工作中,训练集自身会给出一个「标准答案」,领导AI的输入向标准答案聚拢。比方对于图像分类工作,训练集会给出每一幅图像的类别;对于人脸验证工作,训练集会给出两张人脸照片是不是同一个人;对于指标检测工作,训练集会给出指标的具体位置。然而,图像生成工作是没有标准答案的。图像生成数据集里只有一些同类型图片,却没有领导AI如何画得更好的信息。

为了解决这一问题,人们专门设计了一些用于生成图像的神经网络架构。这些架构中比拟闻名的有生成反抗模型(GAN)和变分自编码器(VAE)。

GAN的想法是,既然不晓得一幅图片好不好,就罗唆再训练一个神经网络,用于分别某图片是不是和训练集里的图片长得一样。生成图像的神经网络叫做生成器,鉴定图像的神经网络叫做判断器。两个网络相互反抗,共同进步。

VAE则应用了逆向思维:用向量生成图像很艰难,那就同时学习怎么用图像生成向量。这样,把某图像变成向量,再用该向量生成图像,就应该失去一幅和原图像截然不同的图像。每一个向量的绘画后果有了一个标准答案,能够用个别的优化办法来领导网络的训练了。VAE中,把图像变成向量的网络叫做编码器,把向量转换回图像的网络叫做解码器。其中,解码器就是负责生成图像的模型。

始终以来,GAN的生成成果较好,但训练起来比VAE麻烦很多。有没有和GAN一样弱小,训练起来又不便的生成网络架构呢?扩散模型正是满足这些要求的生成网络架构。

扩散模型是一种非凡的VAE,其灵感来自于热力学:一个散布能够通过一直地增加噪声变成另一个散布。放到图像生成工作里,就是来自训练集的图像能够通过一直增加噪声变成符合标准正态分布的图像。从这个角度登程,咱们能够对VAE做以下批改:1)不再训练一个可学习的编码器,而是把编码过程固定成一直增加噪声的过程;2)不再把图像压缩成更短的向量,而是从头至尾都对一个等大的图像做操作。解码器仍然是一个可学习的神经网络,它的目标也同样是实现编码的逆操作。不过,既然当初编码过程变成了加噪,那么解码器就应该负责去噪。而对于神经网络来说,去噪工作学习起来会更加无效。因而,扩散模型既不会波及GAN中简单的反抗训练,又比VAE更弱小一点。

具体来说,扩散模型由正向过程反向过程这两局部组成,对应VAE中的编码和解码。在正向过程中,输出\(\mathbf{x}_0\)会一直混入高斯噪声。通过\(T\)次加噪声操作后,图像\(\mathbf{x}_T\)会变成一幅符合标准正态分布的纯噪声图像。而在反向过程中,咱们心愿训练出一个神经网络,该网络可能学会\(T\)个去噪声操作,把\(\mathbf{x}_T\)还原回\(\mathbf{x}_0\)。网络的学习指标是让\(T\)个去噪声操作正好能对消掉对应的加噪声操作。训练结束后,只须要从规范正态分布里随机采样出一个噪声,再利用反向过程里的神经网络把该噪声复原成一幅图像,就可能生成一幅图片了。

高斯噪声,就是一幅各处色彩值都满足高斯分布(正态分布)的噪声图像。

总结一下,图像生成网络会学习如何把一个向量映射成一幅图像。设计网络架构时,最重要的是设计学习指标,让网络生成的图像和给定数据集里的图像类似。VAE的做法是应用两个网络,一个学习把图像编码成向量,另一个学习把向量解码回图像,它们的指标是让还原图像和原图像尽可能类似。学习结束后,解码器就是图像生成网络。扩散模型是一种更具体的VAE。它把编码过程固定为加噪声,并让解码器学习怎么样打消之前增加的每一步噪声。

扩散模型的具体算法

上一节中,咱们只是大略理解扩散模型的整体思维。这一节,咱们来引入一些数学示意,来看一看扩散模型的训练算法和采样算法具体是什么。为了便于了解,这一节会呈现一些不是那么谨严的数学形容。更加具体的一些数学推导会放到下一节里介绍。

前向过程

在前向过程中,来自训练集的图像\(\mathbf{x}_0\)会被增加\(T\)次噪声,使得\(x_T\)为符合标准正态分布。精确来说,「加噪声」并不是给上一时刻的图像加上噪声值,而是从一个均值与上一时刻图像相干的正态分布里采样出一幅新图像。如上面的公式所示,\(\mathbf{x}_{t - 1}\)是上一时刻的图像,\(\mathbf{x}_{t}\)是这一时刻生成的图像,该图像是从一个均值与\(\mathbf{x}_{t - 1}\)无关的正态分布里采样进去的。

$$\mathbf{x}_t \sim \mathcal{N}(\mu_t(\mathbf{x}_{t - 1}),\sigma_t^2\mathbf{I})$$

少数文章会说前向过程是一个马尔可夫过程。其实,马尔可夫过程的意思就是以后时刻的状态只由上一时刻的状态决定,而不禁更早的状态决定。下面的公式表明,计算\(\mathbf{x}_t\),只须要用到\(\mathbf{x}_{t - 1}\),而不须要用到\(\mathbf{x}_{t - 2}, \mathbf{x}_{t - 3}...\),这合乎马尔可夫过程的定义。

绝大多数扩散模型会把这个正态分布设置成这个模式:

$$\mathbf{x}_t \sim \mathcal{N}(\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$$

这个正态分布公式乍看起来很奇怪:\(\sqrt{1 - \beta_t}\)是哪里冒出来的?为什么会有这种奇怪的系数?别急,咱们先来看另一个问题:如果给定\(\mathbf{x}_{0}\),也就是从训练集里采样出一幅图片,该怎么计算任意一个时刻\(t\)的噪声图像\(\mathbf{x}_{t}\)呢?

咱们无妨依照公式,从\(\mathbf{x}_{t}\)开始倒推。\(\mathbf{x}_{t}\)其实能够通过一个规范正态分布的样本\(\epsilon_{t-1}\)算进去:

$$\mathbf{x}_t \sim \mathcal{N}(\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I}) \\\Rightarrow \mathbf{x}_t = \sqrt{1 - \beta_t}\mathbf{x}_{t - 1} + \sqrt{\beta_t}\epsilon_{t-1}; \epsilon_{t-1} \sim \mathcal{N}(0, \mathbf{I})$$

再往前推几步:

$$\begin{aligned}\mathbf{x}_t &= \sqrt{1 - \beta_t}\mathbf{x}_{t - 1} + \sqrt{\beta_t}\epsilon_{t-1}\\ &= \sqrt{1 - \beta_t}(\sqrt{1 - \beta_{t-1}}\mathbf{x}_{t - 2} + \sqrt{\beta_{t-1}}\epsilon_{t-2}) + \sqrt{\beta_t}\epsilon_{t-1} \\&= \sqrt{(1 - \beta_t)(1 - \beta_{t-1})}\mathbf{x}_{t - 2} + \sqrt{(1 - \beta_t)\beta_{t-1}}\epsilon_{t-2} + \sqrt{\beta_t}\epsilon_{t-1}\end{aligned}$$

由正态分布的性质可知,均值雷同的正态分布「加」在一起后,方差也会加到一起。也就是\(\mathcal{N}(0, \sigma_1^2 I)\)与\(\mathcal{N}(0, \sigma_2^2 I)\)合起来会失去\(\mathcal{N}(0, (\sigma_1^2+\sigma_2^2) I)\)。依据这一性质,下面的公式能够化简为:

$$\begin{aligned}&\sqrt{(1 - \beta_t)(1 - \beta_{t-1})}\mathbf{x}_{t - 2} + \sqrt{(1 - \beta_t)\beta_{t-1}}\epsilon_{t-2} + \sqrt{\beta_t}\epsilon_{t-1} \\= & \sqrt{(1 - \beta_t)(1 - \beta_{t-1})}\mathbf{x}_{t - 2} + \sqrt{(1 - \beta_t)\beta_{t-1} + \beta_t}\epsilon \\= & \sqrt{(1 - \beta_t)(1 - \beta_{t-1})}\mathbf{x}_{t - 2} + \sqrt{1-(1-\beta_t)(1-\beta_{t-1})}\epsilon\end{aligned}$$

再往前推一步的话,后果是:

$$\sqrt{(1 - \beta_t)(1 - \beta_{t-1})(1 - \beta_{t-2})}\mathbf{x}_{t - 3} + \sqrt{1-(1-\beta_t)(1-\beta_{t-1})(1 - \beta_{t-2})}\epsilon$$

咱们曾经可能猜出法则来了,能够始终把公式推到\(\mathbf{x}_{0}\)。令\(\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i\),则:

$$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon$$

有了这个公式,咱们就能够探讨加噪声公式为什么是\(\mathbf{x}_t \sim \mathcal{N}(\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})\)了。这个公式里的\(\beta_t\)是一个小于1的常数。在DDPM论文中,\(\beta_t\)从\(\beta_1=10^{-4}\)到\(\beta_T=0.02\)线性增长。这样,\(\beta_t\)变大,\(\alpha_t\)也越小,\(\bar{\alpha}_t\)趋于0的速度越来越快。最初,\(\bar{\alpha}_T\)简直为0,代入\(\mathbf{x}_T = \sqrt{\bar{\alpha}_T}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_T}\epsilon\), \(\mathbf{x}_T\)就满足规范正态分布了,合乎咱们对扩散模型的要求。上述推断能够简略形容为:加噪声公式可能从慢到快地扭转原图像,让图像最终均值为0,方差为\(\mathbf{I}\)。

反向过程

在正向过程中,咱们人为设置了\(T\)步加噪声过程。而在反向过程中,咱们心愿可能倒过去勾销每一步加噪声操作,让一幅纯噪声图像变回数据集里的图像。这样,利用这个去噪声过程,咱们就能够把任意一个从规范正态分布里采样进去的噪声图像变成一幅和训练数据长得差不多的图像,从而起到图像生成的目标。

当初问题来了:去噪声操作的数学模式是怎么样的?怎么让神经网络来学习它呢?数学原理表明,当\(\beta_t\)足够小时,每一步加噪声的逆操作也满足正态分布。

$$\mathbf{x}_{t-1} \sim \mathcal{N}(\tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$$

因而,为了形容所有去噪声操作,神经网络应该依据以后的时刻\(t\)、以后的图像\(\mathbf{x}_{t}\),拟合以后时刻的加噪声逆操作的正态分布,也就是拟合以后的均值\(\tilde{\mu}_t\)和方差\(\tilde{\beta}_t\)。

不要被上文的「去噪声」、「加噪声逆操作」绕晕了哦。因为加噪声是固定的,加噪声的逆操作也是固定的。现实状况下,咱们心愿去噪操作就等于加噪声逆操作。然而,加噪声的逆操作不太可能从实践上求得,咱们只能用一个神经网络去拟合它。去噪声操作和加噪声逆操作的关系,就是神经网络的预测值和真值的关系。

当初问题来了:加噪声逆操作的均值和方差是什么?

间接计算所有数据的加噪声逆操作的散布是不太事实的。然而,如果给定了某个训练集输出\(\mathbf{x}_0\),多了一个限定条件后,该散布是能够用贝叶斯公式计算的(其中\(q\)示意概率分布):

$$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)\frac{q(\mathbf{x}_{t-1} | \mathbf{x}_0)}{q(\mathbf{x}_{t} | \mathbf{x}_0)}$$

等式右边的\(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})\)示意加噪声操作的逆操作,它的均值和方差都是待求的。左边的\(q(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})\)是加噪声的散布。而因为\(\mathbf{x}_0\)已知,\(q(\mathbf{x}_{t-1} | \mathbf{x}_0)\)和\(q(\mathbf{x}_{t} | \mathbf{x}_0)\)两项能够依据后面的公式\(\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t\)得来:

$$\begin{aligned}q(\mathbf{x}_{t} | \mathbf{x}_0)&=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I}) \\q(\mathbf{x}_{t-1} | \mathbf{x}_0)&=\mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_{0}, (1-\bar{\alpha}_{t-1})\mathbf{I})\end{aligned}$$

这样,等式左边的式子全副已知。咱们能够把公式套入,算出给定\(\mathbf{x}_0\)时的去噪声散布。经计算化简,散布的均值为:

$$\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t)$$

其中,\(\epsilon_t\)是用公式算\(\mathbf{x}_t\)时从规范正态分布采样出的样本,它来自公式

$$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$$

散布的方差为:

$$\tilde{\beta}_t=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$$

留神,\(\beta_t\)是加噪声的方差,是一个常量。那么,加噪声逆操作的方差\(\tilde{\beta}_t\)也是一个常量,不与输出\(\mathbf{x}_0\)相干。这下就省事了,训练去噪网络时,神经网络只用拟合\(T\)个均值就行,不必再拟合方差了。

晓得了均值和方差的真值,训练神经网络只差最初的问题了:该怎么设置训练的损失函数?加噪声逆操作和去噪声操作都是正态分布,网络的训练指标应该是让每对正态分布更加靠近。那怎么用损失函数形容两个散布尽可能靠近呢?最直观的想法,必定是让两个正态分布的均值尽可能靠近,方差尽可能靠近。依据上文的剖析,方差是不必管制的,只用让均值尽可能靠近就能够了。

那怎么用数学公式表白让均值更靠近呢?再察看一下指标均值的公式:

$$\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t)$$

神经网络拟合均值时,\(\mathbf{x}_{t}\)是已知的(别忘了,图像是一步一步倒着去噪的)。式子里惟一不确定的只有\(\epsilon_t\)。既然如此,神经网络罗唆也别预测均值了,间接预测一个噪声\(\epsilon_\theta(\mathbf{x}_{t}, t)\)(其中\(\theta\)为可学习参数),让它和生成\(\mathbf{x}_{t}\)的噪声\(\epsilon_t\)的均方误差最小就行了。对于一轮训练,最终的误差函数能够写成

$$L=||\epsilon_t - \epsilon_\theta(\mathbf{x}_{t}, t)||^2$$

这样,咱们就意识了反向过程的所有内容。总结一下,反向过程中,神经网络应该让\(T\)个去噪声操作拟合对应的\(T\)个加噪声逆操作。每步加噪声逆操作合乎正态分布,且在给定某个输出时,该正态分布的均值和方差是能够用解析式表达出来的。神经网络的学习指标就是让其输入的散布和实践计算的散布统一。通过数学计算上的一些化简,问题被转换成了拟合生成\(\mathbf{x}_{t}\)时用到的随机噪声\(\epsilon_t\)。

训练算法与采样算法

了解了前向过程和反向过程后,训练神经网络的算法和采样图片(生成图片)的算法就跃然纸上了。

以下是DDPM论文中的训练算法:

让咱们来逐行了解一下这个算法。第二行是指从训练集里取一个数据\(\mathbf{x}_{0}\)。第三行是指随机从\(1, ..., T\)里取一个时刻用来训练。咱们尽管要求神经网络拟合\(T\)个正态分布,但理论训练时,不必一轮预测\(T\)个后果,只须要随机预测\(T\)个时刻中某一个时刻的后果就行。第四行指随机生成一个噪声\(\epsilon\),该噪声是用于执行前向过程生成\(\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon\)的。之后,咱们把\(\mathbf{x}_t\)和\(t\)传给神经网络\(\epsilon_\theta(\mathbf{x}_{t}, t)\),让神经网络预测随机噪声。训练的损失函数是预测噪声和理论噪声之间的均方误差,对此损失函数采纳梯度降落即可优化网络。

DDPM并没有规定神经网络的构造。依据工作的难易水平,咱们能够本人定义简略或简单的网络结构。这里只须要把\(\epsilon_\theta(\mathbf{x}_{t}, t)\)当成一个一般的映射即可。

训练好了网络后,咱们能够执行反向过程,对任意一幅噪声图像去噪,以实现图像生成。这个算法如下:

第一行的\(\mathbf{x}_{t}\)就是从规范正态分布里随机采样的输出噪声。要生成不同的图像,只须要更换这个噪声。前面的过程就是扩散模型的反向过程。令时刻从\(T\)到\(1\),计算这一时刻去噪声操作的均值和方差,并采样出\(\mathbf{x}_{t-1}\)。均值是用之前提到的公式计算的:

$$\mu_{\theta}(\mathbf{x}_{t}, t) = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}_{t}, t))\\$$

而方差\(\sigma_t^2\)的公式有两种抉择,两个公式都能产生差不多的后果。试验表明,当\(\mathbf{x}_{0}\)是特定的某个数据时,用上一节推导进去的方差最好。

$$\sigma_t^2=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$$

而当\(\mathbf{x}_{0} \sim \mathcal{N}(0, \mathbf{I})\)时,只须要令方差和加噪声时的方差一样即可。

$$\sigma_t^2= \beta_t$$

循环执行去噪声操作。最初生成的\(\mathbf{x}_{0}\)就是生成进去的图像。

特地地,最初一步去噪声是不必加方差项的。为什么呢,察看公式\(\sigma_t^2=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t\)。当\(t=1\)时,分子会呈现\(\bar{\alpha}_{t-1}=\bar{\alpha}_0\)这一项。\(\bar{\alpha}_t\)是一个连乘,实践上\(t\)是从\(1\)开始的,在\(t=0\)时没有定义。但咱们能够特地地令\(\bar{\alpha}_0=1\)。这样,\(t=1\)时方差项的分子\(1-\bar{\alpha}_{t-1}\)为\(0\),不必算这一项了。

当然,这一解释从数学上来说是不谨严的。据论文说,这部分的解释能够参见朗之万动力学。

数学推导的补充 (选读)

了解了训练算法和采样算法,咱们就算是搞懂了扩散模型,能够去编写代码了。不过,上文的形容省略了一些数学推导的细节。如果对扩散模型更深的原理感兴趣,能够浏览一下本节。

加噪声逆操作均值和方差的推导

上一节,咱们依据上面几个式子

$$\begin{aligned}q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)\frac{q(\mathbf{x}_{t-1} | \mathbf{x}_0)}{q(\mathbf{x}_{t} | \mathbf{x}_0)} \\q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})\\q(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})\end{aligned}$$

一步就给出了\(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})\)的均值和方差。

$$\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t)$$

$$\tilde{\beta}_t=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$$

当初咱们来看一下推导均值和方差的思路。

首先,把其余几个式子带入贝叶斯公式的等式左边。

$$\begin{aligned}q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) = &\frac{1}{\beta_t\sqrt{2\pi}}exp(-\frac{(\mathbf{x}_{t}-\sqrt{1 - \beta_t}\mathbf{x}_{t - 1})^2}{2\beta_t}) \cdot \\&\frac{1}{(1-\bar{\alpha}_{t-1})\sqrt{2\pi}} exp(-\frac{(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_{0})^2}{2(1-\bar{\alpha}_{t-1})})\cdot \\&(\frac{1}{(1-\bar{\alpha}_t)\sqrt{2\pi}} exp(-\frac{(\mathbf{x}_{t}-\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0})^2}{2(1-\bar{\alpha}_{t})}))^{-1}\end{aligned}$$

因为多个正态分布的乘积还是一个正态分布,咱们晓得\(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)\)也能够用一个正态分布公式\(\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})\)表白,它最初肯定能写成这种模式:

$$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) = \frac{1}{\tilde{\beta}_t\sqrt{2\pi}}exp(-\frac{(\mathbf{x}_{t-1}-\tilde{\mu}_t)^2}{2\tilde{\beta}_t})$$

问题就变成了怎么把开始那个很长的式子化简,算出\(\tilde{\mu}_t\)和\(\tilde{\beta}_t\)。

方差\(\tilde{\beta}_t\)能够从指数函数的系数得来,比拟好求。系数为

$$\begin{aligned}&\frac{1}{\beta_t\sqrt{2\pi}} \cdot \frac{1}{(1-\bar{\alpha}_{t-1})\sqrt{2\pi}} \cdot (\frac{1}{(1-\bar{\alpha}_t)\sqrt{2\pi}})^{-1} \\=&\frac{(1-\bar{\alpha}_t)}{\beta_t(1-\bar{\alpha}_{t-1})\sqrt{2\pi}}\end{aligned}$$

所以,方差为:

$$\tilde{\beta}_t=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$$

接下来只有关注指数函数的指数局部。指数局部肯定是一个对于的\(\mathbf{x}_{t-1}\)的二次函数,只有化简成\((\mathbf{x}_{t-1}-C)^2\)的模式,再除以一下\(-2\)倍方差,就能够失去均值了。

指数局部为:

$$-\frac{1}{2}(\frac{(\mathbf{x}_{t}-\sqrt{1 - \beta_t}\mathbf{x}_{t - 1})^2}{\beta_t}+\frac{(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_{0})^2}{1-\bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_{t}-\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0})^2}{1-\bar{\alpha}_{t}})$$

\(\mathbf{x}_{t-1}\)只在前两项里有。把和\(\mathbf{x}_{t-1}\)无关的项计算化简,能够计算出均值:

$$\tilde{\mu}_t = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1 - \bar{\alpha}_{t}}\mathbf{x}_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_{t}}\mathbf{x}_{0}$$

回忆一下,在去噪声中,神经网络的输出是\(\mathbf{x}_{t}\)和\(t\)。也就是说,上式中\(\mathbf{x}_{t}\)已知,只有\(\mathbf{x}_{0}\)一个未知量。要算均值,还须要算出\(\mathbf{x}_{0}\)。\(\mathbf{x}_{0}\)和\(\mathbf{x}_{t}\)之间是有肯定分割的。\(\mathbf{x}_{t}\)是\(\mathbf{x}_{0}\)在正向过程中第\(t\)步加噪声的后果。而依据正向过程的公式:

$$\begin{aligned}\mathbf{x}_t &= \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t \\\mathbf{x}_0 &= \frac{\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\epsilon_t}{\sqrt{\bar{\alpha}_t}}\end{aligned}$$

把这个\(\mathbf{x}_{0}\)带入均值公式,均值最初会化简成咱们相熟的模式。

$$\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t)$$

优化指标

上一节,咱们只是简略地说神经网络的优化指标是让加噪声和去噪声的均值靠近。而让均值靠近,就是让生成\(\mathbf{x}_t\)的噪声\(\epsilon_t\)更靠近。实际上,这个优化指标是通过简化得来的。扩散模型最早的优化指标是有肯定的数学意义的。

扩散模型,全称为扩散概率模型(Diffusion Probabilistic Model)。最简略的一类扩散模型,是去噪扩散概率模型(Denoising Diffusion Probabilistic Model),也就是常说的DDPM。DDPM的框架次要是由两篇论文建设起来的。第一篇论文是首次提出扩散模型思维的Deep Unsupervised Learning using Nonequilibrium Thermodynamics。在此基础上,Denoising Diffusion Probabilistic Models对最早的扩散模型做出了肯定的简化,让图像生成成果大幅晋升,促成了扩散模型的宽泛应用。咱们上一节看到的公式,全副是简化后的后果。

扩散概率模型的名字之所以有「概率」二字,是因为这个模型是在形容一个零碎的概率。精确来说,扩散模型是在形容经反向过程生成出某一项数据的概率。也就是说,扩散模型\(p_{\theta}(\mathbf{x}_0)\)是一个有着可训练参数\(\theta\)的模型,它形容了反向过程生成出数据\(\mathbf{x}_0\)的概率。\(p_{\theta}(\mathbf{x}_0)\)满足\(p_{\theta}(\mathbf{x}_0)=\int p_{\theta}(\mathbf{x}_{0:T})d\mathbf{x}_{1:T}\),其中\(p_{\theta}(\mathbf{x}_{0:T})\)就是咱们相熟的反向过程,只不过它是以概率计算的模式表白:

$$p_{\theta}(\mathbf{x}_{0:T})=p(\mathbf{x}_T)\prod_{t-1}^Tp_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$$

$$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}) = \mathcal{N}(\mathbf{x}_{t-1};\mu_{\theta}(\mathbf{x}_{t}, t), \Sigma_\theta(\mathbf{x}_{t}, t))$$

咱们上一节里见到的优化指标,是让去噪声操作\(p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})\)和加噪声操作的逆操作\(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)\)尽可能类似。然而,这个对形容并不确切。扩散模型本来的指标,是最大化\(p_{\theta}(\mathbf{x}_0)\)这个概率,其中\(\mathbf{x}_0\)是来自训练集的数据。换个角度说,给定一个训练集的数据\(\mathbf{x}_0\),通过前向过程和反向过程,扩散模型要让还原出\(\mathbf{x}_0\)的概率尽可能大。这也是咱们在本文结尾意识VAE时见到的优化指标。

最大化\(p_{\theta}(\mathbf{x}_0)\),个别会写成最小化其负对数值,即最小化\(-log \ p_{\theta}(\mathbf{x}_0)\)。应用和VAE相似的变分推理,能够把优化指标转换成优化一个叫做变分下界(variational lower bound, VLB)的量。它最终能够写成:

$$L_{VLB}=\mathbb{E}[D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p_\theta(\mathbf{x}_T))+\sum_{t=2}^{T}D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})) - logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})]$$

这里的\(D_{KL}(P||Q)\)示意散布P和Q之间的KL散度。KL散度是掂量两个散布类似度的指标。如果\(P, Q\)都是正态分布,则它们的KL散度能够由一个简略的公式给出。对于KL散度的常识能够参见我之前的文章:从零了解熵、穿插熵、KL散度。

其中,第一项\(D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p_\theta(\mathbf{x}_T))\)和可学习参数\(\theta\)无关(因为可学习参数只形容了每一步去噪声操作,也就是只形容了\(p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})\)),能够不去管它。那么这个优化指标就由两局部组成:

  1. 最小化\(D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))\)示意的是最大化每一个去噪声操作和加噪声逆操作的类似度。
  2. 最小化\(- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})\)就是已知\(\mathbf{x}_{1}\)时,让最初还原原图\(\mathbf{x}_{0}\)概率更高。

咱们别离看这两局部是怎么计算的。

对于第一局部,咱们先回顾一下正态分布之间的KL散度公式。设一维正态分布\(P, Q\)的公式如下:

$$\begin{aligned}P(x) = \frac{1}{\sqrt{2\pi}\sigma_1}exp(-\frac{(x - \mu_1)^2}{2\sigma_1^2}) \\Q(x) = \frac{1}{\sqrt{2\pi}\sigma_2}exp(-\frac{(x - \mu_2)^2}{2\sigma_2^2})\end{aligned}$$

$$D_{KL}(P||Q) = log\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}$$

而对于\(D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))\),依据前文的剖析,咱们晓得,待求方差\(\Sigma_\theta(\mathbf{x}_{t}, t)\)能够间接由计算失去。

$$\Sigma_\theta(\mathbf{x}_{t}, t) = \tilde{\beta}_t\mathbf{I}=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t\mathbf{I}$$

二者的比值是常量。所以,在计算KL散度时,不必管方差那一项了,只须要管均值那一项。

$$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))\to\frac{1}{2\tilde{\beta}_t^2}||\mu_{\theta}(\mathbf{x}_{t}, t)-\tilde{\mu}_{t}(\mathbf{x}_{t}, t)||^2$$

由依据之前的均值公式

$$\tilde{\mu}_t(\mathbf{x}_{t}, t) = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t)$$

$$\mu_{\theta}(\mathbf{x}_{t}, t) = \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t -\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}_{t}, t))\\$$

这一部分的优化指标能够化简成

$$\frac{(1 - \alpha_t)^2}{2\alpha_t(1-\bar{\alpha}_t)\tilde{\beta}_t^2}||\epsilon_t-\epsilon_{\theta}(\mathbf{x}_{t}, t)||^2$$

DDPM论文指出,如果把后面的系数全副丢掉的话,模型的成果更好。最终,咱们就能失去一个非常简单的优化指标:

$$||\epsilon_t-\epsilon_{\theta}(\mathbf{x}_{t}, t)||^2$$

这就是咱们上一节见到的优化指标。

当然,还没完,别忘了优化指标里还有\(- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})\)这一项。它的模式为:

$$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})=-log\frac{1}{\sqrt{2\pi}\tilde{\beta}_1^2}+\frac{||\mathbf{x}_{0} - \mu_{\theta}(\mathbf{x}_{1}, 1)||^2}{2\tilde{\beta}_1^2}$$

只管前面有\(\theta\)的那一项(留神,\(\alpha_1=\bar{\alpha}_1=1-\beta_1\)):

$$\begin{aligned}\frac{(\mathbf{x}_{0} - \mu_{\theta}(\mathbf{x}_{1}, 1))^2}{2\tilde{\beta}_1^2} &= \frac{1}{2\tilde{\beta}_1^2}||\mathbf{x}_{0}-\frac{1}{\sqrt{\alpha_1}}(\mathbf{x}_1 -\frac{1 - \alpha_1}{\sqrt{1-\bar{\alpha}_1}}\epsilon_\theta(\mathbf{x}_{1}, 1))||^2 \\&=\frac{1}{2\tilde{\beta}_1^2}||\mathbf{x}_{0}-\frac{1}{\sqrt{\alpha_1}}(\sqrt{\bar{\alpha}_1}\mathbf{x}_{0}+ \sqrt{1-\bar{\alpha}_1}\epsilon_1-\frac{1 - \alpha_1}{\sqrt{1-\bar{\alpha}_1}}\epsilon_\theta(\mathbf{x}_{1}, 1))||^2 \\&=\frac{1}{2\tilde{\beta}_1^2\alpha_1}|| \sqrt{1-\bar{\alpha}_1}\epsilon_1-\frac{1 - \alpha_1}{\sqrt{1-\bar{\alpha}_1}}\epsilon_\theta(\mathbf{x}_{1}, 1)||^2 \\&=\frac{1-\bar{\alpha}_1}{2\tilde{\beta}_1^2\alpha_1}|| \epsilon_1-\epsilon_\theta(\mathbf{x}_{1}, 1)||^2 \\\end{aligned}$$

这和那些KL散度项\(t=1\)时的模式雷同,咱们能够用雷同的形式简化优化指标,只保留\(|| \epsilon_1-\epsilon_\theta(\mathbf{x}_{1}, 1)||^2\)。这样,损失函数的模式全都是\(||\epsilon_t-\epsilon_{\theta}(\mathbf{x}_{t}, t)||^2\)了。

DDPM论文里写\(- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})\)这一项能够间接满足简化后的公式\(t=1\)时的状况,而没有去掉系数的过程。我在网上没找到文章解释这一点,只好按本人的了解来推导这个误差项了。不论如何,推导的过程不是那么重要,重要的是最初的简化模式。

总结

图像生成工作就是把随机生成的向量(噪声)映射成和训练图像相似的图像。为此,扩散模型把这个过程看成是对纯噪声图像的去噪过程。通过学习把图像逐渐变成纯噪声的逆操作,扩散模型能够把任何一个纯噪声图像变成有意义的图像,也就是实现图像生成。

对于不同水平的读者,应该对本文有不同的意识。

对于只想理解扩散模型大略原理的读者,只须要浏览第一节,并大略理解:

  • 图像生成工作的通常做法
  • 图像生成工作须要监督
  • VAE通过把图像编码再解码来训练一个解码器
  • 扩散模型是一类非凡的VAE,它的编码固定为加噪声,解码固定为去噪声

对于想认真学习扩散模型的读者,只需读懂第二节的次要内容:

  • 扩散模型的优化指标:让反向过程尽可能成为正向过程的逆操作
  • 正向过程的公式
  • 反向过程的做法(采样算法)
  • 加噪声逆操作的均值和方差在给定\(\mathbf{x}_{0}\)时能够求进去的,加噪声逆操作的均值就是去噪声的学习指标
  • 简化后的损失函数与训练算法

对有学有余力对数学感兴趣的读者,能够看一看第三节的内容:

  • 加噪声逆操作均值和方差的推导
  • 扩散模型最早的优化指标与DDPM论文是如何简化优化指标的

我集体认为,因为扩散模型的优化指标曾经被大幅度简化,除非你的钻研指标是改良扩散模型自身,否则没必要花过多的工夫钻研数学原理。在学习时,倡议快点看懂扩散模型的整体思维,搞懂最外围的训练算法和采样算法,跑通代码。之后就能够去看较新的论文了。

在附录中,我给出了一份DDPM的简略实现。欢送大家参考,并本人入手复现一遍DDPM。

参考资料与学习倡议

网上绝大多数的中英文教程都是照搬 https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ 这篇文章的。这篇文章像教科书一样谨严,适宜有肯定数学根底的人浏览,但不适宜给初学者学习。倡议在弄懂扩散模型的大略原理后再来浏览这篇文章补充细节。

少数介绍扩散模型的文章对没学过相干数学知识的人来说很不敌对,我在浏览此类文章时碰到了大量的问题:为什么前向公式里有个\(\sqrt{1-\beta}\)?为什么忽然冒出一个疾速算\(\mathbf{x}_{t}\)的公式?为什么反向过程里来了个贝叶斯公式?优化指标是什么?\(-log \ p_{\theta}(\mathbf{x}_0)\)是什么?为什么优化指标里一大堆项,每一项的意义又是什么?为什么最初莫名其妙算一个\(\epsilon\)?为什么采样时\(t=0\)就不必加方差项了?好不容易,我才把这些问题缓缓搞懂,并在本文做出了解释。心愿我的解答可能帮忙到同样有这些困惑的读者。想逐渐学习扩散模型,能够先看懂我这篇文章的大略解说,再去其余文章里学懂一些细节。无论是教,还是学,最重要的都是搞懂整体思路,晓得动机,最初再去强调细节。

这里还有篇文章给出了扩散模型中数学公式的具体推导,并补充了变分推理的背景介绍,适宜从头学起:https://arxiv.org/abs/2208.11970

想深刻学习DDPM,能够看一看最重要的两篇论文:Deep Unsupervised Learning using Nonequilibrium ThermodynamicsDenoising Diffusion Probabilistic Models。当然,后者更重要一些,外面的一些试验后果仍有浏览价值。

我在代码复现时参考了这篇文章。绝对于网上的其余开源DDPM实现,这份代码比拟简短易懂,更适宜学习。不过,这份代码有一点问题。它的神经网络不够弱小,采样后果会有一点问题。

附录:代码复现

在这个我的项目中,咱们要用PyTorch实现一个基于U-Net的DDPM,并在MNIST数据集(经典的手写数字数据集)上训练它。模型几分钟就能训练完,咱们能够不便地做各种各样的试验。

后续解说只会给出代码片段,残缺的代码请参见 https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddpm 。git clone 仓库并装置后,能够间接运行目录里的main.py训练模型并采样。

获取数据集

PyTorch的torchvision提供了获取了MNIST的接口,咱们只须要用上面的函数就能够生成MNIST的Dataset实例。参数中,root为数据集的下载门路,download为是否主动下载数据集。令download=True的话,第一次调用该函数时会主动下载数据集,而第二次之后就不必下载了,函数会读取存储在root里的数据。

mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)

咱们能够用上面的代码来下载MNIST并输入该数据集的一些信息:

import torchvisionfrom torchvision.transforms import ToTensordef download_dataset():    mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)    print('length of MNIST', len(mnist))    id = 4    img, label = mnist[id]    print(img)    print(label)    # On computer with monitor    # img.show()    img.save('work_dirs/tmp.jpg')    tensor = ToTensor()(img)    print(tensor.shape)    print(tensor.max())    print(tensor.min())if __name__ == '__main__':    download_dataset()

执行这段代码,输入大抵为:

length of MNIST 60000<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>9torch.Size([1, 28, 28])tensor(1.)tensor(0.)

第一行输入表明,MNIST数据集里有60000张图片。而从第二行和第三行输入中,咱们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格局的图片,标签表明该图片是哪个数字。咱们能够用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最初三行输入表明,每一张图片都是单通道图片(灰度图),色彩值的取值范畴是0~1。

咱们能够查看一下每张图片的样子。如果你是在用带显示器的电脑,能够去掉img.show那一行的正文,间接查看图片;如果你是在用服务器,能够去img.save的门路里查看图片。该图片的应该长这个样子:

咱们能够用上面的代码预处理数据并创立DataLoader。因为DDPM会把图像和正态分布关联起来,咱们更心愿图像色彩值的取值范畴是[-1, 1]。为此,咱们能够对图像做一个线性变换,减0.5再乘2。

def get_dataloader(batch_size: int):    transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])    dataset = torchvision.datasets.MNIST(root='./data/mnist',                                         transform=transform)    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

DDPM 类

在代码中,咱们要实现一个DDPM类。它保护了扩散过程中的一些常量(比方\(\alpha\)),并且能够计算正向过程和反向过程的后果。

先来实现一下DDPM类的初始化函数。一开始,咱们听从论文的配置,用torch.linspace(min_beta, max_beta, n_steps)min_betamax_beta线性地生成n_steps个时刻的\(\beta\)。接着,咱们依据公式\(\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i\),计算每个时刻的alphaalpha_bar。留神,为了不便实现,咱们让t的取值从0开始,要比论文里的\(t\)少1。

import torchclass DDPM():    # n_steps 就是论文里的 T    def __init__(self,                 device,                 n_steps: int,                 min_beta: float = 0.0001,                 max_beta: float = 0.02):        betas = torch.linspace(min_beta, max_beta, n_steps).to(device)        alphas = 1 - betas        alpha_bars = torch.empty_like(alphas)        product = 1        for i, alpha in enumerate(alphas):            product *= alpha            alpha_bars[i] = product        self.betas = betas        self.n_steps = n_steps        self.alphas = alphas        self.alpha_bars = alpha_bars
局部实现会让 DDPM 继承torch.nn.Module,但我认为这样不好。DDPM自身不是一个神经网络,它只是形容了前向过程和后向过程的一些计算。只有波及可学习参数的神经网络类才应该继承 torch.nn.Module

筹备好了变量后,咱们能够来实现DDPM类的其余办法。先实现正向过程办法,该办法会依据公式\(\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t\)计算正向过程中的\(\mathbf{x}_t\)。

def sample_forward(self, x, t, eps=None):    alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)    if eps is None:        eps = torch.randn_like(x)    res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x    return res

这里要解释一些PyTorch编程上的细节。这份代码中,self.alpha_bars是一个一维Tensor。而在并行训练中,咱们个别会令t为一个形态为(batch_size, )Tensor。PyTorch容许咱们间接用self.alpha_bars[t]self.alpha_bars里取出batch_size个数,就像用一个一般的整型索引来从数组中取出一个数一样。有些实现会用torch.gatherself.alpha_bars里取数,其作用是一样的。

咱们能够随机从训练集取图片做测试,看看它们在前向过程中是怎么逐渐变成噪声的。

接下来实现反向过程。在反向过程中,DDPM会用神经网络预测每一轮去噪的均值,把\(\mathbf{x}_t\)还原回\(\mathbf{x}_0\),以实现图像生成。反向过程即对应论文中的采样算法。

其实现如下:

def sample_backward(self, img_shape, net, device, simple_var=True):    x = torch.randn(img_shape).to(device)    net = net.to(device)    for t in range(self.n_steps - 1, -1, -1):        x = self.sample_backward_step(x, t, net, simple_var)    return xdef sample_backward_step(self, x_t, t, net, simple_var=True):    n = x_t.shape[0]    t_tensor = torch.tensor([t] * n,                            dtype=torch.long).to(x_t.device).unsqueeze(1)    eps = net(x_t, t_tensor)    if t == 0:        noise = 0    else:        if simple_var:            var = self.betas[t]        else:            var = (1 - self.alpha_bars[t - 1]) / (                1 - self.alpha_bars[t]) * self.betas[t]        noise = torch.randn_like(x_t)        noise *= torch.sqrt(var)    mean = (x_t -            (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *            eps) / torch.sqrt(self.alphas[t])    x_t = mean + noise    return x_t

其中,sample_backward是用来给内部调用的办法,而sample_backward_step是执行一步反向过程的办法。

sample_backward会随机生成纯噪声x(对应\(\mathbf{x}_T\)),再令tn_steps - 10,调用sample_backward_step

def sample_backward(self, img_shape, net, device, simple_var=True):    x = torch.randn(img_shape).to(device)    net = net.to(device)    for t in range(self.n_steps - 1, -1, -1):        x = self.sample_backward_step(x, t, net, simple_var)    return x

sample_backward_step中,咱们先筹备好这一步的神经网络输入eps。为此,咱们要把整型的t转换成一个格局正确的Tensor。思考到输出里可能有多个batch,咱们先获取batch size n,再依据它来生成t_tensor

def sample_backward_step(self, x_t, t, net, simple_var=True):    n = x_t.shape[0]    t_tensor = torch.tensor([t] * n,                            dtype=torch.long).to(x_t.device).unsqueeze(1)    eps = net(x_t, t_tensor)

之后,咱们来解决反向过程公式中的方差项。依据伪代码,咱们仅在t非零的时候算方差项。方差项用到的方差有两种取值,成果差不多,咱们用simple_var来管制选哪种取值形式。获取方差后,咱们再随机采样一个噪声,依据公式,失去方差项。

if t == 0:    noise = 0else:    if simple_var:        var = self.betas[t]    else:        var = (1 - self.alpha_bars[t - 1]) / (            1 - self.alpha_bars[t]) * self.betas[t]    noise = torch.randn_like(x_t)    noise *= torch.sqrt(var)

最初,咱们把eps和方差项套入公式,失去这一步更新过后的图像x_t

mean = (x_t -        (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *        eps) / torch.sqrt(self.alphas[t])x_t = mean + noisereturn x_t

稍后实现了训练后,咱们再来看反向过程的输入后果。

训练算法

接下来,咱们先跳过神经网络的实现,间接实现论文里的训练算法。

再回顾一遍伪代码。首先,咱们要随机选取训练图片\(\mathbf{x}_{0}\),随机生成以后要训练的时刻\(t\),以及随机生成一个生成\(\mathbf{x}_{t}\)的高斯噪声。之后,咱们把\(\mathbf{x}_{t}\)和\(t\)输出进神经网络,尝试预测噪声。最初,咱们以预测噪声和理论噪声的均方误差为损失函数做梯度降落。

为此,咱们能够用上面的代码实现训练。

import torchimport torch.nn as nnfrom dldemos.ddpm.dataset import get_dataloader, get_img_shapefrom dldemos.ddpm.ddpm import DDPMimport cv2import numpy as npimport einopsbatch_size = 512n_epochs = 100def train(ddpm: DDPM, net, device, ckpt_path):    # n_steps 就是公式里的 T    # net 是某个继承自 torch.nn.Module 的神经网络    n_steps = ddpm.n_steps    dataloader = get_dataloader(batch_size)    net = net.to(device)    loss_fn = nn.MSELoss()    optimizer = torch.optim.Adam(net.parameters(), 1e-3)    for e in range(n_epochs):        for x, _ in dataloader:            current_batch_size = x.shape[0]            x = x.to(device)            t = torch.randint(0, n_steps, (current_batch_size, )).to(device)            eps = torch.randn_like(x).to(device)            x_t = ddpm.sample_forward(x, t, eps)            eps_theta = net(x_t, t.reshape(current_batch_size, 1))            loss = loss_fn(eps_theta, eps)            optimizer.zero_grad()            loss.backward()            optimizer.step()    torch.save(net.state_dict(), ckpt_path)

代码的次要逻辑都在循环里。首先是实现训练数据\(\mathbf{x}_{0}\)、\(t\)、噪声的采样。采样\(\mathbf{x}_{0}\)的工作能够交给PyTorch的DataLoader实现,每轮遍历失去的x就是训练数据。\(t\)的采样能够用torch.randint函数随机从[0, n_steps - 1]取数。采样高斯噪声能够间接用torch.randn_like(x)生成一个和训练图片x形态一样的符合标准正态分布的图像。

for x, _ in dataloader:    current_batch_size = x.shape[0]    x = x.to(device)    t = torch.randint(0, n_steps, (current_batch_size, )).to(device)    eps = torch.randn_like(x).to(device)

之后计算\(\mathbf{x}_{t}\)并将其和\(t\)输出进神经网络net。计算\(\mathbf{x}_{t}\)的工作会由DDPM类的sample_forward办法实现,咱们在上文曾经实现了它。

x_t = ddpm.sample_forward(x, t, eps)eps_theta = net(x_t, t.reshape(current_batch_size, 1))

失去了预测的噪声eps_theta,咱们调用PyTorch的API,算均方误差并调用优化器即可。

loss_fn = nn.MSELoss()optimizer = torch.optim.Adam(net.parameters(), 1e-3)...        loss = loss_fn(eps_theta, eps)        optimizer.zero_grad()        loss.backward()        optimizer.step()

去噪神经网络

在DDPM中,实践上咱们能够用任意一种神经网络架构。但因为DDPM工作非常靠近图像去噪工作,而U-Net又是去噪工作中最常见的网络架构,因而绝大多数DDPM都会应用基于U-Net的神经网络。

我始终想训练一个尽可能简略的模型。通过屡次试验,我发现DDPM的神经网络很难训练。哪怕是对于比较简单的MNIST数据集,构造差一点的网络(比方纯ResNet)都不太行,只有带了残差块和时序编码的U-Net能力较好地实现去噪。注意力模块倒是能够不必加上。

因为神经网络构造并不是DDPM学习的重点,我这里就不对U-Net的写法做讲解,而是间接贴上代码了。代码中大部分内容都和一般的U-Net无异。惟一要留神的中央就是时序编码。去噪网络的输出除了图像外,还有一个工夫戳t。咱们要思考怎么把t的信息和输出图像信息交融起来。大部分人的做法是对t进行Transformer中的地位编码,把该编码加到图像的每一处上。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom dldemos.ddpm.dataset import get_img_shapeclass PositionalEncoding(nn.Module):    def __init__(self, max_seq_len: int, d_model: int):        super().__init__()        # Assume d_model is an even number for convenience        assert d_model % 2 == 0        pe = torch.zeros(max_seq_len, d_model)        i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)        j_seq = torch.linspace(0, d_model - 2, d_model // 2)        pos, two_i = torch.meshgrid(i_seq, j_seq)        pe_2i = torch.sin(pos / 10000**(two_i / d_model))        pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))        pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(max_seq_len, d_model)        self.embedding = nn.Embedding(max_seq_len, d_model)        self.embedding.weight.data = pe        self.embedding.requires_grad_(False)    def forward(self, t):        return self.embedding(t)class ResidualBlock(nn.Module):    def __init__(self, in_c: int, out_c: int):        super().__init__()        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)        self.bn1 = nn.BatchNorm2d(out_c)        self.actvation1 = nn.ReLU()        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)        self.bn2 = nn.BatchNorm2d(out_c)        self.actvation2 = nn.ReLU()        if in_c != out_c:            self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, 1),                                          nn.BatchNorm2d(out_c))        else:            self.shortcut = nn.Identity()    def forward(self, input):        x = self.conv1(input)        x = self.bn1(x)        x = self.actvation1(x)        x = self.conv2(x)        x = self.bn2(x)        x += self.shortcut(input)        x = self.actvation2(x)        return xclass ConvNet(nn.Module):    def __init__(self,                 n_steps,                 intermediate_channels=[10, 20, 40],                 pe_dim=10,                 insert_t_to_all_layers=False):        super().__init__()        C, H, W = get_img_shape()  # 1, 28, 28        self.pe = PositionalEncoding(n_steps, pe_dim)        self.pe_linears = nn.ModuleList()        self.all_t = insert_t_to_all_layers        if not insert_t_to_all_layers:            self.pe_linears.append(nn.Linear(pe_dim, C))        self.residual_blocks = nn.ModuleList()        prev_channel = C        for channel in intermediate_channels:            self.residual_blocks.append(ResidualBlock(prev_channel, channel))            if insert_t_to_all_layers:                self.pe_linears.append(nn.Linear(pe_dim, prev_channel))            else:                self.pe_linears.append(None)            prev_channel = channel        self.output_layer = nn.Conv2d(prev_channel, C, 3, 1, 1)    def forward(self, x, t):        n = t.shape[0]        t = self.pe(t)        for m_x, m_t in zip(self.residual_blocks, self.pe_linears):            if m_t is not None:                pe = m_t(t).reshape(n, -1, 1, 1)                x = x + pe            x = m_x(x)        x = self.output_layer(x)        return xclass UnetBlock(nn.Module):    def __init__(self, shape, in_c, out_c, residual=False):        super().__init__()        self.ln = nn.LayerNorm(shape)        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)        self.activation = nn.ReLU()        self.residual = residual        if residual:            if in_c == out_c:                self.residual_conv = nn.Identity()            else:                self.residual_conv = nn.Conv2d(in_c, out_c, 1)    def forward(self, x):        out = self.ln(x)        out = self.conv1(out)        out = self.activation(out)        out = self.conv2(out)        if self.residual:            out += self.residual_conv(x)        out = self.activation(out)        return outclass UNet(nn.Module):    def __init__(self,                 n_steps,                 channels=[10, 20, 40, 80],                 pe_dim=10,                 residual=False) -> None:        super().__init__()        C, H, W = get_img_shape()        layers = len(channels)        Hs = [H]        Ws = [W]        cH = H        cW = W        for _ in range(layers - 1):            cH //= 2            cW //= 2            Hs.append(cH)            Ws.append(cW)        self.pe = PositionalEncoding(n_steps, pe_dim)        self.encoders = nn.ModuleList()        self.decoders = nn.ModuleList()        self.pe_linears_en = nn.ModuleList()        self.pe_linears_de = nn.ModuleList()        self.downs = nn.ModuleList()        self.ups = nn.ModuleList()        prev_channel = C        for channel, cH, cW in zip(channels[0:-1], Hs[0:-1], Ws[0:-1]):            self.pe_linears_en.append(                nn.Sequential(nn.Linear(pe_dim, prev_channel), nn.ReLU(),                              nn.Linear(prev_channel, prev_channel)))            self.encoders.append(                nn.Sequential(                    UnetBlock((prev_channel, cH, cW),                              prev_channel,                              channel,                              residual=residual),                    UnetBlock((channel, cH, cW),                              channel,                              channel,                              residual=residual)))            self.downs.append(nn.Conv2d(channel, channel, 2, 2))            prev_channel = channel        self.pe_mid = nn.Linear(pe_dim, prev_channel)        channel = channels[-1]        self.mid = nn.Sequential(            UnetBlock((prev_channel, Hs[-1], Ws[-1]),                      prev_channel,                      channel,                      residual=residual),            UnetBlock((channel, Hs[-1], Ws[-1]),                      channel,                      channel,                      residual=residual),        )        prev_channel = channel        for channel, cH, cW in zip(channels[-2::-1], Hs[-2::-1], Ws[-2::-1]):            self.pe_linears_de.append(nn.Linear(pe_dim, prev_channel))            self.ups.append(nn.ConvTranspose2d(prev_channel, channel, 2, 2))            self.decoders.append(                nn.Sequential(                    UnetBlock((channel * 2, cH, cW),                              channel * 2,                              channel,                              residual=residual),                    UnetBlock((channel, cH, cW),                              channel,                              channel,                              residual=residual)))            prev_channel = channel        self.conv_out = nn.Conv2d(prev_channel, C, 3, 1, 1)    def forward(self, x, t):        n = t.shape[0]        t = self.pe(t)        encoder_outs = []        for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,                                            self.downs):            pe = pe_linear(t).reshape(n, -1, 1, 1)            x = encoder(x + pe)            encoder_outs.append(x)            x = down(x)        pe = self.pe_mid(t).reshape(n, -1, 1, 1)        x = self.mid(x + pe)        for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,                                                       self.decoders, self.ups,                                                       encoder_outs[::-1]):            pe = pe_linear(t).reshape(n, -1, 1, 1)            x = up(x)            pad_x = encoder_out.shape[2] - x.shape[2]            pad_y = encoder_out.shape[3] - x.shape[3]            x = F.pad(x, (pad_x // 2, pad_x - pad_x // 2, pad_y // 2,                          pad_y - pad_y // 2))            x = torch.cat((encoder_out, x), dim=1)            x = decoder(x + pe)        x = self.conv_out(x)        return xconvnet_small_cfg = {    'type': 'ConvNet',    'intermediate_channels': [10, 20],    'pe_dim': 128}convnet_medium_cfg = {    'type': 'ConvNet',    'intermediate_channels': [10, 10, 20, 20, 40, 40, 80, 80],    'pe_dim': 256,    'insert_t_to_all_layers': True}convnet_big_cfg = {    'type': 'ConvNet',    'intermediate_channels': [20, 20, 40, 40, 80, 80, 160, 160],    'pe_dim': 256,    'insert_t_to_all_layers': True}unet_1_cfg = {'type': 'UNet', 'channels': [10, 20, 40, 80], 'pe_dim': 128}unet_res_cfg = {    'type': 'UNet',    'channels': [10, 20, 40, 80],    'pe_dim': 128,    'residual': True}def build_network(config: dict, n_steps):    network_type = config.pop('type')    if network_type == 'ConvNet':        network_cls = ConvNet    elif network_type == 'UNet':        network_cls = UNet    network = network_cls(n_steps, **config)    return network

试验后果与采样

把之前的所有代码综合一下,咱们以带残差块的U-Net为去噪网络,执行训练。

if __name__ == '__main__':    n_steps = 1000    config_id = 4    device = 'cuda'    model_path = 'dldemos/ddpm/model_unet_res.pth'    config = unet_res_cfg    net = build_network(config, n_steps)    ddpm = DDPM(device, n_steps)    train(ddpm, net, device=device, ckpt_path=model_path)

依照默认训练配置,在3090上花5分钟不到,训练30~40个epoch即可让网络根本收敛。最终收敛时loss在0.023~0.024左右。

batch size: 512epoch 0 loss: 0.23103461712201437 elapsed 7.01sepoch 1 loss: 0.0627968365987142 elapsed 13.66sepoch 2 loss: 0.04828845852613449 elapsed 20.25sepoch 3 loss: 0.04148937337398529 elapsed 26.80sepoch 4 loss: 0.03801360730528831 elapsed 33.37sepoch 5 loss: 0.03604260584712028 elapsed 39.96sepoch 6 loss: 0.03357676289876302 elapsed 46.57sepoch 7 loss: 0.0335664684087038 elapsed 53.15s...epoch 30 loss: 0.026149748386939366 elapsed 204.64sepoch 31 loss: 0.025854381563266117 elapsed 211.24sepoch 32 loss: 0.02589433005253474 elapsed 217.84sepoch 33 loss: 0.026276464049021404 elapsed 224.41s...epoch 96 loss: 0.023299352884292603 elapsed 640.25sepoch 97 loss: 0.023460942271351815 elapsed 646.90sepoch 98 loss: 0.023584651704629263 elapsed 653.54sepoch 99 loss: 0.02364126600921154 elapsed 660.22s

训练这个网络时,并没有特地好的测试指标,咱们只能通过观察采样图像来评估网络的体现。咱们能够用上面的代码调用DDPM的反向流传办法,生成多幅图像并保留下来。

def sample_imgs(ddpm,                net,                output_path,                n_sample=81,                device='cuda',                simple_var=True):    net = net.to(device)    net = net.eval()    with torch.no_grad():        shape = (n_sample, *get_img_shape())  # 1, 3, 28, 28        imgs = ddpm.sample_backward(shape,                                    net,                                    device=device,                                    simple_var=simple_var).detach().cpu()        imgs = (imgs + 1) / 2 * 255        imgs = imgs.clamp(0, 255)        imgs = einops.rearrange(imgs,                                '(b1 b2) c h w -> (b1 h) (b2 w) c',                                b1=int(n_sample**0.5))        imgs = imgs.numpy().astype(np.uint8)        cv2.imwrite(output_path, imgs)

一切顺利的话,咱们能够失去一些不错的生成后果。下图是我失去的一些生成图片:

大部分生成的图片都对应一个阿拉伯数字,它们和训练集MNIST里的图片十分靠近。这算是一个不错的生成后果。

如果神经网络的拟合能力较弱,生成后果就会差很多。下图是我训练一个简略的ResNet后失去的采样后果:

能够看出,每幅图片都很乱,根本对应不上一个数字。这就是一个较差的训练后果。

如果网络再差一点,可能会生成纯黑或者纯白的图片。这是因为网络的预测后果不准,在反向过程中,图像的均值一直偏移,偏移到远大于1或者远小于-1的值了。

总结一下,在复现DDPM时,最次要是要学习DDPM论文的两个算法,即训练算法和采样算法。两个算法很简略,能够轻松地把它们翻译成代码。而为了胜利实现复现,还须要花一点心理在编写U-Net上,尤其是留神解决工夫戳的局部。

本文参加了SegmentFault 思否写作流动,欢送正在浏览的你也退出。