作者:京东批发 刘岩
扩散模型解说
前沿
人工智能生成内容(AI Generated Content,AIGC)近年来成为了十分前沿的一个钻研方向,生成模型目前有四个流派,别离是生成反抗网络(Generative Adversarial Models,GAN),变分自编码器(Variance Auto-Encoder,VAE),标准化流模型(Normalization Flow,NF)以及这里要介绍的扩散模型(Diffusion Models,DM)。扩散模型是受到热力学中的一个分支,它的思维起源是非均衡热力学(Non-equilibrium thermodynamics)。扩散模型的算法实践根底是通过变分推断(Variational Inference)训练参数化的马尔可夫链(Markov Chain),它在许多工作上展示了超过 GAN 等其它生成模型的成果,例如最近十分炽热的 OpenAI 的 DALL-E 2,Stability.ai 的 Stable Diffusion 等。这些成果惊艳的模型扩散模型的实践根底便是咱们这里要介绍的提出扩散模型的文章 [1] 和十分重要的 DDPM[2],扩散模型的实现并不简单,但其背地的数学原理却十分丰盛。在这里我会介绍这些重要的数学原理,但省去了这些公式的推导计算,如果你对这些推导感兴趣,能够学习参考文献 [4,5,11] 的相干内容。我在这里次要以一个绝对简略的角度来解说扩散模型,帮忙你疾速入门这个十分重要的生成算法。
1. 背景常识: 生成模型
目前生成模型次要有图 1 所示的四类。其中 GAN 的原理是通过判断器和生成器的相互博弈来让生成器生成足以以假乱真的图像。VAE 的原理是通过一个编码器将输出图像编码成特征向量,它用来学习高斯分布的均值和方差,而解码器则能够将特征向量转化为生成图像,它侧重于学习生成能力。流模型是从一个简略的散布开始,通过一系列可逆的转换函数将散布转化成指标散布。扩散模型先通过正向过程将噪声逐步退出到数据中,而后通过反向过程预测每一步退出的噪声,通过将噪声去掉的形式逐步还原失去无噪声的图像,扩散模型实质上是一个马尔可夫架构,只是其中训练过程用到了深度学习的 BP,但它更属于数学层面的翻新。这也就是为什么很多计算机的同学看扩散模型相干的论文会如此费劲。
图 1:生成模型的四种类型 [4]
扩散模型中最重要的思维根基是马尔可夫链,它的一个要害性质是平稳性。即如果一个概率随工夫变动,那么再马尔可夫链的作用下,它会趋向于某种安稳散布,工夫越长,散布越安稳。如图 2 所示,当你向一滴水中滴入一滴颜料时,无论你滴在什么地位,只有工夫足够长,最终颜料都会平均的散布在水溶液中。这也就是扩散模型的前向过程。
图 2:颜料分子在水溶液中的扩散过程
如果咱们可能在扩散的过程颜料分子的地位、挪动速度、方向等挪动属性。那么也能够依据正向过程的保留的挪动属性从一杯被溶解了颜料的水中反推颜料的滴入地位。这边是扩散模型的反向过程。记录挪动属性的快照便是咱们要训练的模型。
2. 扩散模型
在这一部分咱们将集中介绍扩散模型的数学原理以及推导的几个重要性质,因为推导过程波及大量的数学知识然而对了解扩散模型自身思维并无太大帮忙,所以这里我会省去推导的过程而间接给出论断。然而我也会给出推导过程的出处,对其中的推导过程比拟感兴趣的请自行查看。
2.1 计算原理
扩散模型简略的讲就是通过神经网络学习从纯噪声数据逐步对数据进行去噪的过程,它蕴含两个步骤,如图 3:
图 3:DDPM 的前向加噪和后向去噪过程
2.1.1 前向过程
2.1.2 后向过程
2.1.3 指标函数
那么问题来了,咱们到底应用什么样的优化指标能力比拟好的预测高斯噪声的散布呢?一个比较复杂的形式是应用变分自编码器的最大化证据下界(Evidence Lower Bound, ELBO)的思维来推导,如式 (6),推导具体过程见论文[11] 的式 (47) 到式(58),这里次要用到了贝叶斯定理和琴生不等式。
式 (6) 的推导细节并不重要,咱们须要重点关注的是它的最终等式的三个组成部分,上面咱们别离介绍它们:
图 4:扩散模型的去噪匹配项在每一步都要拟合乐音的实在后验散布和预计散布
实在后验散布能够应用贝叶斯定理进行推导,最终后果如式 (8),推导过程见论文[11] 的式 (71) 到式(84)。
\(p{\boldsymbol{\theta}}\left(\boldsymbol{x}{t-1} \mid \boldsymbol{x}t\right) = \mathcal N(\boldsymbol x{t-1}; \mu\theta(\boldsymbol x\_t, t), \Sigma\_q(t)) \tag9\)
2.1.4 模型训练
尽管下面咱们介绍了很多内容,并给出了大量公式,但得益于推导出的几个重要性质,扩散模型的训练并不简单,它的训练伪代码见算法 1。
2.1.5 样本生成
2.2 算法实现
2.2.1 模型构造
DDPM 在预测施加的噪声时,它的输出是施加噪声之后的图像,预测内容是和输出图像雷同尺寸的噪声,所以它能够看做一个 Img2Img 的工作。DDPM 抉择了 U -Net[9]作为噪声预测的模型构造。U-Net 是一个 U 形的网络结构,它由编码器,解码器以及编码器和解码器之间的跨层连贯(残差连贯)组成。其中编码器将图像降采样成一个特色,解码器将这个特色上采样为指标噪声,跨层连贯用于拼接编码器和解码器之间的特色。
图 5:U-Net 的网络结构
上面咱们介绍 DDPM 的模型构造的重要组件。首先在 U -Net 的卷积局部,DDPM 应用了宽残差网络(Wide Residual Network,WRN)[12]作为外围构造,WRN 是一个比规范残差网络层数更少,然而通道数更多的网络结构。也有作者复现发现 ConvNeXT 作为根底构造会获得十分显著的成果晋升 [13,14]。这里咱们能够依据训练资源灵便的调整卷积构造以及具体的层数等超参。因为咱们在扩散过程的整个流程中都共享同一套参数,为了辨别不同的工夫片,作者借鉴了 Transformer [15] 的地位编码的思维,采纳了正弦地位嵌入对工夫 $t$ 进行了编码,这使得模型在预测噪声时晓得它预测的是批次中别离是哪个工夫片增加的噪声。在卷积层之间,DDPM 增加了一个注意力层。这里咱们能够应用 Transformer 中提出的自注意力机制或是多头自注意力机制。[13]则提出了一个线性注意力机制的模块,它的特点是耗费的工夫以及占用的内存和序列长度是线性相关的,比照传统注意力机制的平方相干要高效很多。在进行归一化时,DDPM 抉择了组归一化(Group Normalization,GN)[16]。最初,对于 U -Net 中的降采样和上采样操作,DDPM 别离抉择了步长为 2 的卷积以及反卷积。
确定了这些组件,咱们便能够搭建用于 DDPM 的 U -Net 的模型了。从第 2.1 节的介绍咱们晓得,模型的输出为形态为 (batch\_size, num\_channels, height, width) 的噪声图像和形态为 (batch\_size,1) 的噪声程度,返回的是形态为 (batch\_size, num_channels, height, width) 的预测噪声,咱们搭建的用于噪声预测的模型构造如下:
- 首先在噪声图像 \(\boldsymbol x_0\)上利用卷积层,并为噪声程度 $t$ 计算工夫嵌入;
- 接下来是降采样阶段。采纳的模型构造顺次是两个卷积(WRNS 或是 ConvNeXT)+GN+Attention+ 降采样层;
- 在网络的最两头,顺次是卷积层 +Attention+ 卷积层;
- 接下来是上采样阶段。它首先会应用 Short-cut 拼接来自降采样中同样尺寸的卷积,再之后是两个卷积 +GN+Attention+ 上采样层。
- 最初是应用 WRNS 或是 ConvNeXT 作为输入层的卷积。
U-Net 类的 forword 函数如上面代码片段所示,残缺的实现代码参照[3]。
def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
2.2.2 前向加噪
图 6:一张图顺次通过 0 次,50 次,100 次,150 次以及 199 次加噪后的效果图
依据式 (14) 咱们晓得,扩散模型的损失函数计算的是两张图像的相似性,因而咱们能够抉择应用回归算法的所有损失函数,以 MSE 为例,前向过程的外围代码如上面代码片段。
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
# 1. 依据时刻 t 计算随机噪声散布,并对图像 x_start 进行加噪
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# 2. 依据噪声图像以及时刻 t,预测增加的噪声
predicted_noise = denoise_model(x_noisy, t)
# 3. 比照增加的噪声和预测的噪声的相似性
loss = F.mse_loss(noise, predicted_noise)
return loss
2.2.3 样本生成
依据 2.1.5 节介绍的样本生成流程,它的外围代码片段所示,对于这段代码的解说我通过正文增加到了代码片段中。
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# 应用式 (13) 计算模型的均值
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0:
return model_mean
else:
# 获取保留的方差
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# 算法 2 的第 4 行
return model_mean + torch.sqrt(posterior_variance_t) * noise
# 算法 2 的流程,然而咱们保留了所有两头样本
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs
最初咱们看下在人脸图像数据集下训练的模型,一批随机噪声通过逐步去噪变成人脸图像的示例。
图 7:扩散模型由随机噪声通过去噪逐步生成人脸图像
3. 总结
这里咱们以 DDPM 为例介绍了另一个派别的生成算法:扩散模型。扩散模型是一个基于马尔可夫链的数学模型,它通过预测每个工夫片增加的噪声来进行模型的训练。作为近日来引发热烈探讨的 ControlNet,Stable Diffusion 等模型的底层算法,咱们非常有必要对其有所理解。DDPM 的实现并不简单,这得益于大量数学界大佬通过大量的数学推导将整个扩散过程和反向去噪过程进行了精彩的化简,这才有了 DDPM 的大道至简的实现。DDPM 作为一个扩散模型的基石算法,它有着很多晚期算法的独特问题:
- 采样速度慢:DDPM 的去噪是从时刻 $T$ 到时刻 $1$ 的一个残缺的马尔可夫链的计算,尤其是 DDPM 还须要一个比拟大的 $T$ 能力保障比拟好的成果,这就导致了 DDPM 的采样过程注定是十分慢的;
- 生成成果差:DDPM 的成果并不能说是十分好,尤其是对于高分辨率图像的生成。这一方面是因为它的计算速度限制了它扩大到更大的模型;另一方面它的设计还有一些问题,例如逐像素的计算损失并应用雷同权值而疏忽图像中的主体并不是十分好的策略。
- 内容不可控:咱们能够看出,DDPM 生成的内容齐全还是取决于它的训练集。它并没有引入一些先验条件,因而并不能通过管制图像中的细节来生成咱们制订的内容。
咱们当初曾经晓得,DDPM 的这些问题已大幅失去改善,当初基于扩散模型生成的图像曾经达到甚至超过人类少数的画师的成果,我也会在之后逐步给出这些优化计划的解说。
Reference
[1] Sohl-Dickstein, Jascha, et al. “Deep unsupervised learning using nonequilibrium thermodynamics.” _International Conference on Machine Learning_. PMLR, 2015.
[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. “Denoising diffusion probabilistic models.” Advances in Neural Information Processing Systems 33 (2020): 6840-6851.
[3] https://huggingface.co/blog/annotated-diffusion
[4] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#simplification
[5] https://openai.com/blog/generative-models/
[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. “Improved denoising diffusion probabilistic models.” _International Conference on Machine Learning_. PMLR, 2021.
[7] Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).
[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. “Reducing the dimensionality of data with neural networks.” science 313.5786 (2006): 504-507.
[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.
[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for semantic segmentation.” _Proceedings of the IEEE conference on computer vision and pattern recognition_. 2015.
[11] Luo, Calvin. “Understanding diffusion models: A unified perspective.” arXiv preprint arXiv:2208.11970 (2022).
[12] Zagoruyko, Sergey, and Nikos Komodakis. “Wide residual networks.” arXiv preprint arXiv:1605.07146 (2016).
[13] https://github.com/lucidrains/denoising-diffusion-pytorch
[14] Liu, Zhuang, et al. “A convnet for the 2020s.” _Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition_. 2022.
[15] Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).
[16] Wu, Yuxin, and Kaiming He. “Group normalization.” _Proceedings of the European conference on computer vision (ECCV)_. 2018.