关于机器学习:技术博客生成式对抗网络模型综述

3次阅读

共计 10387 个字符,预计需要花费 26 分钟才能阅读完成。

生成式反抗网络模型综述

作者:张真源

GAN

GAN 简介

生成式反抗网络 (Generative adversarial networks,GANs) 的核心思想源自于零和博弈,包含生成器和判断器两个局部。生成器接管随机变量并生成“假”样本,判断器则用于判断输出的样本是实在的还是合成的。两者通过互相反抗来取得彼此性能的晋升。判断器所作的其实就是一个二分类工作,咱们能够计算他的损失并进行反向流传求出梯度,从而进行参数更新。

GAN 的优化指标能够写作:
$$\large\min_G\max_DV(D,G)= \mathbb{E}_{x\sim p_{data}}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[log(1-D(G(z)))]$$
其中 $$\log D(x)$$ 代表了判断器甄别实在样本的能力,而 $$D(G(z))$$ 则代表了生成器坑骗判断器的能力。在理论的训练中,生成器和判断器采取交替训练,即先训练 D,而后训练 G,一直往返。

WGAN

在上一部分咱们给出了 GAN 的优化指标,这个指标的实质是在最小化生成样本与实在样本之间的 JS 间隔。然而在试验中发现,GAN 的训练十分的不稳固,常常会陷入坍缩模式。这是因为,在高维空间中,并不是每个点都能够示意一个样本,而是存在着大量不代表实在信息的无用空间。当两个散布没有重叠时,JS 间隔不能精确的提供两个散布之间的差别。这样的生成器,很难“捕获”到低维空间中的实在数据分布。因而,WGAN(Wasserstein GAN)的作者提出了 Wasserstein 间隔 (推土机间隔) 的概念,其公式能够进行如下示意:
$$\large W(\mathbb P_r,\mathbb P_g)=\inf_{\gamma\sim\prod{\mathbb P_r,\mathbb P_g}}\mathbb E_{(x,y)~\gamma}[\|x-y\|]$$
这里 $$\prod{\mathbb P_r,\mathbb P_g}$$ 指的是实在散布 $$\mathbb P_r$$ 和生成散布 $$\mathbb P_g$$ 的联结散布所形成的汇合,$$(x,y)$$ 是从 $$\gamma$$ 中获得的一个样本。枚举两者之间所有可能的联结散布,计算其中样本间的间隔 $$\|x-y\|$$,并取其冀望。而 Wasserstein 间隔就是两个散布样本间隔冀望的下界值。这个简略的改良,使得生成样本在任意地位下都能给生成器带来适合的梯度,从而对参数进行优化。

DCGAN

卷积神经网络近年来获得了夺目的问题,展示了其在图像处理畛域独特的劣势。很天然的会想到,如果将卷积神经网络引入 GAN 中,是否能够带来成果上的晋升呢?DCGAN(Deep Convolutional GANs)在 GAN 的根底上优化了网络结构,用齐全的卷积代替了全连贯层,去掉池化层,并采纳批标准化 (Batch Normalization,BN) 等技术,使得网络更容易训练。

用 DCGAN 生成图像

为了更不便精确的阐明 DCGAN 的关键环节,这里用一个简化版的模型实例来阐明。代码基于 pytorch 深度学习框架,数据集采纳 MNIST

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import os  
#定义一些超参数
nc = 1        #输出图像的通道数
nz = 100          #输出噪声的维度
num_epochs = 200  #迭代次数
batch_size = 64      #批量大小
sample_dir = 'gan_samples'
# 后果的保留目录
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# 加载 MNIST 数据集
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
mnist = torchvision.datasets.MNIST(root=r'G:\VsCode\ml\mnist',
                                   train=True,
                                   transform=trans,
                                   download=False)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,  
                                          shuffle=True)

判断器 & 生成器

判断器应用 LeakyReLU 作为激活函数,最初通过 Sigmoid 输入,用于虚实二分类
生成器应用 ReLU 作为激活函数,最初通过 tanh 将输入映射在 [-1,1] 之间

