关于机器学习:用于Transformer的6种注意力的数学原理和代码实现

56次阅读

共计 9772 个字符,预计需要花费 25 分钟才能阅读完成。

Transformer 的杰出体现让注意力机制呈现在深度学习的各处。本文整顿了深度学习中最罕用的 6 种注意力机制的数学原理和代码实现。

1、Full Attention

2017 的《Attention is All You Need》中的编码器 - 解码器构造实现中提出。它构造并不简单,所以不难理解。

上图 1. 左侧显示了 Scaled Dot-Product Attention 的机制。当咱们有多个注意力时,咱们称之为多头注意力(右),这也是最常见的注意力的模式公式如下:

公式 1

这里 Q(Query)、K(Key)和 V(values)被认为是它的输出,dₖ(输出维度)被用来升高复杂度和计算成本。这个公式能够说是深度学习中注意力机制倒退的开始。上面咱们看一下它的代码:

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)
        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)
        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)
        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

2、ProbSparse Attention

借助“Transformer Dissection: A Unified Understanding of Transformer’s Attention via the lens of Kernel”中的信息咱们能够将公式批改为上面的公式 2。第 i 个 query 的 attention 就被定义为一个概率模式的核平滑办法(kernel smoother):

公式 2

从公式 2,咱们能够定义第 i 个查问的稠密度测量如下:

最初,注意力块的最终公式是上面的公式 4。

代码如下:

