生成式反抗网络模型综述

作者:张真源

GAN

GAN简介

生成式反抗网络(Generative adversarial networks,GANs)的核心思想源自于零和博弈,包含生成器和判断器两个局部。生成器接管随机变量并生成“假”样本,判断器则用于判断输出的样本是实在的还是合成的。两者通过互相反抗来取得彼此性能的晋升。判断器所作的其实就是一个二分类工作,咱们能够计算他的损失并进行反向流传求出梯度,从而进行参数更新。

GAN的优化指标能够写作:
$$\large\min_G\max_DV(D,G)= \mathbb{E}_{x\sim p_{data}}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[log(1-D(G(z)))]$$
其中$$\log D(x)$$代表了判断器甄别实在样本的能力,而$$D(G(z))$$则代表了生成器坑骗判断器的能力。在理论的训练中,生成器和判断器采取交替训练,即先训练D,而后训练G,一直往返。

WGAN

在上一部分咱们给出了GAN的优化指标,这个指标的实质是在最小化生成样本与实在样本之间的JS间隔。然而在试验中发现,GAN的训练十分的不稳固,常常会陷入坍缩模式。这是因为,在高维空间中,并不是每个点都能够示意一个样本,而是存在着大量不代表实在信息的无用空间。当两个散布没有重叠时,JS间隔不能精确的提供两个散布之间的差别。这样的生成器,很难“捕获”到低维空间中的实在数据分布。因而,WGAN(Wasserstein GAN)的作者提出了Wasserstein间隔(推土机间隔)的概念,其公式能够进行如下示意:
$$\large W(\mathbb P_r,\mathbb P_g)=\inf_{\gamma\sim\prod{\mathbb P_r,\mathbb P_g}}\mathbb E_{(x,y)~\gamma}[\|x-y\|]$$
这里$$\prod{\mathbb P_r,\mathbb P_g}$$指的是实在散布$$\mathbb P_r$$和生成散布$$\mathbb P_g$$的联结散布所形成的汇合,$$(x,y)$$是从$$\gamma$$中获得的一个样本。枚举两者之间所有可能的联结散布,计算其中样本间的间隔$$\|x-y\|$$,并取其冀望。而Wasserstein间隔就是两个散布样本间隔冀望的下界值。这个简略的改良,使得生成样本在任意地位下都能给生成器带来适合的梯度,从而对参数进行优化。

DCGAN

卷积神经网络近年来获得了夺目的问题,展示了其在图像处理畛域独特的劣势。很天然的会想到,如果将卷积神经网络引入GAN中,是否能够带来成果上的晋升呢?DCGAN(Deep Convolutional GANs)在GAN的根底上优化了网络结构,用齐全的卷积代替了全连贯层,去掉池化层,并采纳批标准化(Batch Normalization,BN)等技术,使得网络更容易训练。

用DCGAN生成图像

为了更不便精确的阐明DCGAN的关键环节,这里用一个简化版的模型实例来阐明。代码基于pytorch深度学习框架,数据集采纳MNIST

import torchimport torch.nn as nnimport torchvisionfrom torchvision import transformsfrom torchvision.utils import save_imageimport os  #定义一些超参数nc = 1        #输出图像的通道数nz = 100          #输出噪声的维度num_epochs = 200  #迭代次数batch_size = 64      #批量大小sample_dir = 'gan_samples'# 后果的保留目录if not os.path.exists(sample_dir):    os.makedirs(sample_dir)# 加载MNIST数据集trans = transforms.Compose([                transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])mnist = torchvision.datasets.MNIST(root=r'G:\VsCode\ml\mnist',                                   train=True,                                   transform=trans,                                   download=False)data_loader = torch.utils.data.DataLoader(dataset=mnist,                                          batch_size=batch_size,                                            shuffle=True)

判断器&生成器

判断器应用LeakyReLU作为激活函数,最初通过Sigmoid输入,用于虚实二分类
生成器应用ReLU作为激活函数,最初通过tanh将输入映射在[-1,1]之间