# 构建判断器
class Discriminator(nn.Module):
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            # 28 -> 14
            nn.Conv2d(nc, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 14 -> 7
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7 -> 4
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Sequential(
            # reshape input, 128 -> 1
            nn.Linear(128, 1),
            nn.Sigmoid(),)
   
    def forward(self, x, label=None):
        y_ = self.conv(x)
        y_ = y_.view(y_.size(0), -1)
        y_ = self.fc(y_)
        return y_
# 构建生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(nn.Linear(nz, 4*4*512),
            nn.ReLU(),)
        self.conv = nn.Sequential(
            # input: 4 by 4, output: 7 by 7
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # input: 7 by 7, output: 14 by 14
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # input: 14 by 14, output: 28 by 28
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),)
       
    def forward(self, x, label=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, 4, 4)
        y_ = self.conv(y_)
        return y_

训练模型

# 应用 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
D = Discriminator().to(device)
G = Generator().to(device)
# 损失函数及优化器
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(data_loader):
        images = images.to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        #————————————————————训练判断器——————————————————————
        #甄别实在样本
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        #甄别生成样本
        z = torch.randn(batch_size, nz).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs        
      #计算梯度及更新
        d_loss = d_loss_real + d_loss_fake      
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        #————————————————————训练生成器——————————————————————
        z = torch.randn(batch_size, nz).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        #计算梯度及更新
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
       
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'  
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),  
                          real_score.mean().item(), fake_score.mean().item()))
    # 保留生成图片
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# 保留模型
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

可视化后果

reconsPath = './gan_samples/fake_images-200.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image)
plt.axis('off')
plt.show()

cGAN

在之前介绍的几种模型中,咱们留神到生成器的输出都是一个随机的噪声。能够认为这个高维噪声向量提供了一些要害信息,而生成器依据本人的了解将这些信息进行补充,最终生成须要的图像。生成器生成图片的过程是齐全随机的。例如上述的 MNIST 数据集,咱们不能管制它生成的是哪一个数字。那么,有没有办法能够对其做肯定的限度束缚,来让生成器生成咱们想要的后果呢?cGAN(Conditional Generative Adversarial Nets)通过增一个额定的向量 y 对生成器进行束缚。以 MNIST 分类为例,限度信息 y 能够取 10 维的向量,对于类别进行 one-hot 编码,并与噪声进行拼接一起输出生成器。同样的,判断器也将原来的输出和 y 进行拼接。作者通过各种试验证实了这个简略的改良的确能够起到对生成器的束缚作用。

判断器 & 生成器

只须要在前向流传的过程中退出限度变量 y,咱们很容易就能失去 cGAN 的生成器和判断器模型

class Discriminator(nn.Module):
    ...
    def forward(self, x, label):
        label = label.unsqueeze(2).unsqueeze(3)
        label = label.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat(tensors=(x, label), dim=1)
        y_ = self.conv(x)
        ...
class Generator(nn.Module):
    ...
    def forward(self, x, label):
        x = x.unsqueeze(2).unsqueeze(3)
        label = label.unsqueeze(2).unsqueeze(3)
        x = torch.cat(tensors=(x, label), dim=1)
        y_ = self.fc(x)
        ...

Pix2Pix

在下面的 cGAN 例子中,咱们的管制信息取的是想要图像的标签,如果这个管制信息更加的丰盛,例如输出一整张图像,那么它是否实现一些更加高级的工作?Pix2Pix(Image-to-Image Translation with Conditional Adversarial Networks)将这一类问题演绎为图像到图像的翻译,其应用改良后的 U -net 作为生成器,并设计了新鲜的 Patch- D 判断器构造来输入高清的图像。Patch- D 是指,不论网络所应用的输出图像有多大,都将其切割成若干个固定大小的 Patch,判断器只需对这些 Patch 的虚实进行判断。因为 L1 损失曾经能够掂量生成图像和实在图像的全局差别,所以作者认为判断器只须要用 Patch- D 这样更关注于部分差别的构造即可。同时 Patch- D 的构造使得网络的输出变小,缩小了计算量并且增大了框架的扩展性。

