关于深度学习:EasyCV带你复现更好更快的自监督算法FastConvMAE

3次阅读

共计 11099 个字符,预计需要花费 28 分钟才能阅读完成。

作者:夕陌、谦言、莫申童、临在

导读

自监督学习(Self-Supervised Learning)利用大量无标注的数据进行表征学习,在特定上游工作上对参数进行微调,极大升高了图像工作沉重的标注工作,节俭大量人力老本。近年来,自监督学习在视觉畛域大放异彩,受到了越来越多的关注。在 CV 畛域涌现了如 SIMCLR、MOCO、SwAV、DINO、MoBY、MAE 等一系列工作。其中 MAE 的体现尤为惊艳,大家都被 MAE 简洁高效的性能所吸引,纷纷在 MAE 上进行改良,例如 MixMIM,VideoMAE 等工作。MAE 详解请参考往期文章:MAE 自监督算法介绍和基于 EasyCV 的复现。

ConvMAE 是由上海人工智能实验室和 mmlab 联结发表在 NeurIPS2022 的一项工作,与 MAE 相比,训练雷同的 epoch 数,ImageNet-1K 数据集的 finetune 准确率进步了 1.4%,COCO2017 数据集上微调 25 个 epoch 相比微调 100 个 epoch 的 MAE AP box 晋升 2.9,AP mask 晋升 2.2,语义宰割工作上相比 MAE mIOU 晋升 3.6%。在此基础上,作者提出了 FastConvMAE,进一步优化了训练性能,仅预训练 50 个 epoch,ImageNet Finetuning 的精度就超过 MAE 预训练 1600 个 epoch 的精度 0.77 个点(83.6/84.37)。在检测工作上,精度也超过 ViTDet 和 Swin。

EasyCV 是阿里巴巴开源的基于 Pytorch,以自监督学习和 Transformer 技术为外围的 all-in-one 视觉算法建模工具,笼罩支流的视觉建模工作例如图像分类,度量学习,指标检测,实例 / 语音 / 全景宰割、关键点检测等畛域,具备较强的易用性和扩展性,同时重视性能调优,旨在为社区带来更多更快更强的算法。

近期 FastConvMAE 工作在 EasyCV 框架内首次对外开源,本文将重点介绍 ConvMAE 和 FastConvMAE 的次要工作,以及对应的代码实现,最初提供具体的教程示例如何进行 FastConvMAE 的预训练和上游工作的 finetune。

ConvMAE

ConvMAE 是由上海人工智能实验室和 mmlab 联结发表在 NeurIPS2022 里的一项工作,ConvMAE 的提出证实了应用部分演绎偏置和多尺度的金字塔构造,通过 MAE 的训练形式能够学习到更好的特色示意。该工作提出:

  1. 应用 block-wise mask 策略来确保计算效率。
  2. 输入编码器的多尺度特色,同时捕捉细粒度和粗粒度图像信息。

原文参考:https://arxiv.org/abs/2205.03892

试验结果显示,上述两项策略是简洁而无效的,使得 ConvMAE 在多个视觉工作中相比 MAE 取得了显著晋升。以 ConvMAE-Base 和 MAE-Base 相比为例:在图像分类工作上,ImageNet-1K 数据集的微调准确率进步了 1.4%;在指标检测工作上,COCO2017 微调 25 个 epoch 的 AP box 达到 53.2%,AP mask 达到 47.1%,与微调 100 个 epoch 的 MAE-Base 相比别离晋升 2.9% 和 2.2%;在语义宰割工作上,应用 UperNet 网络头,ConvMAE-Base 在 ADE20K 上的 mIoU 达到 51.7%,相比 MAE-Base 晋升 3.6%。

与 MAE 不同的是,ConvMAE 的编码器将输出图像逐渐形象为多尺度 token embedding,而解码器则重建被 mask 掉的 tokens 对应的像素。对于后面 stage 局部的高分辨率 token embedding,采纳卷积块对部分进行编码,对于前面的低分辨率 token embedding,则应用 transformer 来聚合全局信息。因而,ConvMAE 的编码器在不同阶段能够同时取得部分和全局信息,并生成多尺度特色。

