关于人工智能:扩散模型课程第一单元第二部分扩散模型从零到一

77次阅读

共计 29173 个字符,预计需要花费 73 分钟才能阅读完成。

前言

于 11 月底正式开课的扩散模型课程正在炽热进行中,在中国社区成员们的帮忙下,咱们组织了「抱抱脸中文本地化志愿者小组」并实现了扩散模型课程的中文翻译,感激 @darcula1993、@XhrLeokk、@hoi2022、@SuSung-boy 对课程的翻译!

如果你还没有开始课程的学习,咱们倡议你从 第一单元:扩散模型简介 开始。

扩散模型从零到一

这个 Notebook 咱们将展现雷同的步骤(向数据增加噪声、创立模型、训练和采样),并尽可能简略地在 PyTorch 中从头开始实现。而后,咱们将这个「玩具示例」与 diffusers 版本进行比拟,并关注两者的区别以及改良之处。这里的指标是相熟不同的组件和其中的设计决策,以便在查看新的实现时可能疾速确定要害思维。

让咱们开始吧!

有时,只思考一些事务最简略的状况会有助于更好地了解其工作原理。咱们将在本笔记本中尝试这一点,从“玩具”扩散模型开始,看看不同的局部是如何工作的,而后再查看它们与更简单的实现有何不同。

你将追随本文的 Notebook 学习到

  • 损坏过程(向数据增加噪声)
  • 什么是 UNet,以及如何从零开始实现一个极小的 UNet
  • 扩散模型训练
  • 抽样实践

而后,咱们将比拟咱们的版本与 diffusers 库中的 DDPM 实现的区别

  • 对小型 UNet 的改良
  • DDPM 噪声打算
  • 训练指标的差别
  • timestep 调节
  • 抽样办法

这个笔记本相当深刻,如果你对从零开始的深入研究不感兴趣,能够释怀地跳过!

还值得注意的是,这里的大多数代码都是出于阐明的目标,我不倡议间接将其用于您本人的工作(除非您只是为了学习目标而尝试改良这里展现的示例)。

筹备环境与导入:

!pip install -q diffusers
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

数据

在这里,咱们将应用一个十分小的经典数据集 mnist 来进行测试。如果您想在不扭转任何其余内容的状况下给模型一个略微艰难一点的挑战,请应用 torchvision.dataset,FashionMNIST 应作为替代品。

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

该数据集中的每张图都是一个数字的 28×28 像素的灰度图,像素值的范畴是从 0 到 1。

损坏过程

假如你没有读过任何扩散模型的论文,但你晓得这个过程会减少噪声。你会怎么做?

咱们可能想要一个简略的办法来管制损坏的水平。那么,如果咱们要引入一个参数来管制输出的“噪声量”,那么咱们会这么做:

noise = torch.rand_like(x)

noisy_x = (1-amount)*x + amount*noise

如果 amount = 0,则返回输出而不做任何更改。如果 amount = 1,咱们将失去一个纯正的噪声。通过这种形式将输出与噪声混合,咱们将输入放弃在雷同的范畴(0 to 1)。

咱们能够很容易地实现这一点(然而要留神 tensor 的 shape,以防被播送 (broadcasting) 机制不正确的影响到):

def corrupt(x, amount):
  """Corrupt the input `x` by mixing it with noise according to `amount`"""
  noise = torch.rand_like(x)
  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
  return x*(1-amount) + noise*amount 

让咱们来可视化一下输入的后果,以理解是否合乎咱们的预期:

# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Plottinf the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

当噪声量靠近 1 时,咱们的数据开始看起来像纯随机噪声。但对于大多数的噪声状况下,您还是能够很好地辨认出数字。你认为这是最佳的吗?

模型

咱们想要一个模型,它能够接管 28px 的噪声图像,并输入雷同形态的预测。一个比拟风行的抉择是一个叫做 UNet 的架构。最后被创造用于医学图像中的宰割工作,UNet 由一个“压缩门路”和一个“扩大门路”组成。“压缩门路”会使通过该门路的数据被压缩,而通过“扩大门路”会将数据扩大回原始维度(相似于主动编码器)。模型中的残差连贯也容许信息和梯度在不同层级之间流动。

