共计 5758 个字符,预计需要花费 15 分钟才能阅读完成。
作者 |GUEST
编译 |VK
起源 |Analytics Vidhya
介绍
SimCLR 论文 (http://cse.iitkgp.ac.in/~aras…,并且如果有足够的计算能力,能够产生与监督模型相似的后果。
然而这些需要使得框架的计算量相当大。如果咱们能够领有这个框架的简略性和弱小性能,并且有更少的计算需要,这样每个人都能够拜访它,这不是很好吗?Moco-v2 前来救济。
留神:在之前的一篇博文中,咱们在 PyTorch 中实现了 SimCLR 框架,它是在一个蕴含 5 个类别的简略数据集上实现的,总共只有 1250 个训练图像。
数据集
这次咱们将在 Pytorch 中在更大的数据集上实现 Moco-v2,并在 Google Colab 上训练咱们的模型。这次咱们将应用 Imagenette 和 Imagewoof 数据集
来自 Imagenette 数据集的一些图像
这些数据集的疾速摘要(更多信息在这里:https://github.com/fastai/ima…):
- Imagenette 由 Imagenet 的 10 个容易分类的类组成,总共有 9479 个训练图像和 3935 个验证集图像。
- Imagewoof 是一个由 Imagenet 提供的 10 个难分类组成的数据集,因为所有的类都是狗的种类。总共有 9035 个训练图像,3939 个验证集图像。
比照学习
比照学习在自我监督学习中的作用是基于这样一个理念:咱们心愿同一类别中不同的图像观具备类似的表征。然而,因为咱们不晓得哪些图像属于同一类别,通常所做的是将同一图像的不同外观的示意拉近。咱们把这些不同的外观称为正对(positive pairs)。
另外,咱们心愿不同类别的图像有不同的外观,使它们的表征彼此远离。不同图像的不同外观的出现与类别无关,会被彼此推开。咱们把这些不同的外观称为负对(negative pairs)。
在这种状况下,一个图像的前景是什么?前景能够被认为是以一种通过批改的形式对待图像的某些局部,它实质上是图像的一种变换。
依据手头的工作,有些转换能够比其余转换工作得更好。SimCLR 表明,利用随机裁剪和色彩抖动能够很好地实现各种工作,包含图像分类。这实质上来自于网格搜寻,从旋转、裁剪、剪切、噪声、含糊、Sobel 滤波等选项中抉择一对变换。
从外观到示意空间的映射是通过神经网络实现的,通常,resnet 用于此目标。上面是从图像到示意的管道
负对是如何产生的?
在同一幅图像中,因为随机裁剪,咱们能够失去多个示意。这样,咱们就能够产生正对。
然而如何生成负对呢?负对是来自不同图像的示意。SimCLR 论文在同一批中创立了这些。如果一个批蕴含 N 个图像,那么对于每个图像,咱们将失去 2 个示意,这总共占 2 * N 个示意。对于一个特定的示意 x,有一个示意与 x 造成正对(与 x 来自同一个图像的示意),其余所有示意(正好是 2 *N–2)与 x 造成负对。
如果咱们手头有大量的负样本,这些示意就会失去改善。然而,在 SimCLR 中,只有当批量较大时,能力实现大量的负样本,这导致了对计算能力的更高要求。MoCo-v2 提供了生成负样本的另一种办法。让咱们具体理解一下。
动静词典
咱们能够用一种略微不同的形式来对待比照学习办法,行将查问与键进行匹配。咱们当初有两个编码器,一个用于查问,另一个用于键。此外,为了失去大量的负样本,咱们须要一个大的键编码字典。
此上下文中的正对示意查问与键匹配。如果查问和键都来自同一个图像,则它们匹配。编码的查问应该与其匹配的键类似,而与其余查问不同。
对于负对,咱们保护一个大字典,其中蕴含以前批处理的编码键。它们作为查问的负样本。咱们以队列的模式保护字典。新的 batch 被入队,较早的 batch 被入列。通过更改此队列的大小,能够更改负采样数。
这种办法的挑战
- 随着键编码器的更改,在稍后工夫点排队的键可能与较早排队的键不统一。为了应用比照学习办法,与查问进行比拟的所有键必须来自雷同或类似的编码器,这样比拟才会有意义且统一。
- 另一个挑战是,应用反向流传学习编码器参数是不可行的,因为这将须要计算队列中所有样本的梯度(这将导致大的计算图)。
为了解决这两个问题,MoCo 将键编码器实现为基于动量的查问编码器的挪动平均值 [1]。这意味着它以这种形式更新要害编码器参数:
其中 m 十分靠近于 1(例如,典型值为 0.999),这确保咱们在不同的工夫从类似的编码器取得编码键。
损失函数 -InfoNCE
咱们心愿查问靠近其所有正样本,远离所有负样本。InfoNC 函数 E 会捕捉它。它代表信息噪声比照预计。对于查问 q 和键 k,InfoNCE 损失函数是:
咱们能够重写为:
当 q 和 k 的相似性增大,q 与负样本的相似性减小时,损失值减小
以下是损失函数的代码:
τ = 0.05
def loss_function(q, k, queue):
# N 是批量大小
N = q.shape[0]
# C 是示意的维数
C = q.shape[1]
# bmm 代表批处理矩阵乘法
# 如果 mat1 是 b×n×m 张量,那么 mat2 是 b×m×p 张量,# 而后输入一个 b×n×p 张量。pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ))
# 在查问和队列张量之间执行矩阵乘法
neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1)
# 求和
denominator = neg + pos
return torch.mean(-torch.log(torch.div(pos,denominator)))
让咱们再看看这个损失函数,并将它与分类穿插熵损失函数进行比拟。
这里 predᵢ是数据点在第 i 类中的概率值预测,trueᵢ是该点属于第 i 类的理论概率值(能够是含糊的,但大多数状况下是一个 one-hot)。
如果你不相熟这个话题,你能够看这个视频来更好地了解穿插熵。另外,请留神,咱们常常通过 softmax 这样的函数将分数转换为概率值:https://www.youtube.com/watch…
咱们能够把信息损失函数看作穿插熵损失。数据样本“q”的正确类是第 r 类,底层分类器基于 softmax,它试图在 K + 1 类之间进行分类。
Info-NCE 还与编码表示之间的互相信息无关;对于这一点的更多细节见 [4]。
MoCo-v2 框架
当初,让咱们把所有的货色放在一起,看看整个 Moco-v2 算法是什么样子的。
步骤 1:
咱们必须失去查问和键编码器。最后,键编码器具备与查问编码器雷同的参数。它们是彼此的复制品。随着训练的进行,键编码器将成为查问编码器的挪动平均值(在这一点上停顿迟缓)。
因为计算能力的限度,咱们应用 Resnet-18 体系结构来实现。在通常的 resnet 架构之上,咱们增加了一些密集的层,以使示意的维数降到 25。这些层中的某些层稍后将充当投影。
# 定义咱们的深度学习架构
resnetq = resnet18(pretrained=False)
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(resnetq.fc.in_features, 100)),
('added_relu1', nn.ReLU(inplace=True)),
('fc2', nn.Linear(100, 50)),
('added_relu2', nn.ReLU(inplace=True)),
('fc3', nn.Linear(50, 25))
]))
resnetq.fc = classifier
resnetk = copy.deepcopy(resnetq)
# 将 resnet 架构迁徙到设施
resnetq.to(device)
resnetk.to(device)
步骤 2:
当初,咱们曾经有了编码器,并且假如咱们曾经设置了其余重要的数据结构,当初是时候开始训练循环并了解管道了。
这一步是从训练批中获取编码查问和键。咱们用 L2 范数对示意进行规范化。
只是一个约定正告,所有后续步骤中的代码都将位于批处理和 epoch 循环中。咱们还将张量“k”从它的梯度中分离出来,因为咱们不须要计算图中的键编码器局部,因为动量更新方程会更新键编码器。
# 梯度零化
optimizer.zero_grad()
# 检索 xq 和 xk 这两个图像 batch
xq = sample_batched['image1']
xk = sample_batched['image2']
# 把它们移到设施上
xq = xq.to(device)
xk = xk.to(device)
# 获取他们的输入
q = resnetq(xq)
k = resnetk(xk)
k = k.detach()
# 将输入规范化,使它们成为单位向量
q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1))
k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
步骤 3:
当初,咱们将查问、键和队列传递给后面定义的 loss 函数,并将值存储在一个列表中。而后,像平常一样,对损失值调用 backward 函数并运行优化器。
# 取得损失值
loss = loss_function(q, k, queue)
# 把这个损失值放到 epoch 损失列表中
epoch_losses_train.append(loss.cpu().data.item())
# 反向流传
loss.backward()
# 运行优化器
optimizer.step()
步骤 4:
咱们将最新的 batch 退出咱们的队列。如果咱们的队列大小大于咱们定义的最大队列大小 (K),那么咱们就从其中取出最老的 batch。能够应用 torch.cat 进行队列操作。
# 更新队列
queue = torch.cat((queue, k), 0)
# 如果队列大于最大队列大小(k),则入列
# batch 大小是 256,能够用变量替换
if queue.shape[0] > K:
queue = queue[256:,:]
步骤 5:
当初咱们进入训练循环的最初一步,即更新键编码器。咱们应用上面的 for 循环来实现这一点。
# 更新 resnet
for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()):
θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
一些训练细节
训练 resnet-18 模型的 Imagenette 和 Imagewoof 数据集的 GPU 工夫靠近 18 小时。为此,咱们应用了 googlecolab 的 GPU(16GB)。咱们应用的 batch 大小为 256,tau 值为 0.05,学习率为 0.001,最终升高到 1e-5,权重衰减为 1e-6。咱们的队列大小为 8192,键编码器的动量值为 0.999。
后果
前 3 层(将 relu 视为一层)定义了投影头,咱们将其移除用于图像分类的上游工作。在剩下的网络上,咱们训练了一个线性分类器。
咱们失去了 64.2% 的正确率,而应用 10% 的标记训练数据,应用 MoCo-v2。相比之下,应用最先进的监督学习办法,其准确率靠近 95%。
对于 Imagewoof,咱们对 10% 的标记数据失去了 38.6% 的准确率。在这个数据集上进行比照学习的成果低于咱们的预期。咱们狐疑这是因为首先,数据集十分艰难,因为所有类都是狗类。
其次,咱们认为色彩是这些类的一个重要的区别特色。利用色彩抖动可能会导致来自不同类的多个图像彼此混合示意。相比之下,监督办法的准确率靠近 90%。
可能弥合自监督模型和监督模型之间差距的设计变更 :
- 应用更大更宽的模型。
- 通过应用更大的批量和字典大小。
- 应用更多的数据,如果能够的话。同时引入所有未标记的数据。
- 在大量数据上训练大型模型,而后提取它们。
一些有用的链接:
- 谷歌 Colab:https://colab.research.google…
- Imagewoof Github 仓库后果:https://github.com/thunderInf…
- Imagenette Github 仓库后果:https://github.com/thunderInf…
- Imagewoof 数据集链接:https://github.com/thunderInf…
- Imagenette 数据集链接:https://github.com/thunderInf…
参考援用
- Momentum Contrast for Unsupervised Visual Representation Learning, Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick(https://arxiv.org/pdf/1911.05…
- Improved Baselines with Momentum Contrastive Learning, Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He(https://arxiv.org/pdf/2003.04…
- A simple framework for contrastive learning of visual representations, Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton.(https://arxiv.org/pdf/2002.05…
- Representation Learning with Contrastive Predictive Coding, Aaron van den Oord, Yazhe Li, and Oriol Vinyals(https://arxiv.org/pdf/1807.03…
原文链接:https://www.analyticsvidhya.c…
欢送关注磐创 AI 博客站:
http://panchuang.net/
sklearn 机器学习中文官网文档:
http://sklearn123.com/
欢送关注磐创博客资源汇总站:
http://docs.panchuang.net/