以后的 masked auto encoding 框架,如 BEiT,SimMIM,所采纳的 mask 策略不能间接用于 ConvMAE,因为在前面的 transformer 阶段,所有的 tokens 都须要保留。这导致对大模型进行预训练的计算成本过高,失去了 MAE 在 transformer 编码器中省去 masked tokens 的效率劣势。此外,间接应用 convolution-transformer 构造的编码器进行预训练会导致卷积局部因为随机的 mask 而造成预训练的信息泄露,因此也会升高预训练所得模型的品质。

针对这些问题,ConvMAE 提出了混合 convolution-transformer 架构。ConvMAE 采纳分块 mask 策略 (block-wise masking strategy):,首先随机在前期的获取 transformer token 中生成前期的 mask,而后对 mask 固定地位逐渐进行上采样到晚期卷积阶段的高分辨率。这样,前期解决的 token 能够齐全拆散为 masked tokens 和 visible tokens,从而并继承了 MAE 应用稠密 encoder 的计算效率。

上面将别离针对 encoder、mask 策略以及 decoder 局部开展介绍。

Encoder

如总体流程图所示,encoder 包含 3 个阶段,每个阶段输入的特色维度别离是:H/4 × W/4, H/8 × W/8, H/16 × W/16,其中 H × W 为输出图像分辨率。前两个是卷积阶段,应用卷积模块将输出转换为 token embeddings E1 ∈ R^(H/4 × W/4 ×C1) and E2 ∈ R^(H/8 × W/8 ×C2)。其中卷积模块用 5 × 5 的卷积代替 self-attention 操作。前两个阶段的感触野较小次要捕获图像的部分特色,第三个阶段应用 transformer 模块,将粗粒度特色交融, 并将感触野扩大到整个图像,取得 token embeddings E3 ∈ R(H/16 × W/16 ×C3)。在每个阶段之间,应用 stride 为 2 的卷积对 tokens 进行下采样。

其余蕴含 transformer 的构造,如 CPT、Container、Uniformer、CMT、Swin 等,在第一阶段的输出用绝对地位编码或零填充卷积代替相对地位编码,而作者发现在第 3 个 transformer stage 中应用相对地位编码可取得最优性能。class token 也从编码器中移除。

Mask 策略

MAE、BEiT 等,对输出 patch 采纳随机 mask。但同样的策略不能间接利用于 ConvMAE 编码器:如果独立地从 stage- 1 的 H /4 × W/ 4 个 tokens 中随机抽取 mask,将导致降采样后的 stage- 3 的简直所有 token 都有局部可见信息,使得编码器不再稠密。因而作者提出,从 stage- 3 的输出 tokens 中以同样比例(例如 75%)生成 mask,再对 mask 上采样 2 倍和 4 倍,别离作为 stage- 2 和 stage- 1 的 mask。这样,ConvMAE 在 3 个阶段都只含有很少的(例如 25%)可见 token,从而使得预训练时编码器的效率不受影响。而解码器的工作 e 则放弃雷同,即重建编码过程中被 mask 掉的 tokens。

同时,前 2 个阶段的 5X5 卷积操作会在 masked patches 的边缘处透露不可见 token 的重建答案。为了防止这种状况保障预训练的品质,作者在前两个阶段采纳了 masked convolution, 使被 mask 掉的区域不参加编码过程。

Decoder

原始 MAE 的 decoder 的输出以编码器的输入和 mask 掉的 tokens 作为输出,而后通过重叠的 transformer blocks 进行图像重建。ConvMAE 编码器取得多尺度特色 E1、E2、E3,同时捕捉细粒度和粗粒度图像信息。为了更好地的预训练,作者通过 stride- 4 和 stride- 2 卷积将 E1 和 E2 下采样到 E3 的雷同大小,并进行多尺度特色交融,再通过一个 linear 层失去最终要输出给 decoder 的可见 token。指标函数和 MAE 雷同,仅采纳 MSE 作为损失函数,计算预测向量和被 mask 掉像素值之前的 MSE loss,即只思考 mask 掉的 patches 的重建。

上游工作

预训练之后,ConvMAE 能够输入多尺度的特色用于检测宰割工作。