一些 UNet 的设计在每个阶段都有简单的 blocks,但对于这个玩具 demo,咱们只会构建一个最简略的示例,它接管一个单通道图像,并通过上行门路上的三个卷积层(图和代码中的 down\_layers)和上行门路上的 3 个卷积层,在上行和上行层之间具备残差连贯。咱们将应用 max pooling 进行下采样和 nn.Upsample 用于上采样。某些比较复杂的 UNets 的设计会应用带有可学习参数的上采样和下采样 layer。上面的结构图大抵展现了每个 layer 的输入通道数:

代码实现如下:

class BasicUNet(nn.Module):
  """A minimal UNet implementation."""
  def __init__(self, in_channels=1, out_channels=1):
    super().__init__()
    self.down_layers = torch.nn.ModuleList([nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
      nn.Conv2d(32, 64, kernel_size=5, padding=2),
      nn.Conv2d(64, 64, kernel_size=5, padding=2),
    ])
    self.up_layers = torch.nn.ModuleList([nn.Conv2d(64, 64, kernel_size=5, padding=2),
      nn.Conv2d(64, 32, kernel_size=5, padding=2),
      nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
    ])
    self.act = nn.SiLU() # The activation function
    self.downscale = nn.MaxPool2d(2)
    self.upscale = nn.Upsample(scale_factor=2)

  def forward(self, x):
    h = []
    for i, l in enumerate(self.down_layers):
      x = self.act(l(x)) # Through the layer n the activation function
      if i < 2: # For all but the third (final) down layer:
        h.append(x) # Storing output for skip connection
        x = self.downscale(x) # Downscale ready for the next layer
              
    for i, l in enumerate(self.up_layers):
      if i > 0: # For all except the first up layer
        x = self.upscale(x) # Upscale
        x += h.pop() # Fetching stored output (skip connection)
        x = self.act(l(x)) # Through the layer n the activation function
            
    return x

咱们能够验证输入 shape 是否如咱们冀望的那样与输出雷同:

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
torch.Size([8, 1, 28, 28])

该网络有 30 多万个参数:

sum([p.numel() for p in net.parameters()])
309057

您能够尝试更改每个 layer 中的通道数或尝试不同的结构设计。

训练模型

那么,模型到底应该做什么呢?同样,对这个问题有各种不同的认识,但对于这个演示,让咱们抉择一个简略的框架:给定一个损坏的输出 noisy_x,模型应该输入它对本来 x 的最佳猜想。咱们将通过均方误差将预测与实在值进行比拟。

咱们当初能够尝试训练网络了。

  • 获取一批数据
  • 增加随机噪声
  • 将数据输出模型
  • 将模型预测与洁净图像进行比拟,以计算 loss
  • 更新模型的参数

你能够自在进行批改来尝试取得更好的后果!

# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# How many runs through the data should we do?
n_epochs = 3

# Create the network
net = BasicUNet()
net.to(device)

# Our loss finction
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):

  for x, y in train_dataloader:
    # Get some data and prepare the corrupted version
    x = x.to(device) # Data on the GPU
    noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
    noisy_x = corrupt(x, noise_amount) # Create our noisy x

    # Get the model prediction
    pred = net(noisy_x)

    # Calculate the loss
    loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

    # Backprop and update the params:
    opt.zero_grad()
    loss.backward()
    opt.step()

    # Store the loss for later
    losses.append(loss.item())

    # Print our the average of the loss values for this epoch:
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
Finished epoch 0. Average loss for this epoch: 0.026736
Finished epoch 1. Average loss for this epoch: 0.020692
Finished epoch 2. Average loss for this epoch: 0.018887

咱们能够尝试通过抓取一批数据,以不同的数量损坏数据,而后喂进模型取得预测来察看后果:

#@markdown Visualizing model predictions on noisy inputs:

# Fetch some data
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting

# Corrupt with a range of amounts
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Get the model predictions
with torch.no_grad():
  preds = net(noised_x.to(device)).detach().cpu()

# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

你能够看到,对于较低的噪声程度数量,预测的后果相当不错!然而,当噪声程度十分高时,模型可能取得的信息就开始逐步缩小。而当咱们达到 amount = 1 时,模型会输入一个含糊的预测,该预测会很靠近数据集的平均值。模型通过这样的形式来猜想原始输出。

取样(采样)

如果咱们在高噪声程度下的预测不是很好,咱们如何能力生成图像呢?

如果咱们从齐全随机的噪声开始,检查一下模型预测的后果,而后只朝着预测方向挪动一小部分,比如说 20%。当初咱们有一个噪声很多的图像,其中可能暗藏了一些对于输出数据的构造的提醒,咱们能够将其输出到模型中以取得新的预测。心愿这个新的预测比第一个略微好一点(因为咱们这一次的输出略微缩小了一点噪声),所以咱们能够用这个新的更好的预测再往前迈出一小步。

如果一切顺利的话,以上过程反复几次当前咱们就会失去一个新的图像!以下图例是迭代了五次当前的后果,左侧是每个阶段的模型输出的可视化,右侧则是预测的去噪图像。请留神,即便模型在第 1 步就预测了去噪图像,咱们也只是将输出向去噪图像变换了一小部分。反复几次当前,图像的构造开始逐步呈现并失去改善 , 直到取得咱们的最终后果为止。

#@markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [x.detach().cpu()]
pred_output_history = []

for i in range(n_steps):
  with torch.no_grad(): # No need to track gradients during inference
    pred = net(x) # Predict the denoised x0
  pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
  mix_factor = 1/(n_steps - i) # How much we move towards the prediction
  x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there
  step_history.append(x.detach().cpu()) # Store step for plotting

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
  axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
  axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')

咱们能够将流程分成更多步骤,并心愿通过这种形式取得更好的图像:

#@markdown Showing more results, using 40 sampling steps
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
  noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
  with torch.no_grad():
    pred = net(x)
  mix_factor = 1/(n_steps - i)
  x = x*(1-mix_factor) + pred*mix_factor
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
<matplotlib.image.AxesImage at 0x7f27567d8210>

后果并不是十分好,然而曾经呈现了一些能够被认出来的数字!您能够尝试训练更长时间(例如,10 或 20 个 epoch),并调整模型配置、学习率、优化器等。此外,如果您想尝试略微艰难一点的数据集,您能够尝试一下 fashionMNIST,只须要一行代码的替换就能够了。

与 DDPM 做比拟

在本节中,咱们将看看咱们的“玩具”实现与其余笔记本中应用的基于 DDPM 论文的办法有何不同: 扩散器简介 Notebook。

咱们将会看到的

  • 模型的体现受限于随迭代周期 (timesteps) 变动的管制条件,在前向传导中工夫步 (t) 是作为一个参数被传入的
  • 有很多不同的取样策略可抉择,可能会比咱们下面所应用的最简略的版本更好
  • diffusers UNet2DModel 比咱们的 BasicUNet 更先进
  • 损坏过程的解决形式不同
  • 训练指标不同,包含预测噪声而不是去噪图像
  • 该模型通过调节 timestep 来调节噪声程度 , 其中 t 作为一个附加参数传入前向过程中。
  • 有许多不同的采样策略可供选择,它们应该比咱们下面简略的版本更无效。

自 DDPM 论文发表以来,曾经有人提出了许多改良倡议,但这个例子对于不同的可用设计决策具备指导意义。读完这篇文章后,你可能会想要深刻理解这篇论文《Elucidating the Design Space of Diffusion-Based Generative Models》,它对所有这些组件进行了具体的探讨,并就如何获得最佳性能提出了新的倡议。

如果你感觉这些内容对你来说太过深奥了,请不要放心!你能够随便跳过本笔记本的其余部分或将其保留以备不时之需。

UNet