# 构建判断器class Discriminator(nn.Module):    def __init__(self, in_channel=1, num_classes=1):        super(Discriminator, self).__init__()        self.conv = nn.Sequential(            # 28 -> 14            nn.Conv2d(nc, 512, 3, stride=2, padding=1, bias=False),            nn.BatchNorm2d(512),            nn.LeakyReLU(0.2),            # 14 -> 7            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),            nn.BatchNorm2d(256),            nn.LeakyReLU(0.2),            # 7 -> 4            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),            nn.BatchNorm2d(128),            nn.LeakyReLU(0.2),            nn.AvgPool2d(4),        )        self.fc = nn.Sequential(            # reshape input, 128 -> 1            nn.Linear(128, 1),            nn.Sigmoid(),        )       def forward(self, x, label=None):        y_ = self.conv(x)        y_ = y_.view(y_.size(0), -1)        y_ = self.fc(y_)        return y_# 构建生成器class Generator(nn.Module):    def __init__(self):        super(Generator, self).__init__()        self.fc = nn.Sequential(            nn.Linear(nz, 4*4*512),            nn.ReLU(),        )        self.conv = nn.Sequential(            # input: 4 by 4, output: 7 by 7            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),            nn.BatchNorm2d(256),            nn.ReLU(),            # input: 7 by 7, output: 14 by 14            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),            nn.BatchNorm2d(128),            nn.ReLU(),            # input: 14 by 14, output: 28 by 28            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),            nn.Tanh(),        )           def forward(self, x, label=None):        x = x.view(x.size(0), -1)        y_ = self.fc(x)        y_ = y_.view(y_.size(0), 512, 4, 4)        y_ = self.conv(y_)        return y_

训练模型

# 应用GPUdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')D = Discriminator().to(device)G = Generator().to(device)# 损失函数及优化器criterion = nn.BCELoss()D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))def denorm(x):    out = (x + 1) / 2    return out.clamp(0, 1)def reset_grad():    d_optimizer.zero_grad()    g_optimizer.zero_grad()for epoch in range(num_epochs):    for i, (images, labels) in enumerate(data_loader):        images = images.to(device)        real_labels = torch.ones(batch_size, 1).to(device)        fake_labels = torch.zeros(batch_size, 1).to(device)        #————————————————————训练判断器——————————————————————        #甄别实在样本        outputs = D(images)        d_loss_real = criterion(outputs, real_labels)        real_score = outputs        #甄别生成样本        z = torch.randn(batch_size, nz).to(device)        fake_images = G(z)        outputs = D(fake_images)        d_loss_fake = criterion(outputs, fake_labels)        fake_score = outputs              #计算梯度及更新        d_loss = d_loss_real + d_loss_fake              reset_grad()        d_loss.backward()        d_optimizer.step()        #————————————————————训练生成器——————————————————————        z = torch.randn(batch_size, nz).to(device)        fake_images = G(z)        outputs = D(fake_images)        g_loss = criterion(outputs, real_labels)        #计算梯度及更新        reset_grad()        g_loss.backward()        g_optimizer.step()               if (i+1) % 200 == 0:            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'                    .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),                            real_score.mean().item(), fake_score.mean().item()))    # 保留生成图片    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))# 保留模型torch.save(G.state_dict(), 'G.ckpt')torch.save(D.state_dict(), 'D.ckpt')

可视化后果

reconsPath = './gan_samples/fake_images-200.png'Image = mpimg.imread(reconsPath)plt.imshow(Image)plt.axis('off')plt.show()

cGAN

在之前介绍的几种模型中,咱们留神到生成器的输出都是一个随机的噪声。能够认为这个高维噪声向量提供了一些要害信息,而生成器依据本人的了解将这些信息进行补充,最终生成须要的图像。生成器生成图片的过程是齐全随机的。例如上述的MNIST数据集,咱们不能管制它生成的是哪一个数字。那么,有没有办法能够对其做肯定的限度束缚,来让生成器生成咱们想要的后果呢?cGAN(Conditional Generative Adversarial Nets)通过增一个额定的向量y对生成器进行束缚。以MNIST分类为例,限度信息y能够取10维的向量,对于类别进行one-hot编码,并与噪声进行拼接一起输出生成器。同样的,判断器也将原来的输出和y进行拼接。作者通过各种试验证实了这个简略的改良的确能够起到对生成器的束缚作用。

判断器&生成器

只须要在前向流传的过程中退出限度变量y,咱们很容易就能失去cGAN的生成器和判断器模型

