咱们社区有新的技术分享小伙伴啦
热烈欢迎
作为一名合格的搬运工,我必须做点事件表白我的喜悦之情:搬运~搬运~立刻搬运~

文章起源 | 恒源云社区

原文地址 | 新的混合Transformer模块(MTM)

原文作者 | 咚咚


摘要

存在问题尽管U-Net在医学图像宰割方面获得了微小的胜利,但它不足对长期依赖关系进行显式建模的能力。视觉Transformer因为其固有的通过自留神(SA)捕获长程相关性的能力,近年来成为一种可代替的宰割构造。
存在问题然而,Transformer通常依赖于大规模的预训练,具备较高的计算复杂度。此外,SA只能在单个样本中建模self-affinities,疏忽了整个数据集的潜在相关性
论文办法提出了一种新的混合Transformer模块(MTM),用于同时进行inter-affinities学习和intra-affinities学习。MTM首先通过部分-全局高斯加权自留神(LGG-SA)无效地计算窗口外部affinities。而后,通过内部留神开掘数据样本之间的分割。利用MTM算法,结构了一种用于医学图像宰割的MT-UNet模型

Method


如图1所示。该网络基于编码器-解码器构造

  1. 为了升高计算成本,MTMs只对空间大小较小的深层应用,
  2. 浅层依然应用经典的卷积运算。这是因为浅层次要关注部分信息,蕴含更多高分辨率的细节。

MTM

如图2所示。MTM次要由LGG-SA和EA组成。

LGG-SA用于对不同粒度的短期和长期依赖进行建模,而EA用于开掘样本间的相关性。

该模块是为了代替原来的Transformer编码器,以进步其在视觉工作上的性能和升高工夫复杂度

LGG-SA(Local-Global Gaussian-Weighted Self-Attention)

传统的SA模块对所有tokens赋予雷同的关注度,而LGG -SA则不同,利用local-global自注意力和高斯mask使其能够更专一于邻近区域。试验证实,该办法能够进步模型的性能,节俭计算资源。该模块的具体设计如图3所示

local-global自注意力

在计算机视觉中,邻近区域之间的相关性往往比边远区域之间的相关性更重要,在计算留神图时,不须要为更远的区域破费雷同的代价。

因而,提出local-global自注意力

  1. 上图stage1中的每个部分窗口中含有四个token,local SA计算每个窗口内的外在affinities。
  2. 每个窗口中的token被aggregate聚合为一个全局token ,示意窗口的次要信息。对于聚合函数,轻量级动静卷积(Lightweight Dynamic convolution, LDConv)的性能最好。
  3. 在失去下采样的整个特色图后,能够以更少的开销执行global SA(上图stage2)。


其中\( X \in R^{H \times W \times C} \)

其中,stage1中的部分窗口自注意力代码如下:

class WinAttention(nn.Module):    def __init__(self, configs, dim):        super(WinAttention, self).__init__()        self.window_size = configs["win_size"]        self.attention = Attention(dim, configs)    def forward(self, x):        b, n, c = x.shape        h, w = int(np.sqrt(n)), int(np.sqrt(n))        x = x.permute(0, 2, 1).contiguous().view(b, c, h, w)        if h % self.window_size != 0:            right_size = h + self.window_size - h % self.window_size            new_x = torch.zeros((b, c, right_size, right_size))            new_x[:, :, 0:x.shape[2], 0:x.shape[3]] = x[:]            new_x[:, :, x.shape[2]:,                  x.shape[3]:] = x[:, :, (x.shape[2] - right_size):,                                   (x.shape[3] - right_size):]            x = new_x            b, c, h, w = x.shape        x = x.view(b, c, h // self.window_size, self.window_size,                   w // self.window_size, self.window_size)          x = x.permute(0, 2, 4, 3, 5,                      1).contiguous().view(b, h // self.window_size,                                           w // self.window_size,                                           self.window_size * self.window_size,                                           c).cuda()        x = self.attention(x)  #  (b, p, p, win, c) 对部分窗口内的tokens进行自注意力计算        return x

聚合函数代码如下

class DlightConv(nn.Module):    def __init__(self, dim, configs):        super(DlightConv, self).__init__()        self.linear = nn.Linear(dim, configs["win_size"] * configs["win_size"])        self.softmax = nn.Softmax(dim=-1)    def forward(self, x):  # (b, p, p, win, c)        h = x        avg_x = torch.mean(x, dim=-2)  # (b, p, p, c)        x_prob = self.softmax(self.linear(avg_x))  # (b, p, p, win)        x = torch.mul(h,                      x_prob.unsqueeze(-1))  # (b, p, p, win, c)         x = torch.sum(x, dim=-2)  # (b, p, p, c)        return x

Gaussian-Weighted Axial Attention

与应用原始SA的LSA不同,提出了高斯加权轴向留神(GWAA)的办法。GWAA通过一个可学习的高斯矩阵加强了相邻区域的感知全权重,同时因为具备轴向注意力而升高了工夫复杂度。

  1. 上图中stage2中特色图的第三行第三列特色进行linear projection失去\( q_{i, j} \)
  2. 将该特色点所在行和列的所有特色别离进行linear projection失去\( K_{i, j} \)
    和\( V_{i, j} \)
  3. 将该特色点与所有的K和V的欧式间隔定义为\( D_{i, j} \)

最终的高斯加权轴向注意力输入后果为

并简化为

轴向注意力代码如下:

class Attention(nn.Module):    def __init__(self, dim, configs, axial=False):        super(Attention, self).__init__()        self.axial = axial        self.dim = dim        self.num_head = configs["head"]        self.attention_head_size = int(self.dim / configs["head"])        self.all_head_size = self.num_head * self.attention_head_size        self.query_layer = nn.Linear(self.dim, self.all_head_size)        self.key_layer = nn.Linear(self.dim, self.all_head_size)        self.value_layer = nn.Linear(self.dim, self.all_head_size)        self.out = nn.Linear(self.dim, self.dim)        self.softmax = nn.Softmax(dim=-1)    def transpose_for_scores(self, x):        new_x_shape = x.size()[:-1] + (self.num_head, self.attention_head_size)        x = x.view(*new_x_shape)        return x    def forward(self, x):        # first row and col attention        if self.axial:             # x: (b, p, p, c)            # row attention (single head attention)            b, h, w, c = x.shape            mixed_query_layer = self.query_layer(x)            mixed_key_layer = self.key_layer(x)            mixed_value_layer = self.value_layer(x)            query_layer_x = mixed_query_layer.view(b * h, w, -1)            key_layer_x = mixed_key_layer.view(b * h, w, -1).transpose(-1, -2)  # (b*h, -1, w)            attention_scores_x = torch.matmul(query_layer_x,                                              key_layer_x)  # (b*h, w, w)            attention_scores_x = attention_scores_x.view(b, -1, w,                                                         w)  # (b, h, w, w)            # col attention  (single head attention)            query_layer_y = mixed_query_layer.permute(0, 2, 1,                                                      3).contiguous().view(                                                          b * w, h, -1)            key_layer_y = mixed_key_layer.permute(                0, 2, 1, 3).contiguous().view(b * w, h, -1).transpose(-1, -2)  # (b*w, -1, h)            attention_scores_y = torch.matmul(query_layer_y,                                              key_layer_y)  # (b*w, h, h)            attention_scores_y = attention_scores_y.view(b, -1, h,                                                         h)  # (b, w, h, h)            return attention_scores_x, attention_scores_y, mixed_value_layer        else:                      mixed_query_layer = self.query_layer(x)            mixed_key_layer = self.key_layer(x)            mixed_value_layer = self.value_layer(x)            query_layer = self.transpose_for_scores(mixed_query_layer).permute(                0, 1, 2, 4, 3, 5).contiguous()  # (b, p, p, head, n, c)            key_layer = self.transpose_for_scores(mixed_key_layer).permute(                0, 1, 2, 4, 3, 5).contiguous()            value_layer = self.transpose_for_scores(mixed_value_layer).permute(                0, 1, 2, 4, 3, 5).contiguous()            attention_scores = torch.matmul(query_layer,                                            key_layer.transpose(-1, -2))            attention_scores = attention_scores / math.sqrt(                self.attention_head_size)            atten_probs = self.softmax(attention_scores)            context_layer = torch.matmul(                atten_probs, value_layer)  # (b, p, p, head, win, h)            context_layer = context_layer.permute(0, 1, 2, 4, 3,                                                  5).contiguous()            new_context_layer_shape = context_layer.size()[:-2] + (                self.all_head_size, )            context_layer = context_layer.view(*new_context_layer_shape)            attention_output = self.out(context_layer)        return attention_output

高斯加权代码如下:

class GaussianTrans(nn.Module):    def __init__(self):        super(GaussianTrans, self).__init__()        self.bias = nn.Parameter(-torch.abs(torch.randn(1)))        self.shift = nn.Parameter(torch.abs(torch.randn(1)))        self.softmax = nn.Softmax(dim=-1)    def forward(self, x):         x, atten_x_full, atten_y_full, value_full = x  #x(b, h, w, c) atten_x_full(b, h, w, w)   atten_y_full(b, w, h, h) value_full(b, h, w, c)        new_value_full = torch.zeros_like(value_full)        for r in range(x.shape[1]):  # row            for c in range(x.shape[2]):  # col                atten_x = atten_x_full[:, r, c, :]  # (b, w)                atten_y = atten_y_full[:, c, r, :]  # (b, h)                dis_x = torch.tensor([(h - c)**2 for h in range(x.shape[2])                                      ]).cuda()  # (b, w)                dis_y = torch.tensor([(w - r)**2 for w in range(x.shape[1])                                      ]).cuda()  # (b, h)                dis_x = -(self.shift * dis_x + self.bias).cuda()                dis_y = -(self.shift * dis_y + self.bias).cuda()                atten_x = self.softmax(dis_x + atten_x)                atten_y = self.softmax(dis_y + atten_y)                new_value_full[:, r, c, :] = torch.sum(                    atten_x.unsqueeze(dim=-1) * value_full[:, r, :, :] +                    atten_y.unsqueeze(dim=-1) * value_full[:, :, c, :],                    dim=-2)        return new_value_full

local-global自注意力残缺代码如下:

class CSAttention(nn.Module):    def __init__(self, dim, configs):        super(CSAttention, self).__init__()        self.win_atten = WinAttention(configs, dim)        self.dlightconv = DlightConv(dim, configs)        self.global_atten = Attention(dim, configs, axial=True)        self.gaussiantrans = GaussianTrans()        #self.conv = nn.Conv2d(dim, dim, 3, padding=1)        #self.maxpool = nn.MaxPool2d(2)        self.up = nn.UpsamplingBilinear2d(scale_factor=4)        self.queeze = nn.Conv2d(2 * dim, dim, 1)    def forward(self, x):        '''        :param x: size(b, n, c)        :return:        '''        origin_size = x.shape        _, origin_h, origin_w, _ = origin_size[0], int(np.sqrt(            origin_size[1])), int(np.sqrt(origin_size[1])), origin_size[2]        x = self.win_atten(x)  # (b, p, p, win, c)        b, p, p, win, c = x.shape        h = x.view(b, p, p, int(np.sqrt(win)), int(np.sqrt(win)),                   c).permute(0, 1, 3, 2, 4, 5).contiguous()        h = h.view(b, p * int(np.sqrt(win)), p * int(np.sqrt(win)),                   c).permute(0, 3, 1, 2).contiguous()  # (b, c, h, w)        x = self.dlightconv(x)  # (b, p, p, c)        atten_x, atten_y, mixed_value = self.global_atten(            x)  # (b, h, w, w) (b, w, h, h) (b, h, w, c)这里的h w就是p        gaussian_input = (x, atten_x, atten_y, mixed_value)        x = self.gaussiantrans(gaussian_input)  # (b, h, w, c)        x = x.permute(0, 3, 1, 2).contiguous()  # (b, c, h, w)        x = self.up(x)        x = self.queeze(torch.cat((x, h), dim=1)).permute(0, 2, 3,                                                          1).contiguous()        x = x[:, :origin_h, :origin_w, :].contiguous()        x = x.view(b, -1, c)        return x
EA

内部留神(External Attention, EA),是用于解决SA无奈利用不同输出数据样本之间关系的问题。

与应用每个样本本人的线性变换来计算留神分数的自我留神不同,在EA中,所有的数据样本共享两个记忆单元MKMV(如图2所示),形容了整个数据集的最重要信息。

EA代码如下:

class MEAttention(nn.Module):    def __init__(self, dim, configs):        super(MEAttention, self).__init__()        self.num_heads = configs["head"]        self.coef = 4        self.query_liner = nn.Linear(dim, dim * self.coef)        self.num_heads = self.coef * self.num_heads        self.k = 256 // self.coef        self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)        self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)        self.proj = nn.Linear(dim * self.coef, dim)    def forward(self, x):        B, N, C = x.shape        x = self.query_liner(x)  # (b, n, 4c)        x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1,                                                     3)  #  (b, h, n, 4c/h)        attn = self.linear_0(x)  # (b, h, n, 256/4)        attn = attn.softmax(dim=-2)  # (b, h, 256/4)        attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))  # (b, h, 256/4)        x = self.linear_1(attn).permute(0, 2, 1, 3).reshape(B, N, -1)        x = self.proj(x)        return x

EXPERIMENTS