关于pytorch:带掩码的自编码器MAE详解和Pytorch代码实现

40次阅读

共计 9152 个字符,预计需要花费 23 分钟才能阅读完成。

监督学习是训练机器学习模型的传统办法,它在训练时每一个察看到的数据都须要有标注好的标签。如果咱们有一种训练机器学习模型的办法不须要收集标签,会怎么样? 如果咱们从收集的雷同数据中提取标签呢? 这种类型的学习算法被称为自监督学习。这种办法在自然语言解决中工作得很好。一个例子是 BERT¹,谷歌自 2019 年以来始终在其搜索引擎中应用 BERT¹。可怜的是,对于计算机视觉来说,状况并非如此。

Facebook AI 的 kaiming 大神等人提出了一种带掩码自编码器(MAE)²,它基于(ViT)³架构。他们的办法在 ImageNet 上的体现要好于从零开始训练的 VIT。在本文中,咱们将深入研究他们的办法,并理解如何在代码中实现它。

带掩码自编码器(MAE)

对输出图像的 patches 进行随机掩码,而后重建缺失的像素。MAE 基于两个外围设计。首先,开发了一个非对称的编码器 - 解码器架构,其中编码器仅对可见的 patches 子集 (没有掩码的 tokens) 进行操作,同时还有一个轻量级的解码器,能够从潜在示意和掩码 tokens 重建原始图像。其次,发现对输出图像进行高比例的掩码,例如 75%,会产生有意义的自监督工作。将这两种设计联合起来,可能高效地训练大型模型:放慢模型训练速度 (3 倍甚至更多) 并进步精度。

此阶段称为预训练,因为 MAE 模型稍后将用于上游工作,例如图像分类。模型在 pretext 上的体现在自监督中并不重要,这些工作的重点是让模型学习一个预期蕴含良好语义的两头示意。在预训练阶段之后,解码器将被多层感知器 (MLP) 头或线性层取代,作为分类器输入对上游工作的预测。

模型架构

编码器

编码器是 ViT。它承受张量形态为 (batch_size, RGB_channels, height, width) 的图像。通过执行线性投影为每个 Patch 取得嵌入,这是通过 2D 卷积层来实现。而后张量在最初一个维度被展平(压扁),变成 (batch_size, encoder_embed_dim, num_visible_patches),并 转置为形态(batch_size、num_visible_patches、encoder_embed_dim)的张量。

class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

正如原始 Transformer 论文中提到的,地位编码增加了无关每个 Patch 地位的信息。作者应用“sine-cosine”版本而不是可学习的地位嵌入。上面的这个实现是一维版本。

def get_sinusoid_encoding_table(n_position, d_hid): 
  
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 
    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

与 Transformer 相似,每个块由 norm 层、多头注意力模块和前馈层组成。两头输入形态是(batch_size、num_visible_patches、encoder_embed_dim)。多头注意力模块的代码如下:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = attn_head_dim if attn_head_dim is not None else dim // num_heads
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
        self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else None
        self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) if self.q_bias is not None else None
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        x = self.proj_drop(self.proj(x))
        return x

Transformer 模块的代码如下:

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), act_layer(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(attn_drop)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

这部分仅用于上游工作的微调。论文的模型遵循 ViT 架构,该架构具备用于分类的类令牌(patch)。因而,他们增加了一个虚构令牌,然而论文中也说到他们的办法在没有它的状况下也能够运行良好,因为对其余令牌执行了均匀池化操作。在这里也蕴含了实现的均匀池化版本。之后,增加一个线性层作为分类器。最终的张量形态是 (batch_size, num_classes)。

综上所述,编码器实现如下:

class Encoder(nn.Module)
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=0, **block_kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        # Patch embedding
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # Positional encoding
        self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList([Block(**block_kwargs) for i in range(depth)])  # various arguments are not shown here for brevity purposes
        self.norm =  norm_layer(embed_dim)
        
        # Classifier (for fine-tuning only)
        self.fc_norm = norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x, mask):
        x = self.patch_embed(x)
        x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
        B, _, C = x.shape
        if mask is not None:  # for pretraining only
            x = x[~mask].reshape(B, -1, C) # ~mask means visible
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        if self.num_classes > 0:  # for fine-tuning only
            x = self.fc_norm(x.mean(1))  # average pooling
            x = self.head(x)
        return x

解码器

与编码器相似,解码器由一系列 transformer 块组成。在解码器的末端,有一个由 norm 层和前馈层组成的分类器。输出张量的形态为 batch_size, num_patches,decoder_embed_dim) 而最终输入张量的形态为 (batch_size, num_patches, 3 patch_size * 2)。

class Decoder(nn.Module):
    def __init__(self, patch_size=16, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=768, **block_kwargs):
        super().__init__()
        self.num_classes = num_classes
        assert num_classes == 3 * patch_size ** 2
        self.num_features = self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.blocks = nn.ModuleList([Block(**block_kwargs) for i in range(depth)])  # various arguments are not shown here for brevity purposes
        self.norm =  norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x, return_token_num):
        for blk in self.blocks:
            x = blk(x)
        if return_token_num > 0:
            x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
        else:
            x = self.head(self.norm(x))
        return x