class Discriminator(nn.Module):    ...    def forward(self, x, label):        label = label.unsqueeze(2).unsqueeze(3)        label = label.repeat(1, 1, x.size(2), x.size(3))        x = torch.cat(tensors=(x, label), dim=1)        y_ = self.conv(x)        ...class Generator(nn.Module):    ...    def forward(self, x, label):        x = x.unsqueeze(2).unsqueeze(3)        label = label.unsqueeze(2).unsqueeze(3)        x = torch.cat(tensors=(x, label), dim=1)        y_ = self.fc(x)        ...

Pix2Pix

在下面的cGAN例子中,咱们的管制信息取的是想要图像的标签,如果这个管制信息更加的丰盛,例如输出一整张图像,那么它是否实现一些更加高级的工作?Pix2Pix(Image-to-Image Translation with Conditional Adversarial Networks)将这一类问题演绎为图像到图像的翻译,其应用改良后的U-net作为生成器,并设计了新鲜的Patch-D判断器构造来输入高清的图像。Patch-D是指,不论网络所应用的输出图像有多大,都将其切割成若干个固定大小的Patch,判断器只需对这些Patch的虚实进行判断。因为L1损失曾经能够掂量生成图像和实在图像的全局差别,所以作者认为判断器只须要用Patch-D这样更关注于部分差别的构造即可。同时Patch-D的构造使得网络的输出变小,缩小了计算量并且增大了框架的扩展性。

CycleGAN

Pix2Pix尽管能够生成高清的图像,但其存在一个致命的毛病:须要互相配对的图片x与y。在现实生活中,这样成对的图片很难或者基本不可能收集到,这就大大的限度了Pix2Pix的利用。对此,CycleGAN(Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks)提出了不须要配对的图像翻译办法。

CycleGAN其实就是一个X->Y的单向GAN上再加一个Y->X的单向GAN,形成一个“循环”。网络的构造和单次训练过程如下(图片来自于量子位):


除了经典的根底GAN损失之外,CycleGAN还引入了Consistency loss的概念。循环统一损失使得X->Y转变的过程中必须保留有X的局部个性。循环损失的公式如下:
$$\large L_{cyc}(G,F)=\mathbb E_{x\sim p_{data}(x)}[\|F(G(x))-x\|_1]+\mathbb E_{y\sim p_{data}(y)}[\|G(F(x))-y\|_1]$$
两个判断器的损失示意如下:
$$\large \textit{L}_{GAN}(G,D_Y,X,Y)=\mathbb E_{y\sim p_{data}(y)}[logD_Y(y)]+\mathbb E_{x\sim p_{data}(x)}[log(1-D_Y(G(x)))]$$$$\large \textit{L}_{GAN}(F,D_X,Y,X)=\mathbb E_{x\sim p_{data}(x)}[logD_X(x)]+\mathbb E_{y\sim p_{data}(y)}[log(1-D_X(F(y)))]$$
最初网络的优化指标能够示意为
$$\large \min _{G_{X\rightarrow Y},G_{Y\rightarrow X}}\max_{D_X,D_Y} L(G,F,D_x,D_y)=L_{GAN}(G,D_Y,X,Y)+L_{GAN}(F,D_X,Y,X)+\lambda L_{cyc}(G,F)$$
Pix2Pix以及CycleGAN的官网复现入口:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

StarGAN

Pix2Pix解决了有配对图像的翻译问题,CycleGAN解决了无配对图像的翻译问题,然而他们所作的图像到图像翻译,都是一对一。假如当初须要将人脸转换为喜怒哀乐四个表情,那么他们就须要进行4次不同的训练,这无疑会消耗微小的计算资源。针对于这个问题,StarGAN(StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation)借助cGAN的思维,在网络输出中退出一个域的管制信息。对于判断器,其不仅须要甄别样本是否实在,还须要判断输出的图像来自哪个域。StarGAN的训练过程如下:

  1. 将原始图片c和指标生成域c进行拼接后丢入生成器失去生成图像G(x,c)
  2. 将生成图像G(x,c)和实在图像y别离丢入判断器D,判断器除了须要判断输出图像的真伪之外,还须要判断它来自哪个域
  3. 将生成图像G(x,c)和原始生成域c'丢入生成器生成重构图片(为了对生成器生成的图像做进一步的限度,与CycleGAN的重构损失相似)