CycleGAN

Pix2Pix 尽管能够生成高清的图像,但其存在一个致命的毛病:须要互相配对的图片 x 与 y。在现实生活中,这样成对的图片很难或者基本不可能收集到,这就大大的限度了 Pix2Pix 的利用。对此,CycleGAN(Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks)提出了不须要配对的图像翻译办法。

CycleGAN 其实就是一个 X ->Y 的单向 GAN 上再加一个 Y ->X 的单向 GAN,形成一个“循环”。网络的构造和单次训练过程如下(图片来自于量子位):


除了经典的根底 GAN 损失之外,CycleGAN 还引入了 Consistency loss 的概念。循环统一损失使得 X ->Y 转变的过程中必须保留有 X 的局部个性。循环损失的公式如下:
$$\large L_{cyc}(G,F)=\mathbb E_{x\sim p_{data}(x)}[\|F(G(x))-x\|_1]+\mathbb E_{y\sim p_{data}(y)}[\|G(F(x))-y\|_1]$$
两个判断器的损失示意如下:
$$\large \textit{L}_{GAN}(G,D_Y,X,Y)=\mathbb E_{y\sim p_{data}(y)}[logD_Y(y)]+\mathbb E_{x\sim p_{data}(x)}[log(1-D_Y(G(x)))]$$$$\large \textit{L}_{GAN}(F,D_X,Y,X)=\mathbb E_{x\sim p_{data}(x)}[logD_X(x)]+\mathbb E_{y\sim p_{data}(y)}[log(1-D_X(F(y)))]$$
最初网络的优化指标能够示意为
$$\large \min _{G_{X\rightarrow Y},G_{Y\rightarrow X}}\max_{D_X,D_Y} L(G,F,D_x,D_y)=L_{GAN}(G,D_Y,X,Y)+L_{GAN}(F,D_X,Y,X)+\lambda L_{cyc}(G,F)$$
Pix2Pix 以及 CycleGAN 的官网复现入口:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

StarGAN

Pix2Pix 解决了有配对图像的翻译问题,CycleGAN 解决了无配对图像的翻译问题,然而他们所作的图像到图像翻译,都是一对一。假如当初须要将人脸转换为喜怒哀乐四个表情,那么他们就须要进行 4 次不同的训练,这无疑会消耗微小的计算资源。针对于这个问题,StarGAN(StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation)借助 cGAN 的思维,在网络输出中退出一个域的管制信息。对于判断器,其不仅须要甄别样本是否实在,还须要判断输出的图像来自哪个域。StarGAN 的训练过程如下:

  1. 将原始图片 c 和指标生成域 c 进行拼接后丢入生成器失去生成图像 G(x,c)
  2. 将生成图像 G(x,c)和实在图像 y 别离丢入判断器 D,判断器除了须要判断输出图像的真伪之外,还须要判断它来自哪个域
  3. 将生成图像 G(x,c)和原始生成域 c ’ 丢入生成器生成重构图片(为了对生成器生成的图像做进一步的限度,与 CycleGAN 的重构损失相似)


理解了 StarGAN 的训练过程,咱们很容易失去其损失函数各项的表达形式
首先是 GAN 的个别损失,这里作者采纳了前文所述的 WGAN 的损失模式:
$$\large L_{adv}=\mathbb E_x[D_{src}(x)]-\mathbb E_{x,c}[D_{src}(G(x,c)))]-\lambda_{gp}\mathbb E_{\hat x}[(\|\nabla _\hat xD_{src}(\hat x)\|_2-1)^2]$$
对于判断器,咱们须要激励其将输出图像正确的分类到指标域 c‘(原始生成域):
$$\large L_{src}^r=\mathbb E_{x,c’}[-logD_{cls}(c’|x)]$$