把所有货色放在一起——MAE 架构

MAE 用于对掩码图像进行预训练。首先,屏蔽的输出被发送到编码器。而后,它们被传递到前馈层以更改嵌入维度以匹配解码器。在传递给解码器之前,被掩码的 Patch 被输出进去。地位编码再次利用于残缺的图像块集,包含可见的和被掩码遮蔽的。

在论文中,作者对蕴含所有 Patch 的列表进行了打乱,以便正确插入 Patch 的掩码。这部分在本篇文章中没有实现,因为在 PyTorch 上实现并不简略。所以这里应用的是地位编码在被增加到 Patch 之前被相应地打乱的做法。

class MAE(nn.Module):
    def __init__(self, ...):  # various arguments are not shown here for brevity purposes
        super().__init__()
        self.encoder = Encoder(img_size, patch_size, in_chans, embed_dim, norm_layer, num_classes=0, **block_kwargs)
        self.decoder = Decoder(patch_size, embed_dim, norm_layer, num_classes, **block_kwargs)
        self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
    
    def forward(self, x, mask):
        x_vis = self.encoder(x, mask)
        x_vis = self.encoder_to_decoder(x_vis)
        B, N, C = x_vis.shape
        expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
        pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
        pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
        x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
        x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
        return x

训练过程

对于自监督预训练,论文发现简略的逐像素均匀相对损失作为指标函数成果很好。并且他们应用的数据集是 ImageNet-1K 训练集。

在上游的微调阶段,解码器被移除,编码器在雷同的数据集上进行训练。数据与预训练略有不同,因为编码器当初应用残缺的图像块集(没有屏蔽)。因而,当初的 Patch 数量与预训练阶段不同。

如果您你晓得用于预训练的模型是否依然能够用于微调,答案是必定的。编码器次要由注意力模块、norm 层和前馈层组成。要查看 Patch 数量(索引 1)的变动是否影响前向传递,咱们须要查看每一层的参数张量的形态。

  • norm 层中的参数的形态为(batch, 1, encoder_embed_dim)。它能够在前向流传期间沿着补丁维度(索引 1)进行播送,因而它不依赖于补丁维度的大小。
  • 前馈层有一个形态为 (in_channels, out_channels) 的权重矩阵和一个形态为 (out_channels,) 的偏置矩阵,两者都不依赖于 patch 的数量。
  • 注意力模块实质上执行一系列线性投影。因而,出于同样的起因,patch 的数量也不会影响参数张量的形态。

因为并行处理容许将数据分批输出,所以批处理中的 Patch 数量是须要保持一致的。

后果

让咱们看看原始论文中报道的预训练阶段的重建图像。看起来 MAE 在重建图像方面做得很好,即便 80% 的像素被遮蔽了。

ImageNet 验证图像的示例后果。从左到右: 遮蔽图像、重建图像、实在图像。掩蔽率为 80%。

MAE 在微调的上游工作上也体现良好,例如 ImageNet-1K 数据集上的图像分类。与监督形式相比,在应用 MAE 预训练进行训练时比应用的基线 ViT-Large 实际上体现更好。

论文中还包含对上游工作和各种融化钻研的迁徙学习试验的基准后果。有趣味的能够再看看原论文。

探讨

如果您相熟 BERT,您可能会留神到 BERT 和 MAE 的办法之间的相似之处。在 BERT 的预训练中,咱们遮蔽了一部分文本,模型的工作是预测它们。此外,因为咱们当初应用的是基于 Transformer 的架构,因而说这种办法在视觉上与 BERT 等效也不是不适合的。

然而论文中说这种办法早于 BERT。例如,过来对图像自监督的尝试应用重叠去噪自编码器和图像修复作为 pretext task。MAE 自身也应用主动编码器作为模型和相似于图像修复的 pretext task。

如果是这样的话,是什么让 MAE 工作比以前模型好呢?我认为关键在于 ViT 架构。在他们的论文中,作者提到卷积神经网络在将掩码标记和地位嵌入等“指标”集成到其中时存在问题,而 ViT 解决了这种架构差距。如果是这样,那么咱们将看到在自然语言解决中开发的另一个想法在计算机视觉中胜利实现。之前是 attention 机制,而后 Transformer 的概念以 Vision Transformers 的模式借用到计算机视觉中,当初是整个 BERT 预训练过程。

论断

我对将来自监督的视觉必须提供的货色感到兴奋。鉴于 BERT 在自然语言解决方面的胜利,像 MAE 这样的掩码建模办法将有益于计算机视觉。图像数据很容易取得,但标记它们可能很耗时。通过这种办法,人们能够通过治理比 ImageNet 大得多的数据集来扩大预训练过程,而无需放心标记。后劲是有限的。咱们是否会见证计算机视觉的另一次振兴,只有工夫能力证实。

援用

  1. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pretraining of deep bidirectional transformers for language understanding. In NAACL, 2019.
  2. Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross Girshick. Masked autoencoders are scalable vision learners. arXiv:2111.06377, 2021.
  3. Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16×16 words: Transformers for image recognition at scale. In ICLR, 2021.

作者:Stephen Lau

正文完
 0