关于人工智能:更简单的掩码图像建模框架SimMIM介绍和PyTorch代码实现

21次阅读

共计 14439 个字符,预计需要花费 37 分钟才能阅读完成。

MAE 公布以来,各种应用掩码技术的自监督掩码模型在其根底之上有了更进一步的钻研。在本文中咱们将摸索一篇和 MAE 同期的工作:SimMIM: A Simple Framework for Masked Image Modeling,钻研团队是微软亚研院,并在 PyTorch 中编写它,最初咱们也会提供相干的代码。

SimMIM 的骨干网络是 VIT,相熟自监督学习的基础知识也十分有帮忙,最初咱们还要精通 PyTorch,因为咱们应用它来实现咱们的模型。

图像中的掩码技术

在过来的几年中,比照学习和非比照学习办法始终是计算机视觉 (CV) 的自监督学习 (SSL) 的次要模式,他们中的最先进的 (SOTA) 模型与监督学习处于等同位置。从根本上说,比照学习的目标是教会神经网络将类似的数据点 (正对) 放在一起,并将不同的数据点 (负对) 离开,这是一项须要学习视觉模式的工作。非比照学习克服了与比照学习相干的阻碍(例如,须要大量的标注数据)。

而自然语言解决 (NLP) 为 SSL 应用掩码建模,其中输出的一个随机片段被掩码,模型的指标是依据残余的信息复原它,这样做的印象是将教会模型语法。像 BERT 这样的神经网络就属于这一类,这种形式曾经获得了惊人的性能。

NLP 和视觉之间存在肯定的差别,图像中的局部性十分强,即左近的像素高度相干,因而即便一个像素被屏蔽,通过剖析其街坊也能够绝对容易地推断出它的值。并且照片是间断的不像 NLP 中的标记是离散的,像素是低级原始特色而单词是人类构建的高级概念。

随着 ViT 的呈现,蒙面建模最近已进入计算机视觉畛域,在 ImageNet 分类等上游工作上获得了具备竞争力的分数。然而这些办法很辣手,并且依赖于精密的组件,如像 iGPT 一样像素聚类,以及通过额定的离散变分主动编码器 (dVAE) 进行标记化,这是 BEiT 应用的一种技术。

SimMIM 是一个简略的掩码图像建模框架并且超过了以前的 SOTA 基线,在没有简单的元素的同时放弃了效率。具体来说在提取图像的标记后,SimMIM 通过用可学习的掩码标记替换它们来随机屏蔽一些标记,并用 ViT 对数据进行编码。接下来通过将掩码标记的编码表示传递给线性层来重建缺失局部,损失是预测像素和理论像素之间的 L1 损失除以掩码标记的数量。

Pytorch 实现

SimMIM 很简略而且没有特地简单的操作。咱们假如从一组维度为 batch_size X n_tokens X token_dim 的令牌开始。

fromtorchimport (randn,)

# tokens is currently a dummy tensor.
# Later, it will be replaced by the actual tokens
tokens=randn(batch_size, n_tokens, token_dim)

首先必须确定要屏蔽哪些标记。一种策略是在每个样本的 [0, n_tokens-1](从零开始的索引)范畴内生成一组索引,这些索引将是为该行屏蔽的标记的索引。

fromtorchimport (randn,)

tokens=randn(batch_size, n_tokens, token_dim)

indices_to_mask=randn(batch_size, n_tokens)

# Number of tokens to mask
# 50% of the total number of tokens performs well on average.
# However, for smaller patch sizes, a higher masking ratio is generally better.
# For example, for a patch size of 32, 0.5 performs well but for 
# a patch size of 16, it would be worthwhile to increase it to 0.8.
n_masked_tokens=int(0.5*n_tokens)

# topk returns the k largest elements as well as their indices
# dim=1 tells it to find the maximum values and their indices
# on a per-row basis
# The indices of the tokens that are to be masked is going
# to be the indices of the n_masked_tokens largest values
indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

# The largest values can be accesses via indices_to_mask.values,
# and their indices can be accessed via indices_to_mask.indices
indices_to_mask=indices_to_mask.indices