检测工作中,先将第 stage- 3 的输入特色 E3 通过 2 ×2 最大池化取得 E4。因为 ConvMAE stage- 3 有 11 个 self-attention 层(ConvMAE-base),计算成本过高,作者参考 ViT 的 benchmark 将 stage- 3 中除第 1、4、7、11 之外的所有 global self-attention layers 替换为了 Window size7×7 的  local self-attention 层。批改后的 local self-attention 依然由预训练的 global self-attention 进行初始化。global transformer blocks 之间共享 global relative position bias,local transformer blocks 之间共享 local relative position bias,这样就大大加重了 stage- 3 的计算和 GPU 内存开销。而后将多尺度特色 E1、E2、E3、E4 送入 MaskRCNN head 进行指标检测。

而宰割工作保留了 stage- 3 的构造。

Benchmark

图像分类

ConvMAE 基于 ImageNet-1K,mask 掉 25% 的 input token 做预训练,Decoder 局部是一个 8 层的 transformer,embedding 维度是 512,head 是 12 个。预训练参数和分类 finetuning 后果如下:

BEiT 预训练 300 个 epoch,finetune 的精度达到 83.0%,linear-prob 的精度是 37.6%。与 BEiT 相比,ConVMAE 仅须要 25% 的 token 和一个轻量级的 decoder finetune 可达到 85%,linear-prob 能够达到 70.9%。与原来的 MAE 相比,预训练雷同的 1600 个 epoch,ConVMAE 比 MAE 晋升 1.4 个点。与 SimMIM(backbone 应用 Swin-B)相比晋升了 1 个点。

检测

作者用 ConvMAE 替换 Mask-RCNN 的 backbone,加载 ConvMAE 的预训练模型训练 COCO 数据集。

与 ViT 在 COCO 数据集上 finetune100 个 epoch 的后果相比,ConVMAE 仅 finetune 25 个 epoch 在 APbox 和 APmask 就晋升了 2.9 和 2.2 个点。

与 ViTDet 和 MIMDet 相比,ConvMAE finetune epoch 更少、参数更少,别离超过了它们 2.0% 和 1.7%。

与 Swin 和 MViTv2 相比,在 APbox/APmask,其性能别离高出 4.0%/3.6% 和 2.2%/1.4%。

宰割

作者用 ConvMAE 替换 UperNet 的 backbone,加载 ConvMAE 的预训练模型训练 ADE20K 数据集。

从后果中能够看出,相比与 DeiT, Swin,MoCo-v3 等网络 ConvMAE 获得了更高的性能(51.7%)。表明 ConvMAE 的多尺度特色大大放大了预训练 Backbone 和上游网络之间的传输差距。

Fast ConvMAE

ConvMAE 尽管在分类、检测、宰割等上游工作中有了精度晋升,并解决了 pretraining-finetuning 的差别问题,然而模型的预训练仍然耗时,ConvMAE 的后果中,模型预训练了 1600 个 epoch,因而作者又在 ConvMAE 的根底之上做了进一步的性能优化,提出了 Fast ConvMAE,FastConvMAE 提出了 mask 互补和 deocder 交融 的计划,来实现疾速的 mask 建模计划,进一步缩短了预训练的工夫,从原来预训练的 1600epoch 缩短到了 50epoch。 FastConvMAE 的正式论文作者会在将来收回。

首先,FastConvMAE 翻新地设计出 decoder 相互交融的 Mixture of Reconstructor (MoR),能够让 masked patches 从不同的 tokenizer 中学习到互补的信息,包含 EMA 的 self-ensembling 性质,DINO 的 similarity-discrimination 能力, 以及 CLIP 的 multimodal 常识。MoR 次要包含两个局部,Partially-Shared Decoder(PS-Decoder)和 Mixture of Tokenizer(MoT), PS-Decoder 能够防止不同 tokenizer 的不同常识之间会产生梯度的抵触,MoT 是用来生成不同的 token 作为 masked patches 的 target。

