关于深度学习:使用Pytorch手写ViT-VisionTransformer

6次阅读

共计 6317 个字符,预计需要花费 16 分钟才能阅读完成。

《The Attention is all you need》的论文彻底改变了自然语言解决的世界,基于 Transformer 的架构成为自然语言解决工作的的规范。

只管基于卷积的架构在图像分类工作中依然是最先进的技术,但论文《An image is worth 16×16 words: transformer for image recognition at scale》表明,计算机视觉中 CNNs 的依赖也不是必要的,间接对图像进行分块,而后应用序纯 transformer 能够很好地实现图像分类工作。

在 ViT 中,图像被宰割成小块,并将这些小块的线性嵌入序列作为 Transformer 的输出。对图像进行补丁解决形式与 NLP 应用程序中的标记 (单词) 雷同。

因为不足 CNN 固有的演绎偏差(如局部性),Transformers 在数据量有余的状况下不能很好地泛化。然而当在大型数据集上进行训练时,它在多个图像识别基准上的确达到或超过了最先进的程度。在深刻本文之前,如果你从未据说过 Transformer 架构,我强烈建议你查看 The Illustrated Transformer。

在开始实现之前,咱们先看看 ViT 架构

能够看到输出图像被分解成 16×16 的扁平化块,而后应用一般的全连贯层对这些块进行嵌入操作,并在它们后面蕴含非凡的 cls token 和地位嵌入。

线性投影的张量被传递给规范的 Transformer 编码器,最初传递给 MLP 头,用于分类目标。

首先咱们从导入库开始,一步一步实现论文中提到的 ViT 模型:

 import matplotlib.pyplot as plt
 from PIL import Image
 
 import torch
 import torch.nn.functional as F
 from torch import Tensor, nn
 from torchsummary import summary
 from torchvision.transforms import Compose, Resize, ToTensor
 
 from einops import rearrange, reduce, repeat
 from einops.layers.torch import Rearrange, Reduce

为了调试咱们的模型,还须要一张图片来进行测试:

 img = Image.open('penguin.jpg')
 
 fig = plt.figure()
 plt.imshow(img)
 plt.show()

图片还须要一些预处理:

 transform = Compose([Resize((224, 224)),
     ToTensor(),])
 
 x = transform(img)
 x = x.unsqueeze(0)
 print(x.shape)

通过下面的预处理,咱们的张量大小为 torch.Size([1,3,224,224])。接下来,咱们开始依照论文实现 ViT。

切分补丁和投影

将图像分成多个补丁,并将它们展平。

以下是论文的原话:

咱们能够很容易地应用 einops 来实现它。

 patch_size = 16
 patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)

下一步是对它们进行投影:这能够应用规范线性层轻松实现,但本文中应用卷积层(应用 kernel_size 和 stride 等于 patch_size 取得的,这压根以进步性能)。让咱们在 PatchEmbedding 类中解决问题。

 class PatchEmbedding(nn.Module):
     def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
         self.patch_size = patch_size
         super().__init__()
         self.projection = nn.Sequential(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
             Rearrange('b e (h) (w) -> b (h w) e'),
         ) # this breaks down the image in s1xs2 patches, and then flat them
                 
     def forward(self, x: Tensor) -> Tensor:
         x = self.projection(x)
         return x

为了测试咱们的代码,能够调用 PatchEmbedding()(x).shape,失去:

 torch.Size([1, 196, 768])

CLS 令牌和地位嵌入

与 BERT 的分类令牌相似,一个可学习的嵌入被事后增加到嵌入补丁的序列中。而后将地位嵌入增加到补丁嵌入中以保留地位信息。这里应用规范可学习的一维地位嵌入。

 class PatchEmbedding(nn.Module):
     def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
         self.patch_size = patch_size
         super().__init__()
         self.projection = nn.Sequential(nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
             Rearrange('b e (h) (w) -> b (h w) e'),
         ) # this breaks down the image in s1xs2 patches, and then flat them
         
         self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
         self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
 
         
     def forward(self, x: Tensor) -> Tensor:
         b, _, _, _ = x.shape
         x = self.projection(x)
         cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
         x = torch.cat([cls_tokens, x], dim=1) #prepending the cls token
         x += self.positions
         return x

这样生成的嵌入向量序列用将作编码器的输出。

Transformer 编码器 (Vaswani et al., 2017) 由多头自注意力和 MLP 块的交替层组成。在每个块之前利用 Layer Norm (LN),并在每个块之后增加残差连贯。

注意力机制

注意力机制须要三个输出:查问、键和值。而后它应用查问和键计算注意力矩阵。