对于生成器,咱们须要激励其胜利坑骗判断器将图片分类到指标域 c(指标生成域),此外,生成器还须要在以生成图像和原始生成域 c ’ 的输出下胜利将图像还原回去,这两局部的损失示意如下:
$$\largeL_{src}^f=\mathbb R_{x,c}[-logD_{cls}(c|G(x,c))]$$
$$\large L_{rec}=\mathbb E_{x,c,c’}[\|x-G(G(x,c),c’)\|_1]$$
各局部损失乘上本人的权重加总后就形成了判断器和生成器的总损失:
$$\largeL_D=-L_{adv}+\lambda_{cls}L_{cls}^{r} $$
$$\large L_G=L_{adv}+\lambda_{cls}L_{clas}^f+\lambda_{rec}L_{rec}$$
此外,为了更具备通用性,作者还退出了 mask vector 来适应不同的数据集之间的训练。

总结

名称 翻新点
DCGAN 首次将卷积神经网络引入 GAN 中
cGAN 通过拼接标签信息来管制生成器的输入
Pix2Pix 提出了一种图像到图像翻译的通用办法
CycleGAN 解决了 Pix2Pix 须要图像配对的问题
StarGAN 提出了一种一对多的图像到图像的翻译办法
InfoGAN 基于 cGAN 改良,提出一种无监督的生成办法, 实用于不晓得图像标签的状况
LSGAN 用最小二乘损失函数代替原始 GAN 的损失函数, 缓解了训练不稳固、生成图像不足多样性的问题
ProGAN 在训练期间逐渐增加新的高分辨率层, 能够生成高分辨率的图像
SAGAN 将注意力机制引入 GAN 当中, 简洁高效利用了全局信息

本文列举了生成式反抗网络在倒退过程中一些具备代表性的网络结构。GANs 现在已广泛应用于图像生成、图像去噪、超分辨、文本到图像的翻译等各个领域,且在近几年的钻研中涌现了很多优良的论文。感兴趣的同学能够从上面的链接中 pick 本人想要理解的 GAN~

  • THE-GAN-ZOO:汇总了各种 GAN 的论文及代码地址。
  • GAN Timeline:依照工夫线对不同的 GAN 进行了排序。
  • Browse state-of-the-art:将 ArXiv 上的最新论文与 GitHub 代码相关联,并做了比拟排序,波及了深度学习的各个方面。

参考文献

  1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.
  2. Arjovsky M, Chintala S, Bottou L. Wasserstein gan[J]. arXiv preprint arXiv:1701.07875, 2017.
  3. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.
  4. Mirza M, Osindero S. Conditional generative adversarial nets[J]. arXiv preprint arXiv:1411.1784, 2014.
  5. Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 1125-1134.
  6. Zhu J Y, Park T, Isola P, et al. Unpaired image-to-image translation using cycle-consistent adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2223-2232.
  7. Choi Y, Choi M, Kim M, et al. Stargan: Unified generative adversarial networks for multi-domain image-to-image translation[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 8789-8797.
  8. Mao X, Li Q, Xie H, et al. Least squares generative adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2794-2802.
  9. Karras T, Aila T, Laine S, et al. Progressive growing of gans for improved quality, stability, and variation[J]. arXiv preprint arXiv:1710.10196, 2017.
  10. Chen X, Duan Y, Houthooft R, et al. Infogan: Interpretable representation learning by information maximizing generative adversarial nets[C]//Advances in neural information processing systems. 2016: 2172-2180.
  11. Zhang H, Goodfellow I, Metaxas D, et al. Self-attention generative adversarial networks[C]//International Conference on Machine Learning. 2019: 7354-7363.

对于咱们

Mo(网址:https://momodel.cn)是一个反对 Python 的人工智能在线建模平台,能帮忙你疾速开发、训练并部署模型。

近期 Mo 也在继续进行机器学习相干的入门课程和论文分享流动,欢送大家关注咱们的公众号(MomodelAI)获取最新资讯!

正文完
 0