同时 Mask 局部采纳了互补策略,原来的 mask 每次只会保留例如 25% 的 tokens,FastConvMAE 将 mask 分成了 4 份,每一份都保留 25%,4 份 mask 之间互补。这样,相当于 1 张图片被分成了 4 张图片进行学习,实践上达到了 4 倍的学习效果。

    def random_masking(self, x, mask_ratio=None):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N = x.shape[0]
        L = self.num_patches
        len_keep = int(L * (1 - mask_ratio))
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep1 = ids_shuffle[:, :len_keep]
        ids_keep2 = ids_shuffle[:, len_keep:2 * len_keep]
        ids_keep3 = ids_shuffle[:, 2 * len_keep:3 * len_keep]
        ids_keep4 = ids_shuffle[:, 3 * len_keep:]

        # generate the binary mask: 0 is keep, 1 is remove
        mask1 = torch.ones([N, L], device=x.device)
        mask1[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask1 = torch.gather(mask1, dim=1, index=ids_restore)

        mask2 = torch.ones([N, L], device=x.device)
        mask2[:, len_keep:2 * len_keep] = 0
        # unshuffle to get the binary mask
        mask2 = torch.gather(mask2, dim=1, index=ids_restore)

        mask3 = torch.ones([N, L], device=x.device)
        mask3[:, 2 * len_keep:3 * len_keep] = 0
        # unshuffle to get the binary mask
        mask3 = torch.gather(mask3, dim=1, index=ids_restore)

        mask4 = torch.ones([N, L], device=x.device)
        mask4[:, 3 * len_keep:4 * len_keep] = 0
        # unshuffle to get the binary mask
        mask4 = torch.gather(mask4, dim=1, index=ids_restore)

        return [ids_keep1, ids_keep2, ids_keep3,
                ids_keep4], [mask1, mask2, mask3, mask4], ids_restore

前两个卷积阶段将输出转换为 embeddings tokens E1 和 E2。而后 E1 和 E2 别离从 4 份 mask 中获取 4 份可见的 tokens 并进行拼接,作为 decoder 的输出,Decoder 解决的是拼接后的 tokens。代码参考如下:

   def encoder_forward(self, x, mask_ratio):
        # embed patches
        ids_keep, masks, ids_restore = self.random_masking(x, mask_ratio)
        mask_for_patch1 = [1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(1, 1, 1, 16).reshape(-1, 14, 14, 4, 4).permute(0, 1, 3, 2, 4).reshape(x.shape[0], 56, 56).unsqueeze(1)
          for mask in masks
        ]
        mask_for_patch2 = [1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(1, 1, 1, 4).reshape(-1, 14, 14, 2, 2).permute(0, 1, 3, 2, 4).reshape(x.shape[0], 28, 28).unsqueeze(1)
          for mask in masks
        ]

        s1 = self.patch_embed1(x)
        s1 = self.pos_drop(s1)
        for blk in self.blocks1:
            s1 = blk(s1, mask_for_patch1)

        s2 = self.patch_embed2(s1)
        for blk in self.blocks2:
            s2 = blk(s2, mask_for_patch2)

        stage1_embed = self.stage1_output_decode(s1).flatten(2).permute(0, 2, 1)
        stage2_embed = self.stage2_output_decode(s2).flatten(2).permute(0, 2, 1)
        stage1_embed_1 = torch.gather(
          stage1_embed,
          dim=1,
          index=ids_keep[0].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1]))
        stage2_embed_1 = torch.gather(
          stage2_embed,
          dim=1,
          index=ids_keep[0].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1]))
        stage1_embed_2 = torch.gather(
          stage1_embed,
          dim=1,
          index=ids_keep[1].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1]))
        stage2_embed_2 = torch.gather(
          stage2_embed,
          dim=1,
          index=ids_keep[1].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1]))
        stage1_embed_3 = torch.gather(
          stage1_embed,
          dim=1,
          index=ids_keep[2].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1]))
        stage2_embed_3 = torch.gather(
          stage2_embed,
          dim=1,
          index=ids_keep[2].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1]))
        stage1_embed_4 = torch.gather(
          stage1_embed,
          dim=1,
          index=ids_keep[3].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1]))
        stage2_embed_4 = torch.gather(
          stage2_embed,
          dim=1,
          index=ids_keep[3].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1]))
        stage1_embed = torch.cat([stage1_embed_1, stage1_embed_2, stage1_embed_3, stage1_embed_4])
        stage2_embed = torch.cat([stage2_embed_1, stage2_embed_2, stage2_embed_3, stage2_embed_4])

        x = self.patch_embed3(s2)
        x = x.flatten(2).permute(0, 2, 1)
        x = self.patch_embed4(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed

        x1 = torch.gather(x, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
        x2 = torch.gather(x, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
        x3 = torch.gather(x, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
        x4 = torch.gather(x, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
        x = torch.cat([x1, x2, x3, x4])

        # apply Transformer blocks
        for blk in self.blocks3:
            x = blk(x)

        x = x + stage1_embed + stage2_embed
        x = self.norm(x)
        mask = torch.cat([masks[0], masks[1], masks[2], masks[3]])
        return x, mask, ids_restore

Benchmark

EasyCV 复现的后果如下:

ImageNet Pretrained

Config Epochs Download
fast_convmae_vit_base_patch16_8xb64_50e 50 model- log

ImageNet Finetuning

Algorithm Fintune Config Pretrained Config Top-1 Download
Fast ConvMAE(EasyCV) fast_convmae_vit_base_patch16_8xb64_100e_fintune fast_convmae_vit_base_patch16_8xb64_50e 84.4% fintune model- log
Fast ConvMAE(官网) 84.4%

Object Detection

Algorithm Eval Config Pretrained Config mAP (Box) mAP (Mask) Download
Fast ConvMAE(EasyCV) mask_rcnn_conv_vitdet_50e_coco fast_convmae_vit_base_patch16_8xb64_50e 51.3% 45.6% finetune model
Fast ConvMAE(官网) 51.0% 45.4%

从后果能够看出,仅预训练 50 个 epoch,ImageNet Finetuning 的精度就超过 MAE 预训练 1600 个 epoch 的精度 0.77 个点(83.6/84.37)。在检测工作上,精度也超过 ViTDet 和 Swin。

FastConvMAE 的更多官网后果请参考:https://github.com/Alpha-VL/F…。

Tutorial

一、装置依赖包

如果是在本地开发环境运行,能够参考该链接装置环境。若应用 PAI-DSW 进行试验则无需装置相干依赖,在 PAI-DSW docker 中已内置相干环境。

二、数据筹备

数据筹备请参考文档:https://github.com/alibaba/Ea…

三、模型预训练

FastConvMAE 占用显存较大,倡议应用 A100 资源。(FastConvMAE 一次 forward-backward 等价于 ConvMAE forward-backward 4 次)

在 EasyCV 中,应用配置文件的模式来实现对模型参数、数据输出及增广形式、训练策略的配置,仅通过批改配置文件中的参数设置,就能够实现试验配置进行训练。

配置 EasyCV 门路

# 查看 easycv 装置地位
import easycv
print(easycv.__file__)
$ export PYTHONPATH=$PYTHONPATH:${your EasyCV root path}

训练

$ python -m torch.distributed.launch --nproc_per_node=8 --master_port=29930 \
tools/train.py \
configs/selfsup/fast_convmae/fast_convmae_vit_base_patch16_8xb64_50e.py \
--work_dir ./work_dir \
--launcher pytorch

上游工作 finetune

下载预训练模型

$ wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/FastConvMAE/pretrained/epoch_50.pth
  • 单卡
$ python tools/train.py \
${CONFIG_FILE} \
--work_dir ./work_dir \
--load_from=./epoch_50.pth
  • 多卡
$ python -m torch.distributed.launch --nproc_per_node=8 --master_port=29930 \
tools/train.py \
${CONFIG_FILE} \
--work_dir ./work_dir \
--launcher pytorch \
--load_from=./epoch_50.pth

分类工作 CONFIG_FILE 请参考:https://github.com/alibaba/Ea…

分类工作 CONFIG_FILE 请参考:https://github.com/alibaba/Ea…

Reference

EasyCV:https://github.com/alibaba/Ea…

EasyCV 往期分享

  • 基于 EasyCV 复现 DETR 和 DAB-DETR,Object Query 的正确打开方式
  • 基于 EasyCV 复现 ViTDet:单层特色超过 FPN
  • MAE 自监督算法介绍和基于 EasyCV 的复现
  • EasyCV 开源|开箱即用的视觉自监督 +Transformer 算法库
正文完
 0