乐趣区

关于注意力:DeepMind-提出-Perceiver使用RNN的方式进行注意力通过交叉注意力节省计算量附使用方法

明天要解读的论文来自 DeepMind,论文名为《Perceiver: General Perception with Iterative Attention》,文中介绍了一种基于 Transformer 的构造,不对数据做任何假如,不须要批改网络结构,就能够利用于各种模态的数据。

咱们人在感知世界的时候,是通过同时解决各个模态的高维数据,而当初深度学习中应用的办法,都会引入很多畛域内的常识,比方当初简直所有的视觉办法,都引入了”局部性“的假如,即在一张图像内,部分的特色是有用的,这也是 CNN 有用的基本原理。引入这些有帮忙信息的同时,也将模型的作用范畴限度在了某一个模态以内。

在这篇论文中,作者提出了 Perceiver,它是一个基于 Transformer 的模型,简直没有做任何对于输出数据之间关系的结构性假如,然而也与 ConvNets 一样,能够扩大到数十万的输出上。

In this paper we introduce the Perceiver – a model that builds upon Transformers and hence makes few architectural assumptions about the relationship between its inputs, but that also scales to hundreds of thousands of inputs, like ConvNets.

作者提出的构造,达到甚至超过了精心设计用于某一个模态的模型的成果。在试验中,作者用了 ImageNet 的图像数据,AudioSet 的视频和音频数据,以及 3D 点云数据。

办法

应用了两个局部来构建网络:

  1. 应用穿插注意力机制(cross attention)来将一个输出向量(文中叫做 byte array)与一个隐向量映射为一个隐向量
  2. 应用 transformer 塔将一个隐向量映射为另一个同样大小的隐向量

输出向量的大小被输出数据所决定,这个个别会很大,例如一张 ImageNet 中的图像,有 224*224 维,也就是 50176 维。而隐向量是模型中的一个超参数,能够人为管制,这个个别很小,作者在 ImageNet 中应用了 1024 维。

所提办法的关键在于:通过一个低维的注意力瓶颈层,将输出的高维数据,映射到低维,再将它送入深度的 transformer 中。

这样做的益处是,如果仅间接应用 transformer 层,那么面临最大的问题是,训练太消耗工夫,以及须要十分大的显存。作者在文中剖析,transformer 的工夫复杂度为序列长度的二次关系,即 O(M^2),这里 M 指序列长度。应用文中提出的穿插注意力机制,变成了 O(MN),而个别能够设置 N 远小于 M。

接下来是一些我的了解:

相熟注意力机制的都晓得,它包含三个局部,别离是 Q、K 和 V。个别的作用形式是,序列长度是多少,那么 Q、K 和 V 的长度就是多少。但这一点其实是没有必要的。对于一张图,咱们不须要每一个地位,都须要一个查问向量(Q)。这样就容易了解,作者提出的构造。对于长度为 M 的序列中的每一个元素,咱们会有 N 个查问向量作用于它,所以工夫复杂度就变为了 O(MN)。当有了这样 N 个后果当前,再送入传统的 Transformer 构造,这样就极大水平上缩小了运算量和显存的占用。

迭代式的注意力机制

瓶颈层可能会限度网络捕获必要信息的能力,为了缓解这个景象,Perceiver 应用多个 byte-attend 层,也就是穿插注意力层,当网络须要具体的输出信息时候,它就可能取得到这些信息。

最初,借助这样的迭代的注意力机制,能够将网络设计成权值共享的模式(最终的网络结构十分相似于 RNN)。权值共享使得参数量减少约 10 倍,缩小了网络的过拟合,进步了验证集上的性能。

试验局部

在 ImageNet 上的试验

在 ImageNet 上的试验后果。红色的办法代表设计模型时引入了一些特定常识,蓝色的办法代表没有引入。能够看到 Perceiver 达到了十分有竞争力的成果。

将图像像素随机打乱

这里作者将图像中的像素随机打乱,Fixed 代表所有图像都是用同一个打乱的形式,Random 代表每张图都是随机打乱,能够看到,当进行随机打乱时,其余办法的性能大幅降落。

前面一列是每个模型输出单元的感触野。

这里可能会有一个问题,就是,既然咱们晓得图像中部分的信息是有用的,为什么不利用它呢?作者的思考次要是,这样能够失去一个利用范畴更广的模型,因为如果面临的是多模态工作,比方视频、音频、嗅觉传感器和触摸传感器等等数据,再去手动设计输出数据的交互模式是十分艰难的。

注意力可视化

这里展现的是穿插注意力可视化的后果。

其中,蓝色代表是第一层网络的可视化后果,绿色代表第 2 - 7 层网络的后果,橙色代表第八层网络的可视化后果。第一行是每层抽了一个注意力图作为特写。

从图中能够看到,所提办法没有取部分的信息,而是以一种相似格网的模式扫描整张图。

视频音频的后果

应用了 AudioSet 数据集,独自应用视频或音频,或者两者联合应用,都达到了最好的后果。

点云数据

在点云数据的后果中,PointNet ++ 应用了额定的几何特色,以及更多的加强技术。蓝色的办法都没有应用这些技术。在蓝色的外面,成果是最好的。

应用办法

装置

pip install perceiver-pytorch

应用

import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels = 3,          # 序列中每一个元素的维度
    input_axis = 2,              # 输出数据的坐标数(用于构建地位编码,图像的话就是 2:x 和 y)num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # 网络深度
    num_latents = 256,           # 隐向量的个数
    cross_dim = 512,             # 穿插注意力的维度
    latent_dim = 512,            # 隐向量的维度
    cross_heads = 1,             # 穿插注意力的头数
    latent_heads = 8,            # 隐自注意力的头数
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # 最终输入类别数
    attn_dropout = 0.,
    ff_dropout = 0.,
)

img = torch.randn(1, 224, 224, 3) # imagenet 图像数据

model(img) # (1, 1000)

参考资料

  • 论文链接:https://arxiv.org/pdf/2103.03206.pdf
  • 代码:https://github.com/lucidrains/perceiver-pytorch(非官方实现)

写在最初:如果感觉这篇文章对您有帮忙,欢送点赞珍藏评论反对我,谢谢!
也欢送关注我的公众号:算法小哥克里斯。

举荐浏览:

  • Chris:将注意力机制引入 ResNet,视觉畛域涨点技巧来了!附应用办法
  • Chris:Facebook AI 提出 TimeSformer:齐全基于 Transformer 的视频了解框架
退出移动版