diffusers 中的 UNet2DModel 模型比上述根本 UNet 模型有许多改良:

  • GroupNorm 层对每个 blocks 的输出进行了组标准化(group normalization)
  • Dropout 层能使训练更平滑
  • 每个块有多个 resnet 层(如果 layers\_per\_block 未设置为 1)
  • 留神机制(通常仅用于输出分辨率较低的 blocks)
  • timestep 的调节。
  • 具备可学习参数的下采样和上采样块

让咱们来创立并认真钻研一下 UNet2DModel:

model = UNet2DModel(
  sample_size=28,           # the target image resolution
  in_channels=1,            # the number of input channels, 3 for RGB images
  out_channels=1,           # the number of output channels
  layers_per_block=2,       # how many ResNet layers to use per UNet block
  block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
  down_block_types=( 
    "DownBlock2D",        # a regular ResNet downsampling block
    "AttnDownBlock2D",    # a ResNet downsampling block w/ spatial self-attention
    "AttnDownBlock2D",
  ), 
  up_block_types=(
    "AttnUpBlock2D", 
    "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
    "UpBlock2D",          # a regular ResNet upsampling block
  ),
)
print(model)
UNet2DModel((conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding((linear_1): Linear(in_features=32, out_features=128, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList((0): DownBlock2D((resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU())
        (1): ResnetBlock2D((norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU())
      )
      (downsamplers): ModuleList((0): Downsample2D((conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (1): AttnDownBlock2D((attentions): ModuleList((0): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU())
      )
      (downsamplers): ModuleList((0): Downsample2D((conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
    )
    (2): AttnDownBlock2D((attentions): ModuleList((0): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU())
        (1): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU())
      )
    )
  )
  (up_blocks): ModuleList((0): AttnUpBlock2D((attentions): ModuleList((0): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (2): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D((norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D((norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList((0): Upsample2D((conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (1): AttnUpBlock2D((attentions): ModuleList((0): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (1): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
        (2): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (proj_attn): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D((norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D((norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
          (conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
          (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (upsamplers): ModuleList((0): Upsample2D((conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (2): UpBlock2D((resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
          (conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
          (conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (mid_block): UNetMidBlock2D((attentions): ModuleList((0): AttentionBlock((group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (proj_attn): Linear(in_features=64, out_features=64, bias=True)
      )
    )
    (resnets): ModuleList((0): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU())
      (1): ResnetBlock2D((norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
        (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (nonlinearity): SiLU())
    )
  )
  (conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

正如你所看到的,还有更多!它比咱们的 BasicUNet 有多得多的参数量:

sum([p.numel() for p in model.parameters()]) # 1.7M vs the ~309k parameters of the BasicUNet
1707009

咱们能够用这个模型代替原来的模型来反复一遍下面展现的训练过程。咱们须要将 x 和 timestep 传递给模型(这里我会传递 t = 0,以表明它在没有 timestep 条件的状况下工作,并放弃采样代码简略,但您也能够尝试输出 (amount*1000),使 timestep 与噪声程度相当)。如果要查看代码,更改的即将显示为“#<<<

#@markdown Trying UNet2DModel instead of BasicUNet:

# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# How many runs through the data should we do?
n_epochs = 3

# Create the network
net = UNet2DModel(
  sample_size=28,  # the target image resolution
  in_channels=1,  # the number of input channels, 3 for RGB images
  out_channels=1,  # the number of output channels
  layers_per_block=2,  # how many ResNet layers to use per UNet block
  block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
  down_block_types=( 
    "DownBlock2D",  # a regular ResNet downsampling block
    "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
    "AttnDownBlock2D",
  ), 
  up_block_types=(
    "AttnUpBlock2D", 
    "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
    "UpBlock2D",   # a regular ResNet upsampling block
  ),
) #<<<
net.to(device)

# Our loss finction
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):

  for x, y in train_dataloader:

    # Get some data and prepare the corrupted version
    x = x.to(device) # Data on the GPU
    noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
    noisy_x = corrupt(x, noise_amount) # Create our noisy x

    # Get the model prediction
    pred = net(noisy_x, 0).sample #<<< Using timestep 0 always, adding .sample

    # Calculate the loss
    loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

    # Backprop and update the params:
    opt.zero_grad()
    loss.backward()
    opt.step()

    # Store the loss for later
    losses.append(loss.item())

    # Print our the average of the loss values for this epoch:
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# Plot losses and some samples
fig, axs = plt.subplots(1, 2, figsize=(12, 5))

# Losses
axs[0].plot(losses)
axs[0].set_ylim(0, 0.1)
axs[0].set_title('Loss over time')

# Samples
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
  noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
  with torch.no_grad():
    pred = net(x, 0).sample
  mix_factor = 1/(n_steps - i)
  x = x*(1-mix_factor) + pred*mix_factor

axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Generated Samples');
Finished epoch 0. Average loss for this epoch: 0.018925
Finished epoch 1. Average loss for this epoch: 0.012785
Finished epoch 2. Average loss for this epoch: 0.011694

这看起来比咱们的第一组后果好多了!您能够尝试调整 UNet 配置或更长时间的训练,以取得更好的性能。

损坏过程

DDPM 论文形容了一个为每个“timestep”增加大量噪声的损坏过程。为某些 timestep 给定 , 咱们能够失去一个噪声稍稍减少的 :

这就是说,咱们取 , 给他一个 的系数,而后加上带有 系数的噪声。这里 是依据一些管理器来为每一个 t 设定的,来决定每一个迭代周期中增加多少噪声。当初,咱们不想把这个推演进行 500 次来失去,所以咱们用另一个公式来依据给出的 计算失去任意 t 时刻的 :

数学符号看起来总是很吓人!侥幸的是,调度器为咱们解决了所有这些(勾销下一个单元格的正文以查看代码)。咱们能够画出 (标记为 sqrt_alpha_prod) 和 (标记为 sqrt_one_minus_alpha_prod) 来看一下输出 (x) 与噪声是如何在不同迭代周期中量化和叠加的 :

#??noise_scheduler.add_noise
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

一开始 , 噪声 x 里绝大部分都是 x 本身的值  (sqrt\_alpha\_prod ~= 1),然而随着工夫的推移,x 的成分逐步升高而噪声的成分逐步减少。与咱们依据 amount 对 x 和噪声进行线性混合不同,这个噪声的减少绝对较快。咱们能够在一些数据上看到这一点:

#@markdown visualize the DDPM noising process for different timesteps:

# Noise a batch of images to view the effect
fig, axs = plt.subplots(3, 1, figsize=(16, 10))
xb, yb = next(iter(train_dataloader))
xb = xb.to(device)[:8]
xb = xb * 2. - 1. # Map to (-1, 1)
print('X shape', xb.shape)

# Show clean inputs
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap='Greys')
axs[0].set_title('Clean X')

# Add noise with scheduler
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print('Noisy X shape', noisy_xb.shape)

# Show noisy version (with and without clipping)
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1),  cmap='Greys')
axs[1].set_title('Noisy X (clipped to (-1, 1)')
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(),  cmap='Greys')
axs[2].set_title('Noisy X');
X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])

在运行中的另一个变动:在 DDPM 版本中,退出的噪声是取自一个高斯分布(来自均值 0 方差 1 的 torch.randn),而不是在咱们原始 corrupt 函数中应用的 0-1 之间的均匀分布(torch.rand),当然对训练数据做正则化也能够了解。在另一篇笔记中,你会看到 Normalize(0.5, 0.5) 函数在变动列表中,它把图片数据从 (0, 1) 区间映射到 (-1, 1),对咱们的指标来说也‘足够用了’。咱们在此篇笔记中没应用这个办法,但在下面的可视化中为了更好的展现增加了这种做法。

训练指标

在咱们的玩具示例中,咱们让模型尝试预测去噪图像。在 DDPM 和许多其余扩散模型实现中,模型则会预测损坏过程中应用的噪声(在缩放之前,因而是单位方差噪声)。在代码中,它看起来像是这样:

noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target

你可能认为预测噪声(咱们能够从中得进来噪图像的样子)等同于间接预测去噪图像。那么,为什么要这么做呢?这仅仅是为了数学上的不便吗?

这里其实还有另一些精妙之处。咱们在训练过程中,会计算不同(随机抉择)timestep 的 loss。这些不同的指标将导致这些 loss 的不同的“隐含权重”,其中预测噪声会将更多的权重放在较低的噪声程度上。你能够抉择更简单的指标来扭转这种“隐性损失权重”。或者,您抉择的噪声管理器将在较高的噪声程度下产生更多的示例。兴许你让模型设计成预测“velocity”v,咱们将其定义为由噪声程度影响的图像和噪声组合(请参阅“扩散模型疾速采样的渐进蒸馏”- ‘PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS’)。兴许你将模型设计成预测噪声,而后基于某些因子来对 loss 进行缩放:比方有些实践指出能够参考噪声程度(参见“扩散模型的感知优先训练”-‘Perception Prioritized Training of Diffusion Models’),或者基于一些摸索模型最佳噪声程度的试验(参见“基于扩散的生成模型的设计空间阐明”-‘Elucidating the Design Space of Diffusion-Based Generative Models’)。

一句话解释:抉择指标对模型性能有影响,当初有许多研究者正在摸索“最佳”选项是什么。目前,预测噪声(epsilon 或 eps)是最风行的办法,但随着工夫的推移,咱们很可能会看到库中反对的其余指标,并在不同的状况下应用。

迭代周期(Timestep)调节

UNet2DModel 以 x 和 timestep 为输出。后者被转化为一个嵌入(embedding),并在多个中央被输出到模型中。

这背地的实践反对是这样的:通过向模型提供无关噪声程度的信息,它能够更好地执行工作。尽管在没有这种 timestep 条件的状况下也能够训练模型,但在某些状况下,它仿佛的确有助于性能,目前来说绝大多数的模型实现都包含了这一输出。

取样(采样)

有一个模型能够用来预测在带噪样本中的噪声(或者说能预测其去噪版本),咱们怎么用它来生成图像呢?

咱们能够给入纯噪声,而后就心愿模型能一步就输入一个不带噪声的好图像。然而,就咱们下面所见到的来看,这通常行不通。所以,咱们在模型预测的根底上应用足够多的小步,迭代着来每次去除一点点噪声。

具体咱们怎么走这些小步,取决于应用下面取样办法。咱们不会去深刻探讨太多的实践细节,然而一些顶层想法是这样:

  • 每一步你想走多大?也就是说,你遵循什么样的“噪声打算(噪声治理)”?
  • 你只应用模型以后步的预测后果来领导下一步的更新方向吗(像 DDPM,DDIM 或是其余的什么那样)?你是否要应用模型来多预测几次来预计一个更高阶的梯度来更新一步更大更精确的后果(更高阶的办法和一些离散 ODE 处理器)?或者保留历史预测值来尝试更好的领导以后步的更新(线性多步或遗传取样器)?
  • 你是否会在取样过程中额定再加一些随机噪声,或你齐全已知的(deterministic)来增加噪声?许多取样器通过参数(如 DDIM 中的 ‘eta’)来供用户抉择。

对于扩散模型取样器的钻研演进的很快,随之开发出了越来越多能够应用更少步就找到好后果的办法。怯懦和有好奇心的人可能会在浏览 diffusers library 中不同部署办法时感到十分有意思,能够查看 Schedulers 代码 或看看 Schedulers 文档,这里常常有一些相干的论文。

结语

心愿这能够从一些不同的角度来扫视扩散模型提供一些帮忙。这篇笔记是 Jonathan Whitaker 为 Hugging Face 课程所写的,如果你对从噪声和束缚分类来生成样本的例子感兴趣。问题与 bug 能够通过 GitHub issues 或 Discord 来交换。

致谢第一单元第二局部社区贡献者

感激社区成员们对本课程的奉献:

@darcula1993、@XhrLeokk:魔都强人工智能孵化者,二里街调参记录放弃人,所有趣味使然的 AIGC 色图创作家的包庇者,图灵神在五角场的惟一指定路上行走。

感激茶叶蛋蛋对本文奉献设计素材!

欢送通过链接退出咱们的本地化小组与大家独特交换:
https://bit.ly/3G40j6U

正文完
 0