理解了StarGAN的训练过程,咱们很容易失去其损失函数各项的表达形式
首先是GAN的个别损失,这里作者采纳了前文所述的WGAN的损失模式:
$$\large L_{adv}=\mathbb E_x[D_{src}(x)]-\mathbb E_{x,c}[D_{src}(G(x,c)))]-\lambda_{gp}\mathbb E_{\hat x}[(\|\nabla _\hat xD_{src}(\hat x)\|_2-1)^2]$$
对于判断器,咱们须要激励其将输出图像正确的分类到指标域c‘(原始生成域):
$$\large L_{src}^r=\mathbb E_{x,c'}[-logD_{cls}(c'|x)]$$

对于生成器,咱们须要激励其胜利坑骗判断器将图片分类到指标域c(指标生成域),此外,生成器还须要在以生成图像和原始生成域c'的输出下胜利将图像还原回去,这两局部的损失示意如下:
$$\largeL_{src}^f=\mathbb R_{x,c}[-logD_{cls}(c|G(x,c))]$$
$$\large L_{rec}=\mathbb E_{x,c,c'}[\|x-G(G(x,c),c')\|_1]$$
各局部损失乘上本人的权重加总后就形成了判断器和生成器的总损失:
$$\largeL_D=-L_{adv}+\lambda_{cls}L_{cls}^{r} $$
$$\large L_G=L_{adv}+\lambda_{cls}L_{clas}^f+\lambda_{rec}L_{rec}$$
此外,为了更具备通用性,作者还退出了mask vector来适应不同的数据集之间的训练。

总结

名称翻新点
DCGAN首次将卷积神经网络引入GAN中
cGAN通过拼接标签信息来管制生成器的输入
Pix2Pix提出了一种图像到图像翻译的通用办法
CycleGAN解决了Pix2Pix须要图像配对的问题
StarGAN提出了一种一对多的图像到图像的翻译办法
InfoGAN基于cGAN改良,提出一种无监督的生成办法,实用于不晓得图像标签的状况
LSGAN用最小二乘损失函数代替原始GAN的损失函数,缓解了训练不稳固、生成图像不足多样性的问题
ProGAN在训练期间逐渐增加新的高分辨率层,能够生成高分辨率的图像
SAGAN将注意力机制引入GAN当中,简洁高效利用了全局信息

本文列举了生成式反抗网络在倒退过程中一些具备代表性的网络结构。GANs现在已广泛应用于图像生成、图像去噪、超分辨、文本到图像的翻译等各个领域,且在近几年的钻研中涌现了很多优良的论文。感兴趣的同学能够从上面的链接中pick本人想要理解的GAN~

  • THE-GAN-ZOO:汇总了各种GAN的论文及代码地址。
  • GAN Timeline:依照工夫线对不同的GAN进行了排序。
  • Browse state-of-the-art:将ArXiv上的最新论文与GitHub代码相关联,并做了比拟排序,波及了深度学习的各个方面。

参考文献

  1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.
  2. Arjovsky M, Chintala S, Bottou L. Wasserstein gan[J]. arXiv preprint arXiv:1701.07875, 2017.
  3. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.
  4. Mirza M, Osindero S. Conditional generative adversarial nets[J]. arXiv preprint arXiv:1411.1784, 2014.
  5. Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 1125-1134.
  6. Zhu J Y, Park T, Isola P, et al. Unpaired image-to-image translation using cycle-consistent adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2223-2232.
  7. Choi Y, Choi M, Kim M, et al. Stargan: Unified generative adversarial networks for multi-domain image-to-image translation[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 8789-8797.
  8. Mao X, Li Q, Xie H, et al. Least squares generative adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2794-2802.
  9. Karras T, Aila T, Laine S, et al. Progressive growing of gans for improved quality, stability, and variation[J]. arXiv preprint arXiv:1710.10196, 2017.
  10. Chen X, Duan Y, Houthooft R, et al. Infogan: Interpretable representation learning by information maximizing generative adversarial nets[C]//Advances in neural information processing systems. 2016: 2172-2180.
  11. Zhang H, Goodfellow I, Metaxas D, et al. Self-attention generative adversarial networks[C]//International Conference on Machine Learning. 2019: 7354-7363.

对于咱们

Mo(网址:https://momodel.cn) 是一个反对 Python的人工智能在线建模平台,能帮忙你疾速开发、训练并部署模型。

近期 Mo 也在继续进行机器学习相干的入门课程和论文分享流动,欢送大家关注咱们的公众号(MomodelAI)获取最新资讯!