indices_to_mask 的形态为 batch_size X n_masked_tokens,每一行都蕴含要为该特定数据点屏蔽标记的索引。应用 indices_to_mask 索引标记会略微简单一些,所以更好的办法是构建大小为 batch_size * n_tokens 的位掩码,其中如果标记 i 被屏蔽,则 bitmaski 为 True,否则为 False .

fromtorchimport (
randn,
zeros,
)

tokens=randn(batch_size, n_tokens, token_dim)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

# Initially, bitmask is simply full of zeros (i.e., False)
bitmask=zeros(batch_size, n_tokens)

# What this line does is as follows:
# For every row i, bitmask[i][j] is replaced
# by the value argument (in this case 1), where j takes every value
# in indices_to_mask[i].
# For example, if indices_to_mask[3] is
# [2, 4, 7], then bitmask[3][2], bitmask[3][4], and bitmask[3][7]
# are all set to 1.
bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

要应用位掩码首先要通过 VIT 从输出产生令牌。咱们这里应用的 ViT 来自 timm 包,然而它能够很容易地为转换为其余实现。

fromtorchimport (
randn,
zeros,
)

# vit is assumed to be a vision transformer from timm
# To get tokens from a timm ViT, one must call its patch_embed method
# tokens is now of shape batch_size X n_tokens X token_dim
# Keep in mind that input is image data and of size 
# batch_size X n_channels X height X width
tokens=vit.patch_embed(input)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

下一步应用掩码令牌替换切片。PyTorch 不容许以 Inplace 的形式批改变量,不能间接将掩码标记赋值给令牌[bitmask];所以必须用掩码标记填充形态为 batch_size n_tokens token_dim 的张量(维度雷同),

fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (Parameter,)

tokens=vit.patch_embed(input)

# The mask token itself is simply a vector of dimension token_dim
mask_token=Parameter(randn(token_dim))

# mask_token is repeated to make it the same shape as tokens
# mask_tokens is now of size batch_size X n_tokens X token_dim
mask_tokens=mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

这样就实现了掩码的过程

fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (Parameter,)

tokens=vit.patch_embed(input)

mask_token=Parameter(randn(token_dim))

mask_tokens=mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

# bitmask must have the same number of axes as tokens and mask_tokens
# Therefore, unsqueeze(2) adds an axis to it and it is now of shape batch_size X n_tokens X 1
bitmask=bitmask.unsqueeze(2)

# ~bitmask turns True to False and False to True
# Here, all that is taking place is (~bitmask) is multiplied by tokens
# to zero out every token that is supposed to be masked, and the result is added
# to bitmask*mask_tokens, in which everything is 0 except the tokens that are
# supposed to mask.
tokens= (~bitmask)*tokens+bitmask*mask_tokens

而后就是地位嵌入

fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (Parameter,)

tokens=vit.patch_embed(input)

mask_token=Parameter(randn(token_dim))

mask_tokens=mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

bitmask=bitmask.unsqueeze(2)

tokens= (~bitmask)*tokens+bitmask*mask_tokens

# In timm, a ViT's position embedding is accessible via vit.pos_embed
# The reason for vit.pos_embed[:, 1:] in place of simply vit.pos_embed
# is that the first position embedding vector is for the class token,
# which is not used for self-supervised learning.
tokens=tokens+vit.pos_embed[:, 1:]

令牌能够被输出到 ViT 取得它的编码表示。

fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (Parameter,)

tokens=vit.patch_embed(input)

mask_token=Parameter(randn(token_dim))

mask_tokens=mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

bitmask=bitmask.unsqueeze(2)

tokens= (~bitmask)*tokens+bitmask*mask_tokens

tokens=tokens+vit.pos_embed[:, 1:]

# The encoded representation of tokens
encoded=vit.blocks(tokens)

被屏蔽的令牌将从编码中获取,而后它们通过线性层来重建像素值。

fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (
Linear,
Parameter,
)

tokens=vit.patch_embed(input)

mask_token=Parameter(randn(token_dim))

mask_tokens=mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

bitmask=bitmask.unsqueeze(2)

tokens= (~bitmask)*tokens+bitmask*mask_tokens

tokens=tokens+vit.pos_embed[:, 1:]

encoded=vit.blocks(tokens)

# To index input and encoded with bitmask,
# the axis that was added must be removed.
# This reverts bit_mask to a size of batch_size X n_tokens
bitmask=bitmask.squeeze(2)

