关于pytorch:一文使用PyTorch搭建GAN模型

46次阅读

共计 4705 个字符,预计需要花费 12 分钟才能阅读完成。

作者|Ta-Ying Cheng,牛津大学博士研究生,Medium 技术博主,多篇文章均被平台官网刊物 Towards Data Science 收录
翻译|颂贤

以往人们普遍认为生成图像是不可能实现的工作,因为依照传统的机器学习思路,咱们基本没有真值(ground truth)能够拿来测验生成的图像是否合格。

2014 年,Goodfellow 等人则提出生成 反抗网络(Generative Adversarial Network, GAN),可能让咱们齐全依附机器学习来生成极为真切的图片。GAN 的横空出世使得整个人工智能行业都为之触动,计算机视觉和图像生成畛域产生了巨变。

本文将带大家理解 GAN 的工作原理,并介绍如何 通过 PyTorch 简略上手 GAN

GAN 的原理

依照传统的办法,模型的预测后果能够间接与已有的真值进行比拟。然而,咱们却很难定义和掂量到底怎么才算作是“正确的”生成图像。

Goodfellow 等人则提出了一个乏味的解决办法:咱们能够先训练好一个分类工具,来主动辨别生成图像和实在图像。这样一来,咱们就能够用这个分类工具来训练一个生成网络,直到它可能输入齐全以假乱真的图像,连分类工具本人都没有方法评判虚实。

依照这一思路,咱们便有了 GAN:也就是一个 生成器(generator)和一个 判断器(discriminator)。生成器负责依据给定的数据集生成图像,判断器则负责辨别图像是真是假。GAN 的运作流程如上图所示。

损失函数

在 GAN 的运作流程中,咱们能够发现一个显著的矛盾:同时优化生成器和判断器是很艰难的。能够设想,这两个模型有着齐全相同的指标:生成器想要尽可能伪造出实在的货色,而判断器则必须要识破生成器生成的图像。

为了阐明这一点,咱们设 D(x)为判断器的输入,即 x 是实在图像的概率,并设 G(z)为生成器的输入。判断器相似于一种二进制的分类器,所以其指标是使该函数的后果最大化:
这一函数实质上是非负的二元穿插熵损失函数。另一方面,生成器的指标是最小化判断器做出正确判断的机率,因而它的指标是使上述函数的后果最小化。

因而,最终的损失函数将会是两个分类器之间的极小极大博弈,示意如下:

实践上来说,博弈的最终后果将是让判断器判断胜利的概率收敛到 0.5。然而在实践中,极大极小博弈通常会导致网络不收敛,因而认真调整模型训练的参数十分重要。

在训练 GAN 时,咱们尤其要留神学习率等超参数,学习率比拟小时能让 GAN 在输出乐音较多的状况下也能有较为对立的输入。

计算环境

本文将领导大家通过 PyTorch 搭建整个程序(包含 torchvision)。同时,咱们将会应用 Matplotlib 来让 GAN 的生成后果可视化。以下代码可能导入上述所有库:

"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

数据集

数据集对于训练 GAN 来说十分重要,尤其思考到咱们在 GAN 中解决的通常是非结构化数据(个别是图片、视频等),任意一 class 都能够有数据的散布。这种数据分布恰好是 GAN 生成输入的根底。

为了更好地演示 GAN 的搭建流程,本文将带大家应用最简略的 MNIST 数据集,其中含有 6 万张手写阿拉伯数字的图片。

像 MNIST 这样高质量的非结构化数据集都能够在 格物钛 的公开数据集网站上找到。事实上,格物钛 Open Datasets 平台涵盖了很多优质的公开数据集,同时也能够实现 数据集托管及一站式搜寻的性能,这对 AI 开发者来说,是相当实用的社区平台。

硬件需要

一般来说,尽管能够应用 CPU 来训练神经网络,但最佳抉择其实是 GPU,因为这样能够大幅晋升训练速度。咱们能够用上面的代码来测试本人的机器是否用 GPU 来训练:

"""Determine if any GPUs are available"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

实现

网络结构

因为数字是非常简单的信息,咱们能够将判断器和生成器这两层构造都组建成全连贯层(fully connected layers)。

咱们能够用以下代码在 PyTorch 中搭建判断器和生成器:

"""
Network Architectures
The following are the discriminator and generator architectures
"""

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

def forward(self, x):
    x = self.activation(self.fc1(x))
    x = self.activation(self.fc2(x))
    x = self.fc3(x)
    x = x.view(-1, 1, 28, 28)
    return nn.Tanh()(x)

训练

在训练 GAN 的时候,咱们须要一边优化判断器,一边改良生成器,因而每次迭代咱们都须要同时优化两个互相矛盾的损失函数。

对于生成器,咱们将输出一些随机乐音,让生成器来依据乐音的渺小扭转输入的图像:

"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # Training the discriminator
        # Real inputs are actual images of the MNIST dataset
        # Fake inputs are from the generator
        # Real inputs should be classified as 1 and fake as 0
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)

        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # Training the generator
        # For generator, goal is to make the discriminator believe everything is 1
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')

后果

通过 100 个训练期间之后,咱们就能够对数据集进行可视化解决,间接看到模型从随机乐音生成的数字:

咱们能够看到,生成的后果和实在的数据十分相像。思考到咱们在这里只是搭建了一个非常简单的模型,理论的利用成果会有十分大的回升空间。

不仅是有样学样

GAN 和以往机器视觉专家提出的想法都不一样,而利用 GAN 进行的具体场景利用更是让许多人赞叹深度网络的有限后劲。上面咱们来看一下两个最为闻名的 GAN 延申利用。

CycleGAN

朱俊彦等人 2017 年发表的 CycleGAN 可能在没有配对图片的状况下将一张图片从 X 域间接转换到 Y 域,比方把马变成斑马、将热夏变成隆冬、把莫奈的画变成梵高的画等等。这些看似天方夜谭的转换 CycleGAN 都能轻松做到,并且后果十分精确。

GauGAN

英伟达则通过 GAN 让人们可能只须要寥寥数笔勾画出本人的想法,便能失去一张极为真切的实在场景图片。尽管这种利用须要的计算成本极为昂扬,然而 GauGAN 凭借它的转换能力摸索出了前所未有的钻研和应用领域。

结语

置信看到这里,你曾经晓得了 GAN 的大抵工作原理,并且可能本人入手简略搭建一个 GAN 了。

正文完
 0