关于算法:Vision-MLP之S2MLP-V1V2-SpatialShift-MLP

53次阅读

共计 9923 个字符,预计需要花费 25 分钟才能阅读完成。

Vision MLP 之 S2-MLP V1&V2 : Spatial-Shift MLP Architecture for Vision

原始文档:https://www.yuque.com/lart/pa…

这里将会总结对于 S2-MLP 的两篇文章。这两篇文章外围思路是一样的,即基于空间偏移操作替换空间 MLP。

从摘要了解文章

V1

Recently, visual Transformer (ViT) and its following works _abandon the convolution and exploit the self-attention operation_, attaining a comparable or even higher accuracy than CNNs. More recently, MLP-Mixer _abandons both the convolution and the self-attention operation_, proposing an architecture containing only MLP layers.
To achieve cross-patch communications, it devises an additional token-mixing MLP besides the channel-mixing MLP. It achieves promising results when training on an extremely large-scale dataset. _But it cannot achieve as outstanding performance as its CNN and ViT counterparts when training on medium-scale datasets such as ImageNet1K and ImageNet21K_. _The performance drop of MLP-Mixer motivates us to rethink the token-mixing MLP_.

这里引出了本文的次要内容,即改良空间 MLP。

We discover that the token-mixing MLP is a variant of the depthwise convolution with a global reception field and spatial-specific configuration. But _the global reception field and the spatial-specific property make token-mixing MLP prone to over-fitting_.

指出了空间 MLP 的问题,因为 其全局感触野和空间特定的属性使得模型容易过拟合

In this paper, we propose a novel pure MLP architecture, spatial-shift MLP (S2-MLP). Different from MLP-Mixer, our S2-MLP only contains channel-mixing MLP.

这里提到仅有通道 MLP,阐明想到了新的方法来扩张通道 MLP 的感触野还能够保留点运算。

We utilize a _spatial-shift operation for communications between patches_. It has a local reception field and is spatial-agnostic. It is parameter-free and efficient for computation.

引出本文的核心内容,也就是题目中提到的空间偏移操作。看上去这一操作不带参数,仅仅是用来调整特色的一个解决伎俩。
Spatial-Shift 操作能够参考这里的几篇文章:https://www.yuque.com/lart/architecture/conv#i8nnp

The proposed S2-MLP attains higher recognition accuracy than MLP-Mixer when training on ImageNet-1K dataset. Meanwhile, S2-MLP accomplishes as excellent performance as ViT on ImageNet-1K dataset with considerably _simpler architecture and fewer FLOPs and parameters_.

V2

Recently, MLP-based vision backbones emerge. MLP-based vision architectures with less inductive bias achieve competitive performance in image recognition compared with CNNs and vision Transformers. Among them, spatial-shift MLP (S2-MLP), adopting the straightforward spatial-shift operation, achieves better performance than the pioneering works including MLP-mixer and ResMLP. More recently, using smaller patches with a pyramid structure, Vision Permutator (ViP) and Global Filter Network (GFNet) achieve better performance than S2-MLP.

这里引出了金字塔构造,看来 V2 版本要应用相似的结构。

In this paper, we improve the S2-MLP vision backbone. We expand the feature map along the channel dimension and split the expanded feature map into several parts. We conduct different spatial-shift operations on split parts.

仍然连续了空间偏移的策略,然而不晓得相较于 V1 版本改变如何

Meanwhile, we _exploit the split-attention operation to fuse these split parts_.

这里还引入了 split-attention(ResNeSt)来交融分组。难道这里是要应用并行分支?

Moreover, like the counterparts, we adopt _smaller-scale patches and use a pyramid structure for boosting the image recognition accuracy_.
We term the improved spatial-shift MLP vision backbone as S2-MLPv2. Using 55M parameters, our medium-scale model, S2-MLPv2-Medium achieves an 83.6% top-1 accuracy on the ImageNet-1K benchmark using 224×224 images without self-attention and external training data.

在我看来,V2 相较于 V1,次要是借鉴了 CycleFC 的一些想法,并进行了适应性的调整。整体改变有两方面:

  1. 引入多分支解决的思维,并利用 Split-Attention 来交融不同分支。
  2. 受现有工作的启发,应用更小的 patch 和分层金字塔构造。