这里将实现一个多头注意力机制,次要概念是应用查问和键之间的乘积来理解序列中的每个元素对其余元素的重要性。稍后将应用这些信息对值进行缩放。能够为查问、键和值矩阵应用 3 个不同的线性层,也能够将它们交融为一个。

 class MultiHeadAttention(nn.Module):
     def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
         super().__init__()
         self.emb_size = emb_size
         self.num_heads = num_heads
         self.qkv = nn.Linear(emb_size, emb_size * 3) # queries, keys and values matrix
         self.att_drop = nn.Dropout(dropout)
         self.projection = nn.Linear(emb_size, emb_size)
         
     def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
         # split keys, queries and values in num_heads
         qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
         queries, keys, values = qkv[0], qkv[1], qkv[2]
         # sum up over the last axis
         energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
         
         if mask is not None:
             fill_value = torch.finfo(torch.float32).min
             energy.mask_fill(~mask, fill_value)
             
         scaling = self.emb_size ** (1/2)
         
         att = F.softmax(energy, dim=-1) / scaling
         att = self.att_drop(att)
         out = torch.einsum('bhal, bhlv -> bhav', att, values) # sum over the third axis
         out = rearrange(out, "b h n d -> b n (h d)")
         out = self.projection(out)
         
         return out

残差连贯

从上图中能够看出,transformer 块有残差连贯。

咱们可应用一个包装器来执行残差加法,这样能够复用:

 class ResidualAdd(nn.Module):
     def __init__(self, fn):
         super().__init__()
         self.fn = fn
         
     def forward(self, x, **kwargs):
         res = x
         x = self.fn(x, **kwargs)
         x += res
         return x

注意力块的输入被传递到一个全连贯层。最初一层由两层组成,它们通过因子 L 进行上采样:

 class FeedForwardBlock(nn.Sequential):
     def __init__(self, emb_size: int, L: int = 4, drop_p: float = 0.):
         super().__init__(nn.Linear(emb_size, L * emb_size),
             nn.GELU(),
             nn.Dropout(drop_p),
             nn.Linear(L * emb_size, emb_size),
         )

Transformer 编码器块

下面的分块步骤都曾经实现了,上面咱们将这些块整合成编码器:

 class TransformerEncoderBlock(nn.Sequential):
     def __init__(self, emb_size: int = 768, drop_p: float = 0., forward_expansion: int = 4,
                  forward_drop_p: float = 0.,
                  **kwargs):
                  
         super().__init__(
             ResidualAdd(nn.Sequential(nn.LayerNorm(emb_size),
                 MultiHeadAttention(emb_size, **kwargs),
                 nn.Dropout(drop_p)
             )),
             ResidualAdd(nn.Sequential(nn.LayerNorm(emb_size),
                 FeedForwardBlock(emb_size, L=forward_expansion, drop_p=forward_drop_p),
                 nn.Dropout(drop_p)
             )
             ))

要测试这部分代码,能够间接调用:

 patches_embedded = PatchEmbedding()(x)
 print(TransformerEncoderBlock()(patches_embedded).shape)

这样会返回 torch.Size([1,197,768])

Transformer 编码器

因为只须要编码器,所以能够应用下面编写的 TransformerEncoderBlock 进行构建

 class TransformerEncoder(nn.Sequential):
     def __init__(self, depth: int = 12, **kwargs):
         super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

分类头

因为 ViT 是分类工作,所以最初要有一个进行分类人物的分类头,这个非常简单:计算整个序列的简略平均值之后是一个规范的全连贯,它给出了类概率。

 class ClassificationHead(nn.Sequential):
     def __init__(self, emb_size: int = 768, n_classes: int = 1000):
         super().__init__(Reduce('b n e -> b e', reduction='mean'),
             nn.LayerNorm(emb_size), 
             nn.Linear(emb_size, n_classes))

整合所有的组件——VisionTransformer

将咱们下面构建的所有内容整合,最终就能够失去 ViT 了。

 class ViT(nn.Sequential):
     def __init__(self,     
                 in_channels: int = 3,
                 patch_size: int = 16,
                 emb_size: int = 768,
                 img_size: int = 224,
                 depth: int = 12,
                 n_classes: int = 1000,
                 **kwargs):
         super().__init__(PatchEmbedding(in_channels, patch_size, emb_size, img_size),
             TransformerEncoder(depth, emb_size=emb_size, **kwargs),
             ClassificationHead(emb_size, n_classes)
         )

查看咱们构建的模型,能够应用 torchsummary 来查看后果:

 print(summary(ViT(), (3,224,224), device='cpu'))

将失去:

 ================================================================
 Total params: 86,415,592
 Trainable params: 86,415,592
 Non-trainable params: 0
 ----------------------------------------------------------------
 Input size (MB): 0.57
 Forward/backward pass size (MB): 364.33
 Params size (MB): 329.65
 Estimated Total Size (MB): 694.56
 ----------------------------------------------------------------

总结

本篇文章应用 Pytorch 中实现 Vision Transformer,通过咱们本人的手动实现能够更好的了解 ViT 的架构,为了加深印象咱们再看下论文中提供的与现有技术的比拟:

本文代码:

https://avoid.overfit.cn/post/da052c915f4b4309b5e6b139a69394c1

作者:Alessandro Lamberti

正文完
 0