class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
    def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape
        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]
        # use the reduced Q to calculate Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
        return Q_K, M_top
    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape
        if not self.mask_flag:
            # V_sum = V.sum(dim=-2)
            V_sum = V.mean(dim=-2)
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else: # use mask
            assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
            contex = V.cumsum(dim=-2)
        return contex
    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape
        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)

        context_in[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   index, :] = torch.matmul(attn, V).type_as(context_in)
        if self.output_attention:
            attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
            attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
            return (context_in, attns)
        else:
            return (context_in, None)
    def forward(self, queries, keys, values, attn_mask):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2,1)
        keys = keys.transpose(2,1)
        values = values.transpose(2,1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
        u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 

        U_part = U_part if U_part<L_K else L_K
        u = u if u<L_Q else L_Q
        
        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 

        # add scale factor
        scale = self.scale or 1./sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        # get the context
        context = self._get_initial_context(values, L_Q)
        # update the context with selected top_k queries
        context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
        
        return context.transpose(2,1).contiguous(), attn

这也是 Informer 这个用于长序列工夫序列预测的新型 Transformer 中应用的注意力。

3、LogSparse Attention

咱们之前探讨的注意力有两个毛病:1. 与地位无关 2. 内存的瓶颈。为了应答这两个问题,钻研人员应用了卷积算子和 LogSparse Transformers。

Transformer 中相邻层之间不同注意力机制的图示

卷积自注意力显示在(右)中,它应用步长为 1,内核大小为 k 的卷积层将输出(具备适当的填充)转换为 Q /K。这种地位感知能够依据(左)中的形态正确匹配最相干的特色

他们不是应用步长为 1,卷积核大小 1,而是应用步长为 1,核大小为 k 的随便卷积(以确保模型无法访问将来的点)将输出转换为 Q 和 K

代码实现

class Attention(nn.Module):
    def __init__(self, n_head, n_embd, win_len, scale, q_len, sub_len, sparse=None, attn_pdrop=0.1, resid_pdrop=0.1):
        super(Attention, self).__init__()

        if(sparse):
            print('Activate log sparse!')
            mask = self.log_mask(win_len, sub_len)
        else:
            mask = torch.tril(torch.ones(win_len, win_len)).view(1, 1, win_len, win_len)

        self.register_buffer('mask_tri', mask)
        self.n_head = n_head
        self.split_size = n_embd * self.n_head
        self.scale = scale
        self.q_len = q_len
        self.query_key = nn.Conv1d(n_embd, n_embd * n_head * 2, self.q_len)
        self.value = Conv1D(n_embd * n_head, 1, n_embd)
        self.c_proj = Conv1D(n_embd, 1, n_embd * self.n_head)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)

    def log_mask(self, win_len, sub_len):
        mask = torch.zeros((win_len, win_len), dtype=torch.float)
        for i in range(win_len):
            mask[i] = self.row_mask(i, sub_len, win_len)
        return mask.view(1, 1, mask.size(0), mask.size(1))

    def row_mask(self, index, sub_len, win_len):
        """
        Remark:
        1 . Currently, dense matrices with sparse multiplication are not supported by Pytorch. Efficient implementation
            should deal with CUDA kernel, which we haven't implemented yet.
        2 . Our default setting here use Local attention and Restart attention.
        3 . For index-th row, if its past is smaller than the number of cells the last
            cell can attend, we can allow current cell to attend all past cells to fully
            utilize parallel computing in dense matrices with sparse multiplication."""
        log_l = math.ceil(np.log2(sub_len))
        mask = torch.zeros((win_len), dtype=torch.float)
        if((win_len // sub_len) * 2 * (log_l) > index):
            mask[:(index + 1)] = 1
        else:
            while(index >= 0):
                if((index - log_l + 1) < 0):
                    mask[:index] = 1
                    break
                mask[index - log_l + 1:(index + 1)] = 1  # Local attention
                for i in range(0, log_l):
                    new_index = index - log_l + 1 - 2**i
                    if((index - new_index) <= sub_len and new_index >= 0):
                        mask[new_index] = 1
                index -= sub_len
        return mask

    def attn(self, query: torch.Tensor, key, value: torch.Tensor, activation="Softmax"):
        activation = activation_dict[activation](dim=-1)
        pre_att = torch.matmul(query, key)
        if self.scale:
            pre_att = pre_att / math.sqrt(value.size(-1))
        mask = self.mask_tri[:, :, :pre_att.size(-2), :pre_att.size(-1)]
        pre_att = pre_att * mask + -1e9 * (1 - mask)
        pre_att = activation(pre_att)
        pre_att = self.attn_dropout(pre_att)
        attn = torch.matmul(pre_att, value)

        return attn

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def forward(self, x):

        value = self.value(x)
        qk_x = nn.functional.pad(x.permute(0, 2, 1), pad=(self.q_len - 1, 0))
        query_key = self.query_key(qk_x).permute(0, 2, 1)
        query, key = query_key.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        attn = self.attn(query, key, value)
        attn = self.merge_heads(attn)
        attn = self.c_proj(attn)
        attn = self.resid_dropout(attn)
        return attn

class Conv1D(nn.Module):
    def __init__(self, out_dim, rf, in_dim):
        super(Conv1D, self).__init__()
        self.rf = rf
        self.out_dim = out_dim
        if rf == 1:
            w = torch.empty(in_dim, out_dim)
            nn.init.normal_(w, std=0.02)
            self.w = Parameter(w)
            self.b = Parameter(torch.zeros(out_dim))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.rf == 1:
            size_out = x.size()[:-1] + (self.out_dim,)
            x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
            x = x.view(*size_out)
        else:
            raise NotImplementedError
        return x

来自:https://github.com/AIStream-P…

4、LSH Attention

Reformer 的论文抉择了部分敏感哈希的 angular 变体。它们首先束缚每个输出向量的 L2 范数 (行将向量投影到一个单位球面上),而后利用一系列的旋转,最初找到每个旋转向量所属的切片。这样一来就须要找到最近邻的值,这就须要部分敏感哈希(LSH)了,它可能疾速在高维空间中找到最近邻。一个部分敏感哈希算法能够将每个向量 x 转换为 hash h(x),和这个 x 凑近的哈希更有可能有着雷同的哈希值,而距离远的则不会。作者心愿最近的向量最可能失去雷同的哈希值,或者 hash-bucket 大小类似的更有可能雷同。

部分敏感哈希算法应用球投影点的随机旋转,通过 argmax 在有符号的轴投影上建设 bucket。在这个高度简化的 2D 形容中,对于三个不同的角哈希,两个点 x 和 y 不太可能共享雷同的哈希桶 (上图),除非它们的球面投影彼此靠近 (下图)。

通过固定一个大小为 [dₖ, b/2] 的随机矩阵 R 来取得 b 个哈希值。h(x) = argmax([xR;-xR]) 其中 [u;v] 示意两个向量的串联。这样就能够应用 LSH,将查问地位 I 带入重写公式 1:

下图咱们能够示意性地解释 LSH 注意力:

  • 原始的注意力矩阵通常是稠密的,但不利于计算
  • LSH Attention 基于哈希桶进行键的排序进行查问
  • 在排序后的留神矩阵中,来自同一桶的对将汇集在对角线左近
  • 采纳批处理办法,m 个间断查问的块互相解决,一个块返回。

代码很长为了节约工夫这里就不贴了:

https://github.com/lucidrains…

5、Sparse Attention(Generating Long Sequences with Sparse Transformers)

OpenAI 的 Sparse Attention,通过“只保留小区域内的数值、强制让大部分注意力为零”的形式,来缩小 Attention 的计算量。通过 top- k 抉择,将留神进化为稠密留神。这样,保留最有助于引起留神的局部,并删除其余无关的信息。这种选择性办法在保留重要信息和打消噪声方面是无效的。注意力能够更多地集中在最有奉献的价值因素上。

代码:https://github.com/openai/spa…

6、Single-Headed Attention(Single Headed Attention RNN: Stop Thinking With Your Head)

SHA-RNN 模型的注意力是简化到只保留了一个头并且惟一的矩阵乘法呈现在 query (下图 Q) 那里,A 是缩放点乘注意力 (Scaled Dot-Product Attention),是向量之间的运算。所以这种计算量比拟小,可能疾速的进行训练,就像它介绍的那样:

Obtain strong results on a byte level language modeling dataset (enwik8) in under 24 hours on a single GPU (12GB Titan V)

代码:https://github.com/Smerity/sh…

援用:

  1. Kitaev, N., Ł. Kaiser, and A. Levskaya, Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
  2. Li, S., et al., Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting. Advances in Neural Information Processing Systems, 2019. 32.
  3. Zhou, H., et al. Informer: Beyond efficient transformer for long sequence time-series forecasting. in Proceedings of AAAI. 2021.
  4. Vaswani, A., et al., Attention is all you need. Advances in neural information processing systems, 2017. 30.
  5. Rewon Child, Scott Gray, Alec Radford, Ilya Sutskever Generating Long Sequences with Sparse Transformers
  6. Stephen Merity,Single Headed Attention RNN: Stop Thinking With Your Head

留神机制的倒退到当初远远不止这些,在本篇文章中只整顿了一些常见的注意力机制,心愿对你有所帮忙。

另外就是来自 Erasmus University 的 Gianni Brauwers 和 Flavius Frasincar 在 TKDE 上发表的《A General Survey on Attention Mechanisms in Deep Learning》综述论文,提供了一个对于深度学习注意力机制的重要概述。各种注意力机制通过一个由注意力模型,对立符号和一个全面的分类注意力机制组成的框架来进行解释,还有注意力模型评估的各种办法。

有趣味和有资源的话能够进行浏览,女神 Alexandra Elbakyan 的网站还未提供该论文。

https://www.overfit.cn/post/739299d8be4e4ddc8f5804b37c6c82ad

作者:Reza Yazdanfar

正文完
 0