共计 5733 个字符,预计需要花费 15 分钟才能阅读完成。
手把手带你疾速入门超过 GAN 的 Normalizing Flow
作者:Aryansh Omray,微软数据迷信工程师,Medium 技术博主
机器学习畛域的一个根本问题就是如何学习简单数据的表征是机器学习。
这项工作的重要性在于,现存的大量非结构化和无标签的数据,只有通过无监督式学习能力了解。密度估计、异样检测、文本总结、数据聚类、生物信息学、DNA 建模等各方面的利用均须要实现这项工作。
多年来,钻研人员创造了许多办法来学习大型数据集的概率分布,包含生成反抗网络(GAN)、变分自编码器(VAE)和 Normalizing Flow 等。
本文即向大家介绍 Normalizing Flow 这一为了克服 GAN 和 VAE 的有余而提出的办法。
Glow 模型的输入样例 (Source)
GAN 和 VAE 的能力本已非常惊人,它们都能通过简略的推理方法学习十分复杂的数据分布。
然而,GAN 和 VAE 都不足对概率分布的准确评估和推理,这往往导致 VAE 中的含糊后果品质不高,GAN 训练也面临着如模式解体和后置解体等挑战。
因而,Normalizing Flow 应运而生,试图通过应用可逆函数来解决目前 GAN 和 VAE 存在的许多问题。
Normalizing Flow
简略地说,Normalizing Flow 就是一系列的可逆函数,或者说这些函数的解析逆是能够计算的。例如,f(x)=x+ 2 是一个可逆函数,因为每个输出都有且仅有一个惟一的输入,并且反之亦然,而 f(x)=x²则不是一个可逆函数。这样的函数也被称为双射函数。
图源作者
从上图能够看出,Normalizing Flow 能够将简单的数据点(如 MNIST 中的图像)转化为简略的高斯分布,反之亦然。和 GAN 十分不一样的中央是,GAN 输出的是一个随机向量,而输入的是一个图像,基于流 (Flow) 的模型则是将数据点转化为简略散布。在上图的 MNIST 一例中,咱们从高斯分布中抽取随机样本,均可从新取得其对应的 MNIST 图像。
基于流的模型应用负对数可能性损失函数进行训练,其中 p(z)是概率函数。上面的损失函数就是应用统计学中的变量变动公式失去的。
(Source)
Normalizing Flow 的劣势
与 GAN 和 VAE 相比,Normalizing Flow 具备各种劣势,包含:
- Normalizing Flow 模型不须要在输入中放入噪声,因而能够有更弱小的部分方差模型(local variance model);
- 与 GAN 相比,基于流的模型训练过程十分稳固,GAN 则须要认真调整生成器和判断器的超参数;
- 与 GAN 和 VAE 相比,Normalizing Flow 更容易收敛。
Normalizing Flow 的有余
尽管基于流的模型有其劣势,但它们也有一些毛病:
- 基于流的模型在密度估计等工作上的体现不尽如人意;
- 基于流的模型要求保留变换的体积(volume preservation over transformations),这往往会产生十分高维的潜在空间,通常会导致解释性变差;
- 基于流的模型产生的样本通常没有 GAN 和 VAE 的好。
为了更好地了解 Normalizing Flow,咱们以 Glow 架构为例进行解释。Glow 是 OpenAI 在 2018 年提出的一个基于流的模型。下图展现了 Glow 的架构。
Glow 的架构(Source)
Glow 架构由多个表层(superficial layers)组合而成。首先咱们来看看 Glow 模型的多尺度框架。Glow 模型由一系列的反复层(命名为尺度)组成。每个尺度包含一个挤压函数和一个流步骤,每个流步骤蕴含 ActNorm、1×1 Convolution 和 Coupling Layer,流步骤后是宰割函数。宰割函数在通道维度上将输出分成两个相等的局部。其中一半进入之后的层,另一半则进入损失函数。宰割是为了缩小梯度隐没的影响,梯度隐没会在模型以端到端形式(end-to-end)训练时呈现。
如下图所示,挤压函数 (squeeze function)通过横向重塑张量,将大小为[c, h, w] 的输出张量转换为大小为 [4c, h/2, w/2] 的张量。此外,在测试阶段能够采纳重塑函数,将输出的 [4c, h/2, w/2] 重塑为大小为 [c, h, w] 的张量。
(Source)
其余层,如 ActNorm、1×1 Convolution 和 Affine Coupling 层,能够从下表了解。该表展现了每层的性能(包含正向和反向)。
(Source)
实现
在理解了 Normalizing Flow 和 Glow 模型的基础知识后,咱们将介绍如何应用 PyTorch 实现该模型,并在 MNIST 数据集上进行训练。
Glow 模型
首先,咱们将应用 PyTorch 和 nflows 实现 Glow 架构。为了节省时间,咱们应用 nflows 蕴含所有层的实现。
import torch
import torch.nn as nn
import torch.nn.functional as F
from nflows import transforms
import numpy as np
from torchvision.transforms.functional import resize
from nflows.transforms.base import Transform
class Net(nn.Module):
def __init__(self, in_channel, out_channels):
super().__init__()
self.net = nn.Sequential(nn.Conv2d(in_channel, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 1),
nn.ReLU(inplace=True),
ZeroConv2d(64, out_channels),
)
def forward(self, inp, context=None):
return self.net(inp)
def getGlowStep(num_channels, crop_size, i):
mask = [1] * num_channels
if i % 2 == 0:
mask[::2] = [-1] * (len(mask[::2]))
else:
mask[1::2] = [-1] * (len(mask[1::2]))
def getNet(in_channel, out_channels):
return Net(in_channel, out_channels)
return transforms.CompositeTransform([transforms.ActNorm(num_channels),
transforms.OneByOneConvolution(num_channels),
transforms.coupling.AffineCouplingTransform(mask, getNet)
])
def getGlowScale(num_channels, num_flow, crop_size):
z = [getGlowStep(num_channels, crop_size, i) for i in range(num_flow)]
return transforms.CompositeTransform([transforms.SqueezeTransform(),
*z
])
def getGLOW():
num_channels = 1 * 4
num_flow = 32
num_scale = 3
crop_size = 28 // 2
transform = transforms.MultiscaleCompositeTransform(num_scale)
for i in range(num_scale):
next_input = transform.add_transform(getGlowScale(num_channels, num_flow, crop_size),
[num_channels, crop_size, crop_size])
num_channels *= 2
crop_size //= 2
return transform
Glow_model = getGLOW()
咱们能够用各种数据集来训练 Glow 模型,如 MNIST、CIFAR-10、ImageNet 等。本文为了演示不便,应用的是 MNIST 数据集。
像 MNIST 这样的数据集能够很容易地从 格物钛公开数据集平台 获取,该平台蕴含了机器学习中所有罕用的凋谢数据集,如分类、密度估计、物体检测和基于文本的分类数据集等。
要拜访数据集,咱们只须要在格物钛的平台上创立账户,就能够间接 fork 想要的数据集,能够间接下载或者应用格物钛提供的 pipeline 导入数据集。根本的代码和相干文档可在 TensorBay 的反对网页上取得。
联合格物钛 TensorBay 的 Python SDK,咱们能够很不便地导入 MNIST 数据集到 PyTorch 中:
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tensorbay import GAS
from tensorbay.dataset import Dataset as TensorBayDataset
class MNISTSegment(Dataset):
def __init__(self, gas, segment_name, transform):
super().__init__()
self.dataset = TensorBayDataset("MNIST", gas)
self.segment = self.dataset[segment_name]
self.category_to_index = self.dataset.catalog.classification.get_category_to_index()
self.transform = transform
def __len__(self):
return len(self.segment)
def __getitem__(self, idx):
data = self.segment[idx]
with data.open() as fp:
image_tensor = self.transform(Image.open(fp))
return image_tensor, self.category_to_index[data.label.classification.category]
模型训练
模型训练能够通过上面的代码简略开始。该代码应用格物钛 TensorBay 提供的 Pipeline 创立数据加载器,其中的 ACCESS_KEY 能够在 TensorBay 的账户设置中取得。
from nflows.distributions import normal
ACCESS_KEY = "Accesskey-*****"
EPOCH = 100
to_tensor = transforms.ToTensor()
normalization = transforms.Normalize(mean=[0.485], std=[0.229])
my_transforms = transforms.Compose([to_tensor, normalization])
train_segment = MNISTSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms)
train_dataloader = DataLoader(train_segment, batch_size=4, shuffle=True, num_workers=4)
optimizer = torch.optim.Adam(Glow_model.parameters(), 1e-3)
for epoch in range(EPOCH):
for index, (image, label) in enumerate(train_dataloader):
if index == 0:
image_size = image.shaape[2]
channels = image.shape[1]
image = image.cuda()
output, logabsdet = Glow_model._transform(image)
shape = output.shape[1:]
log_z = normal.StandardNormal(shape=shape).log_prob(output)
loss = log_z + logabsdet
loss = -loss.mean()/(image_size * image_size * channels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch:{epoch+1}/{EPOCH} Loss:{loss}")
下面代码用的是 MNIST 数据集,要想应用其余数据集咱们能够间接替换该数据集的数据加载器。
样例生成
模型训练实现之后,咱们能够通过上面的代码来生成样例:
samples = Glow_model.sample(25)
display(samples)
应用 nflows 库之后,咱们只须要用一行代码就能够生成样例,而 display 函数则能在一个网格中显示生成的样本。
用 MNIST 训练模型之后生成的样例
结语
本文向大家介绍了 Normalizing Flow 的基本知识,并与 GAN 和 VAE 进行了比拟,同时向大家展现了 Glow 模型的根本工作形式。咱们还解说了如何简略实现 Glow 模型,并应用 MNIST 数据集进行训练。在格物钛公开数据集平台的帮忙下,数据集拜访变得非常便捷。
【对于格物钛】
格物钛智能科技 专一打造人工智能新型基础设施,通过非结构化数据平台和公开数据集社区,帮忙机器学习团队和集体更好地开释非结构化数据后劲,让 AI 利用开发更快、性能体现更优,继续为人工智能赋能千行百业、驱动产业降级、推动科技普惠打造坚实基础。目前已取得红杉、云启、真格、风和、耀途资本以及奇绩创坛的千万美金投资。