# The encoded mask tokens, of shape batch_size X n_masked_tokens X token_dim
masked_tokens_encoded=encoded[bitmask]

# In timm, A ViT's patch height and width are vit.patch_embed.patch_size
patch_height=patch_width=vit.patch_embed.patch_size

# The input is the tokens,
# the output is the reconstructed raw pixel values.
# Therefore, the output shape is 3 (for 3 channels)
# multiplied by patch_height*patch_width, which is the original shape
# of the patches before they were tokenized
decoder_out_dim=3*patch_height*patch_width
decoder=Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)

# The reconstructed pixels, of shape batch_size X n_masked_tokens X 3*patch_height*patch_width
masked_patches_reconstructed=decoder(masked_tokens_encoded)

最初 masked_patchesde 的重构与初始数据中的原始像素进行比拟。因为输出的 patches 不可用,因而必须对输出进行 patche 解决。PyTorch 的 reshap 函数有一些限度,用 torch 进行拼接。重塑将产生不正确的输入。所以一个简略的解决方案是 einops(它是一个不便用于操作张量的库,并且与框架无关)。

须要留神的是,patches 和令牌(Token)是不同的。patches 是从 batch_size 3 height width 重塑为 batch_size n_tokens 3 patch_height * patch_width 的数据,而令牌是通过沿最终轴线性变换 patches 创立的。

fromeinopsimport (rearrange,)
fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (
Linear,
Parameter,
)

tokens=vit.patch_embed(input)

mask_token=torch.nn.Parameter(torch.randn(token_dim))

mask_tokens=self.mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

bitmask=bitmask.unsqueeze(2)

tokens= (~bitmask)*tokens+bitmask*mask_tokens

tokens=tokens+vit.pos_embed[:, 1:]

encoded=vit.blocks(tokens)

bitmask=bitmask.squeeze(2)

masked_tokens_encoded=encoded[bitmask]

patch_height=patch_width=vit.patch_embed.patch_size

decoder_out_dim=3*patch_height*patch_width
decoder=Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)

masked_patches_reconstructed=decoder(masked_tokens_encoded)

# patterns tells einops how to rearrange the tensor
# Its layout is as follows: 'shape_before -> shape_after'
# In this case, the shape before would be batch_size X n_channels X height X width,
# and the shape after would be batch_size X n_tokens X n_channels*patch_height*patch_width
# However, in einops, variables that are in shape_before must be in shape_after as well and vice versa
# For example, in this case, height is in shape_before but not shape_after.
# Therefore, shape_before and shape_after must be restructured.
# Particularly, two new variables can be introduced, n_patches_height and n_patches_width,
# that say how many patches are along the height and width axes respectively.
# Thus, height = n_patches_height * patch_height,
# width = n_patches_width * patch_width, and 
# n_tokens = n_patches_height * n_patches width
# Multiplying two variables in einops is denoted by (x y).
pattern= ('batch_size n_channels (n_patches_height patch_height) (n_patches_width patch_width) ->'
'batch_size (n_patches_height n_patches_width) (n_channels patch_height patch_width)'
  )

# einops.rearrange is like torch.reshape
# einops cannot infer patch_height and patch_width,
# so they must be passed manually
# patches is now of shape batch_size X n_tokens X 3*patch_height*patch_width
patches=rearrange(
tensor=input,
pattern=pattern,
patch_height=patch_height,
patch_width=patch_width,
)

得对应于 masked_patches_reconstructed 的 patche 局部,

fromeinopsimport (rearrange,)
fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (
Linear,
Parameter,
)

tokens=vit.patch_embed(input)

mask_token=torch.nn.Parameter(torch.randn(token_dim))

mask_tokens=self.mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices

bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

bitmask=bitmask.unsqueeze(2)

tokens= (~bitmask)*tokens+bitmask*mask_tokens

tokens=tokens+vit.pos_embed[:, 1:]

encoded=vit.blocks(tokens)

bitmask=bitmask.squeeze(2)

masked_tokens_encoded=encoded[bitmask]

patch_height=patch_width=vit.patch_embed.patch_size

decoder_out_dim=3*patch_height*patch_width
decoder=Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)

masked_patches_reconstructed=decoder(masked_tokens_encoded)