次要内容

外围构造比拟

V1 中,整体流程连续的是 MLP-Mixer 的思路,依然放弃直筒状构造。

MLP-Mixer 的结构图:

从图中能够看到,不同于 MLP-Mixer 中的 Pre-Norm 构造,S2MLP 应用的是 Post-Norm 构造。
另外,S2MLP 的改变次要集中在空间 MLP 的地位,由原来的 Spatial-MLP(Linear->GeLU->Linear) 转变为 Spatial-Shifted Channel-MLP(Linear->GeLU->Spatial-Shift->Lienar)
对于空间偏移的外围伪代码如下:

能够看到,这里就是将 输出划分成四个不同的分组,各自沿着不同的轴向(H 和 W 轴)偏移 ,因为实现的起因,在边界局部会有反复值呈现。 分组数依赖于方向的数量,这里默认应用 4,即向四个方向偏移。
尽管从单个空间偏移模块上来看,仅仅关联了相邻的 patch,然而从整体重叠后的构造来看,能够实现一个近似的长距离交互过程。

而在 V2 版本相较于 V1 版本引入了多分支解决的策略,并且在结构上开始应用 Pre-Norm 模式。

对于多分支构造的结构思路与 CycleFC 十分相似。不同支路应用不同的解决策略,同时在多分支整合时,应用了 Split-Attention 的形式进行交融。

Split-Attention: Vision Permutator (Hou et al., 2021) adopts split attention proposed in ResNeSt (Zhang et al., 2020) for enhancing multiple feature maps from different operations. 本文借鉴应用来交融多分支。
次要操作过程:

  1. 输出 $K$ 个特色图(能够来自不同分支)$\mathbf{X} = \{X_k \in \mathbb{R}^{N \times C}\}^{K}_{k=1}, \, N=HW$
  2. 将所有特诊图的列求和后的后果累加:$a \in \mathbb{R}^{C} = \sum_{k=1}^{K}\sum_{n=1}^{N}\mathbf{X}_{k}[n, :]$
  3. 通过重叠的全连贯层进行变换,失去针对不同特色图的通道注意力 logits:$\hat{a} \in \mathbb{R}^{KC} = \sigma(a W_1) W_2, \, W_1 \in \mathbb{R}^{C \times \bar{C}}, \, W_2 \in \mathbb{R}^{\bar{C} \times KC}$
  4. 应用 reshape 来调整注意力向量的形态:$\hat{a} \in \mathbb{R}^{KC} \rightarrow \hat{A} \in \mathbb{R}^{K \times C}$
  5. 应用 softmax 沿着索引 $k$ 计算,来取得针对不同样本的归一化注意力权重:$\bar{A}[:, c] \in \mathbb{R}^{K} = \text{softmax}(\hat{A}[:, c])$
  6. 对输出的 $K$ 个特色图加权求和失去后果 $Y$,其一行的后果能够示意为:$Y[n, :] \in \mathbb{R}^{C} = \sum_{k=1}^{K} X_{k}[n, :] \odot \bar{A}[k, :]$

不过须要留神的是,这里第三个分支是一个恒等分支,间接将输出的局部通道取了过去,这一点连续了 GhostNet 的想法,而不同于 CycleFC,应用的是一个独立的通道 MLP。

GhostNet 的外围构造:

对于该多分支构造的外围伪代码如下:

其余细节

Spatial-Shift 与 Depthwise Convolution 的关系

实际上,四个方向的偏移都是能够通过特定的卷积核结构来实现的:

所以分组空间偏移操作能够通过为 Depthwise Convolution 的不同分组指定对应下面的卷积核来实现。

实际上实现偏移的办法十分多,除了文中提到的切片索引和结构核的 depthwise convolution 的形式,还能够通过分组 torch.roll 和自定义 offset 的 deform_conv2d 来实现。

import torch
import torch.nn.functional as F
from torchvision.ops import deform_conv2d

xs = torch.meshgrid(torch.arange(5), torch.arange(5))
x = torch.stack(xs, dim=0)
x = x.unsqueeze(0).repeat(1, 4, 1, 1).float()

