多模态人工智能是一种新型 AI 范式,是指图像、文本、语音、视频等多种数据类型,与多种智能解决算法相结合,以期实现更高的性能。
近日,PyTorch 官网公布了一个 domain library–TorchMultimodal, 用于 SoTA 多任务、多模态模型的大规模训练。
该库提供了:
- 可组合的 building block(module、transforms、损失函数)用于减速模型开发
- 从已发表的钻研、训练及评估脚本中提取的 SoTA 模型架构 (FLAVA, MDETR, Omnivore)
- 用于测试这些模型的 notebook
TorchMultimodal 库仍在踊跃开发中,详情请关注:
https://github.com/facebookresearch/multimodal#installation
TorchMultimodal 开发背景
随着技术的提高,能了解多种类型输出(文本、图像、视频和音频信号),并能利用这种了解来生成不同模式的输入(句子、图片、视频)的 AI 模型越来越引发关注。
FAIR 最近的钻研工作(如 FLAVA、Omnivore 和 data2vec)表明, 用于了解的多模态模型与单模态模型相比更有劣势,并且在某些状况下正在创始全新的 SOTA。
相似 Make-a-video 以及 Make-a-scene 这样的生成模型,正在从新定义古代人工智能零碎的能力边界。
为了促成 PyTorch 生态中多模态 AI 的倒退, TorchMultimodal 库应运而生,其解决思路为:
- 提供可组合的 building block, 利用这些 building block,钻研人员能够在本人的工作流中减速模型开发和试验。模块化设计也升高了迁徙到新模态数据的难度。
-
提供了用于训练和评估钻研中最新模型的端到端示例。 这些示例中用到了一些高阶个性,如集成 FSDP 和用于扩大模型及批尺寸的 activation checkpointing。
初识 TorchMultimodal
TorchMultimodal 是一个 PyTorch domain library, 用于多任务多模态模型的大规模训练。 它提供:
1. Building Block
模块及可组合 building block 汇合,如模型、交融层、损失函数、数据集和实用程序,例如:
- 温度比照损失 (Contrastive Loss with Temperature): 罕用于训练模型的函数,如 CLIP 和 FLAVA。此外还包含在 ALBEF 等模型中应用的 ImageTextContrastiveLoss 等变量。
- Codebook layer: 通过向量空间中的最近邻查找压缩高维数据,它也是 VQVAE 的重要组成部分。
- Shifted-window Attention: window 基于 multi-head self attention,是 Swin 3D Transformer 等编码器的重要组件。
- CLIP 组件: 由 OpenAI 公布,是一个在学习文本和图像表征方面十分无效的模型。
- Multimodal GPT: 与生成程序联合时,可将 OpenAI 的 GPT 架构扩大为更适宜多模态生成的形象。
-
MultiHeadAttention: 基于 attention 的模型的一个要害组件,反对 auto-regressive 和 decoding。
2. 示例
一组示例展现了如何将 building block 与 PyTorch 组件和公共基础设施 (Lightning, TorchMetrics) 联合,从而复制文献中发表的 SOTA 模型。目前提供了五个示例,其中包含:
-
FLAVA: CVPR 接管论文的官网代码,包含一个对于 FLAVA 微调的教程。
查看论文:https://arxiv.org/abs/2112.04482
-
MDETR: 与 NYU 的作者单干提供了一个例子,加重了 PyTorch 生态系统中互操作性 (interoperability) 痛点,包含一个应用 MDETR 进行 phrase grounding 和可视化问答的 notebook。
查看论文:https://arxiv.org/abs/2104.12763
-
Omnivore: TorchMultimodal 中解决视频和 3D 数据的模型的第一个例子,包含用于摸索模型的 notebook。
查看论文:https://arxiv.org/abs/2204.08058
-
MUGEN: auto-regressive 生成和检索的根底工作,包含应用 OpenAI coinrun 丰盛的大规模合成数据集生成和检索 text-video 的 demo。
查看论文:https://arxiv.org/abs/2204.08058
-
ALBEF: 模型代码,包含用该模型解决视觉问答问题的 notebook。
查看论文:https://arxiv.org/abs/2107.07651
以下代码展现了几个与 CLIP 相干的 TorchMultimodal 组件的用法:
# instantiate clip transform
clip_transform = CLIPTransform()
# pass the transform to your dataset. Here we use coco captions
dataset = CocoCaptions(root= ..., annFile=..., transforms=clip_transform)
dataloader = DataLoader(dataset, batch_size=16)
# instantiate model. Here we use clip with vit-L as the image encoder
model= clip_vit_l14()
# define loss and other things needed for training
clip_loss = ContrastiveLossWithTemperature()
optim = torch.optim.AdamW(model.parameters(), lr = 1e-5)
epochs = 1
# write your train loop
for _ in range(epochs):
for batch_idx, batch in enumerate(dataloader):
image, text = batch
image_embeddings, text_embeddings = model(image, text)
loss = contrastive_loss_with_temperature(image_embeddings, text_embeddings)
loss.backward()
optimizer.step()
装置 TorchMultimodal
Python ≥ 3.7,安不装置 CUDA 反对均可。
以下代码以装置 conda 为例
前提条件
1. 装置 conda 环境
conda create -n torch-multimodal python=\<python_version\>
conda activate torch-multimodal
2. 装置 PyTorch、torchvision 以及 torchtext
参阅 PyTorch 文档:
https://pytorch.org/get-started/locally/
# Use the current CUDA version as seen [here](https://pytorch.org/get-started/locally/)
# Select the nightly Pytorch build, Linux as the OS, and conda. Pick the most recent CUDA version.
conda install pytorch torchvision torchtext pytorch-cuda=\<cuda_version\> -c pytorch-nightly -c nvidia
# For CPU-only install
conda install pytorch torchvision torchtext cpuonly -c pytorch-nightly
从二进制文件装置
在 Linux 上,实用于 Python 3.7、3.8 和 3.9 的 Nightly binary 可通过 pip wheels 装置。目前只通过 PyPI 反对 Linux 平台。
python -m pip install torchmultimodal-nightly
源码装置
开发者也能够通过源码构建并运行示例:
git clone --recursive https://github.com/facebookresearch/multimodal.git multimodal
cd multimodal
pip install -e .
以上就是对于 TorchMultimodal 的简略介绍。除代码外,PyTorch 官网还公布了一个对于微调多模态模型的基础教程, 以及一篇对于如何应用 PyTorch Distributed PyTorch (FSDP and activation checkpointing) 技术扩大这些模型的 blog。
后续咱们将针对这篇 blog 进行汉化整顿。欢送继续关注 PyTorch 开发者社区公众号!
—— 完 ——