pattern= ('batch_size n_channels (n_patches_height patch_height) (n_patches_width patch_width) ->'
'batch_size (n_patches_height n_patches_width) (n_channels patch_height patch_width)'
  )

patches=einops.rearrange(
tensor=input,
pattern=pattern,
patch_height=patch_height,
patch_width=patch_width,
)

# Similar to how masked_tokens_encoded was computed
maskes_patches_original=patches[bitmask]

评估损失。

fromeinopsimport (rearrange,)
fromtorchimport (
randn,
zeros,
)
fromtorch.nnimport (
Linear,
Parameter,
)
fromtorch.nn.functionalimport (l1_loss,)

tokens=vit.patch_embed(input)

mask_token=torch.nn.Parameter(torch.randn(token_dim))

mask_tokens=self.mask_token.repeat(batch_size, n_tokens, 1)

indices_to_mask=randn(batch_size, n_tokens)

n_masked_tokens=int(0.5*n_tokens)

indices_to_mask=indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)

indices_to_mask=indices_to_mask.indices
bitmask=zeros(batch_size, n_tokens)

bitmask=bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask=bitmask.bool()

bitmask=bitmask.unsqueeze(2)

tokens= (~bitmask)*tokens+bitmask*mask_tokens

tokens=tokens+vit.pos_embed[:, 1:]

encoded=vit.blocks(tokens)

bitmask=bitmask.squeeze(2)

masked_tokens_encoded=encoded[bitmask]

patch_height=patch_width=vit.patch_embed.patch_size

decoder_out_dim=3*patch_height*patch_width
decoder=Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)

masked_patches_reconstructed=decoder(masked_tokens_encoded)

pattern= ('batch_size n_channels (n_patches_height patch_height) (n_patches_width patch_width) ->'
'batch_size (n_patches_height n_patches_width) (n_channels patch_height patch_width)'
  )

patches=einops.rearrange(
tensor=input,
pattern=pattern,
patch_height=patch_height,
patch_width=patch_width,
)

maskes_patches_original=patches[bitmask]

# The loss is the L1 difference between 
# the predicted pixel values and the ground truth,
# divided by the number of masked patches
loss=l1_loss(
input=masked_patches_reconstructed,
target=maskes_patches_original,
)/n_masked_tokens

把下面的代码封装成类并减少一些辅助函数,这里就不贴了有趣味的看下最初的源代码。而后应用的时候如下:

fromtimmimport (create_model,)
fromtorch.nn.functionalimport (l1_loss,)
fromtorch.optimimport (AdamW,)

vit=create_model(
'vit_small_patch32_224',
num_classes=0,
)
simmim=SimMIM(
vit=vit,
masking_ratio=0.5,
)
optimizer=AdamW(params=simmim.parameters(),
lr=1e-4,
weight_decay=5e-2,
)

forepochinrange(n_epochs):
forinputindataloader:
n_masked_tokens, masked_patches_reconstructed, masked_patches_original=simmim(input)

loss=l1_loss(
input=masked_patches_reconstructed,
target=maskes_patches_original,
)
loss/=n_masked_tokens
loss.backward()

optimizer.backward()
optimizer.zero_grad()

下面的代码能够配置各种超参数(例如学习率,应用余弦退火,但为简略起见,此处省略)。咱们的 dataloader 仅返回随机调整大小的裁剪、随机程度翻转和标准化的图像。

应用下面的代码,任何 VIT 都能够在大量未正文的数据上进行训练,并且能够很好地学习上游工作。看起来很简略吧,这也就是论文的名字 sample 的起源。

总结

在本文中,咱们介绍 SimMIM,这是一种受掩码建模启发的弱小 SSL 算法,其中一部分输出数据被掩码,模型的指标是最小化重建损失。为了更相熟模型的运行形式咱们还是用 Pytorch 对其进行了实现,这样能够帮忙咱们理解模型的细节。

援用:

A Simple Framework for Contrastive Learning of Visual Representations

Exploring Simple Siamese Representation Learning

SimMIM: A Simple Framework for Masked Image Modeling

本文代码:

https://avoid.overfit.cn/post/8729cb4ea9d5402db115b13ca80b3d9e

作者:Borna Ahmadzadeh

正文完
 0