direct_shift = torch.clone(x)
direct_shift[:, 0:2, :, 1:] = torch.clone(direct_shift[:, 0:2, :, :4])
direct_shift[:, 2:4, :, :4] = torch.clone(direct_shift[:, 2:4, :, 1:])
direct_shift[:, 4:6, 1:, :] = torch.clone(direct_shift[:, 4:6, :4, :])
direct_shift[:, 6:8, :4, :] = torch.clone(direct_shift[:, 6:8, 1:, :])
print(direct_shift)

pad_x = F.pad(x, pad=[1, 1, 1, 1], mode="replicate")  # 这里须要借助 padding 来保留边界的数据

roll_shift = torch.cat(
    [torch.roll(pad_x[:, c * 2 : (c + 1) * 2, ...], shifts=(shift_h, shift_w), dims=(2, 3))
        for c, (shift_h, shift_w) in enumerate([(0, 1), (0, -1), (1, 0), (-1, 0)])
    ],
    dim=1,
)
roll_shift = roll_shift[..., 1:6, 1:6]
print(roll_shift)

k1 = torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k2 = torch.FloatTensor([[0, 0, 0], [0, 0, 1], [0, 0, 0]]).reshape(1, 1, 3, 3)
k3 = torch.FloatTensor([[0, 1, 0], [0, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k4 = torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 1, 0]]).reshape(1, 1, 3, 3)
weight = torch.cat([k1, k1, k2, k2, k3, k3, k4, k4], dim=0)  # 每个输入通道对应一个输出通道
conv_shift = F.conv2d(pad_x, weight=weight, groups=8)
print(conv_shift)

offset = torch.empty(1, 2 * 8 * 1 * 1, 1, 1)
for c, (rel_offset_h, rel_offset_w) in enumerate([(0, -1), (0, -1), (0, 1), (0, 1), (-1, 0), (-1, 0), (1, 0), (1, 0)]):
    offset[0, c * 2 + 0, 0, 0] = rel_offset_h
    offset[0, c * 2 + 1, 0, 0] = rel_offset_w
offset = offset.repeat(1, 1, 7, 7).float()
weight = torch.eye(8).reshape(8, 8, 1, 1).float()
deconv_shift = deform_conv2d(pad_x, offset=offset, weight=weight)
deconv_shift = deconv_shift[..., 1:6, 1:6]
print(deconv_shift)

"""
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
"""

偏移方向的影响

试验是在 ImageNet 的子集上跑的。

V1 中针对不同的偏移方向进行了融化试验,这里的模型中都是依照方向个数对通道分组。从后果中能够看到:

  • 偏移的确能够带来性能增益。
  • a 和 b:四个方向和八个方向相比,差别并不大。
  • e 和 f:程度偏移成果更好。
  • c 和 e/f:两个轴的偏移要好于单个轴的偏移。

输出尺寸以及 patchsize 的影响

试验是在 ImageNet 的子集上跑的。

V1 中在固定 patchsize 后,不同的输出尺寸 WxH 的体现也不同。过大的 patchsize 成果也不好,会失落更多的细节信息,然而却能够无效晋升推理速度。

金字塔构造的有效性

V2 中,结构了两个不同的构造,一个有着更小的 patch,并且应用金字塔构造,另一个更大的 patch,不应用金字塔构造。能够看到,同时受害于小 patchsize 带来的细节信息的性能加强和金字塔构造带来的更优的计算效率,前者取得了更好的体现。

Split-Attention 的成果

V2 将 split-attention 与特色间接相加取均匀比照。能够看到,前者更优。不过这里参数量也不一样了,其实更正当的比拟应该最起码是加几层带参数的构造来交融三分支的特色。

三分支构造的有效性

这里的试验阐明有些含糊,作者说道“In this section, we evaluate the influence of removing one of them.”然而却没有阐明去掉特定分支后其余构造的调整形式。

试验后果

试验后果间接看 V2 论文的表格即可:

链接

  • 论文:

    • V1:https://arxiv.org/pdf/2106.07477.pdf
    • V2:https://arxiv.org/pdf/2108.01072.pdf
  • 参考代码:

    • CycleFC 的代码有能够借鉴之处: https://github.com/ShoufaChen/CycleMLP/blob/main/cycle_mlp.py

正文完
 0