关于深度学习:EasyNLP中文文图生成模型带你秒变艺术家

3次阅读

共计 11931 个字符,预计需要花费 30 分钟才能阅读完成。

作者:汪诚愚、刘婷婷

导读

宣物莫大于言,存形莫长于画。–【晋】陆机

多模态数据(文本、图像、声音)是人类意识、了解和表白世间万物的重要载体。近年来,多模态数据的爆炸性增长促成了内容互联网的凋敝,也带来了大量多模态内容了解和生成的需要。与常见的跨模态了解工作不同,文到图的生成工作是风行的跨模态生成工作,旨在生成与给定文本对应的图像。这一文图生成的工作,极大地开释了 AI 的想象力,也激发了人类的创意。典型的模型例如 OpenAI 开发的 DALL- E 和 DALL-E2。近期,业界也训练出了更大、更新的文图生成模型,例如 Google 提出的 Parti 和 Imagen。

然而,上述模型个别不能用于解决中文的需要,而且上述模型的参数量宏大,很难被开源社区的宽广用户间接用来 Fine-tune 和推理。本次,EasyNLP 开源框架再次迎来大降级,集成了先进的文图生成架构 Transformer+VQGAN,同时,向开源社区收费凋谢不同参数量的中文文图生成模型的 Checkpoint,以及相应 Fine-tune 和推理接口。用户能够在咱们凋谢的 Checkpoint 根底上进行大量畛域相干的微调,在不耗费大量计算资源的状况下,就能一键进行各种艺术创作。

EasyNLP 是阿里云机器学习 PAI 团队基于 PyTorch 开发的易用且丰盛的中文 NLP 算法框架,并且提供了从训练到部署的一站式 NLP 开发体验。EasyNLP 提供了简洁的接口供用户开发 NLP 模型,包含 NLP 利用 AppZoo、预训练模型 ModelZoo、数据仓库 DataHub 等个性。因为跨模态了解和生成需要的一直减少,EasyNLP 也反对各种跨模态模型,特地是中文畛域的跨模态模型,推向开源社区。例如,在先前的工作中,EasyNLP 曾经对中文图文检索 CLIP 模型进行了反对(看这里)。咱们心愿可能服务更多的 NLP 和多模态算法开发者和研究者,也心愿和社区一起推动 NLP / 多模态技术的倒退和模型落地。本文简要介绍文图生成的技术,以及如何在 EasyNLP 框架中如何轻松实现文图生成,带你秒变艺术家。本文结尾的展现图片即为咱们模型创作的作品。

文图生成模型简述

上面以几个经典的基于 Transformer 的工作为例,简略介绍文图生成模型的技术。DALL- E 由 OpenAI 提出,采取两阶段的办法生成图像。在第一阶段,训练一个 dVAE(discrete variational autoencoder)的模型将 256×256 的 RGB 图片转化为 32×32 的 image token,这一步骤将图片进行信息压缩和离散化,不便进行文本到图像的生成。第二阶段,DALL- E 训练一个自回归的 Transformer 模型,将文本输出转化为上述 1024 个 image token。

由清华大学等单位提出的 CogView 模型对上述两阶段文图生成的过程进行了进一步的优化。在下图中,CogView 采纳了 sentence piece 作为 text tokenizer 使得输出文本的空间表白更加丰盛,并且在模型的 Fine-tune 过程中采纳了多种技术,例如图像的超分、格调迁徙等。

ERNIE-ViLG 模型思考进一步思考了 Transformer 模型学习常识的可迁移性,同时学习了从文本生成图像和从图像生成文本这两种工作。其架构图如下所示:

随着文图生成技术的一直倒退,新的模型和技术不断涌现。举例来说,OFA 将多种跨模态的生成工作对立在同一个模型架构中。DALL-E 2 同样由 OpenAI 提出,是 DALL- E 模型的升级版,思考了层次化的图像生成技术,模型利用 CLIP encoder 作为编码器,更好地融入了 CLIP 预训练的跨模态表征。Google 进一步提出了 Diffusion Model 的架构,能无效生成高清大图,如下所示:

在本文中,咱们不再对这些细节进行赘述。感兴趣的读者能够进一步查阅参考文献。

EasyNLP 文图生成模型

因为前述模型的规模往往在数十亿、百亿参数级别,宏大的模型尽管能生成品质较大的图片,而后对计算资源和预训练数据的要求使得这些模型很难在开源社区广泛应用,尤其在须要面向垂直畛域的状况下。在本节中,咱们具体介绍 EasyNLP 提供的中文文图生成模型,它在较小参数量的状况下,仍然具备良好的文图生成成果。

模型架构

模型框架图如下图所示:

思考到 Transformer 模型复杂度随序列长度呈二次方增长,文图生成模型的训练个别以图像矢量量化和自回归训练两阶段联合的形式进行。

图像矢量量化是指将图像进行离散化编码,如将 256×256 的 RGB 图像进行 16 倍降采样,失去 16×16 的离散化序列,序列中的每个 image token 对应于 codebook 中的示意。常见的图像矢量量化办法包含:VQVAE、VQVAE- 2 和 VQGAN 等。咱们采纳 VQGAN 在 ImageNet 上训练的 f16_16384(16 倍降采样,词表大小为 16384)的模型权重来生成图像的离散化序列。

自回归训练是指将文本序列和图像序列作为输出,在图像局部,每个 image token 仅与文本序列的 tokens 和其之前的 image tokens 进行 attention 计算。咱们采纳 GPT 作为 backbone,可能适应不同模型规模的生成工作。在模型预测阶段,输出文本序列,模型以自回归的形式逐渐生成定长的图像序列,再通过 VQGAN decoder 重构为图像。

开源模型参数设置

模型配置 pai-painter-base-zh pai-painter-large-zh
参数量(Parameters) 202M 433M
层数(Number of Layers) 12 24
注意力头数(Attention Heads) 12 16
隐向量维度(Hidden Size) 768 1024
文本长度(Text Length) 32 32
图像序列长度(Image Length) 16 x 16 16 x 16
图像尺寸(Image Size) 256 x 256 256 x 256
VQGAN 词表大小(Codebook Size) 16384 16384

模型实现

在 EasyNLP 框架中,咱们在模型层构建基于 minGPT 的 backbone 构建模型,外围局部如下所示:

self.first_stage_model = VQModel(ckpt_path=vqgan_ckpt_path).eval()
self.transformer = GPT(self.config)

VQModel 的 Encoding 阶段过程为:

# in easynlp/appzoo/text2image_generation/model.py

@torch.no_grad()
def encode_to_z(self, x):
    quant_z, _, info = self.first_stage_model.encode(x)
    indices = info[2].view(quant_z.shape[0], -1)
    return quant_z, indices

x = inputs['image']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
# one step to produce the logits
_, z_indices = self.encode_to_z(x)  # z_indice: torch.Size([batch_size, 256]) 

VQModel 的 Decoding 阶段过程为:

# in easynlp/appzoo/text2image_generation/model.py

@torch.no_grad()
def decode_to_img(self, index, zshape):
    bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
    quant_z = self.first_stage_model.quantize.get_codebook_entry(index.reshape(-1), shape=bhwc)
    x = self.first_stage_model.decode(quant_z)
    return x

# sample 为训练阶段的后果生成,与预测阶段的 generate 相似,详解见下文 generate
index_sample = self.sample(z_start_indices, c_indices,
                           steps=z_indices.shape[1],
                           ...)
x_sample = self.decode_to_img(index_sample, quant_z.shape)

Transformer 采纳 minGPT 进行构建,输出图像的离散编码,输入文本 token。前向流传过程为:

# in easynlp/appzoo/text2image_generation/model.py

def forward(self, inputs):
    x = inputs['image']
    c = inputs['text']
    x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
    # one step to produce the logits
    _, z_indices = self.encode_to_z(x)  # z_indice: torch.Size([batch_size, 256]) 
    c_indices = c
    
    if self.training and self.pkeep < 1.0:
        mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
                                                     device=z_indices.device))
        mask = mask.round().to(dtype=torch.int64)
        r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
        a_indices = mask*z_indices+(1-mask)*r_indices
    
    else:
        a_indices = z_indices
        cz_indices = torch.cat((c_indices, a_indices), dim=1)
        # target includes all sequence elements (no need to handle first one
        # differently because we are conditioning)
        target = z_indices
        # make the prediction
        logits, _ = self.transformer(cz_indices[:, :-1])
        # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
        logits = logits[:, c_indices.shape[1]-1:]
    return logits, target

在预测阶段,输出为文本 token,输入为 256*256 的图像。首先,将输出文本预处理为 token 序列:

# in easynlp/appzoo/text2image_generation/predictor.py

def preprocess(self, in_data):
    if not in_data:
        raise RuntimeError("Input data should not be None.")

    if not isinstance(in_data, list):
        in_data = [in_data]
    rst = {"idx": [], "input_ids": []}
    max_seq_length = -1
    for record in in_data:
        if "sequence_length" not in record:
            break
        max_seq_length = max(max_seq_length, record["sequence_length"])
    max_seq_length = self.sequence_length if (max_seq_length == -1) else max_seq_length

    for record in in_data:
        text= record[self.first_sequence]
        try:
            self.MUTEX.acquire()
            text_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
            text_ids = text_ids[: self.text_len]
            n_pad = self.text_len - len(text_ids)
            text_ids += [self.pad_id] * n_pad
            text_ids = np.array(text_ids) + self.img_vocab_size

        finally:
            self.MUTEX.release()

        rst["idx"].append(record["idx"]) 
        rst["input_ids"].append(text_ids)
    return rst

逐渐生成长度为 16*16 的图像离散 token 序列:

# in easynlp/appzoo/text2image_generation/model.py

def generate(self, inputs, top_k=100, temperature=1.0):
    cidx = inputs
    sample = True
    steps = 256
    for k in range(steps):
        x_cond = cidx
        logits, _ = self.transformer(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = self.top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        cidx = torch.cat((cidx, ix), dim=1)
    img_idx = cidx[:, 32:]
    return img_idx

最初,咱们调用 VQModel 的 Decoding 过程将这些图像离散 token 序列转换为图像。

模型成果

咱们在四个中文的公开数据集 COCO-CN、MUGE、Flickr8k-CN、Flickr30k-CN 上验证了 EasyNLP 框架中文图生成模型的成果。同时,咱们比照了这个模型和 CogView、DALL- E 的成果,如下所示:

其中,
1)MUGE 是天池平台颁布的电商场景的中文大规模多模态评测基准(http://tianchi.aliyun.com/muge)。为了不便计算指标,MUGE 咱们采纳 valid 数据集的后果,其余数据集采纳 test 数据集的后果。

2)CogView 源自 https://github.com/THUDM/CogView

3)DALL- E 模型没有公开的官网代码。曾经公开的局部只蕴含 VQVAE 的代码,不包含 Transformer 局部。咱们基于广受关注的 https://github.com/lucidrains… 版本的代码和该版本举荐的 checkpoits 进行复现,checkpoints 为 2.09 亿参数,为 OpenAI 的 DALL- E 模型参数量的 1 /100。(OpenAI 版本 DALL- E 为 120 亿参数,其中 CLIP 为 4 亿参数)。

经典案例

咱们别离在自然风景数据集 COCO-CN 上 Fine-tune 了 base 和 large 级别的模型,如下展现了模型的成果:

示例 1:一只俏皮的狗正跑过草地

示例 2:一片水域的风景以日落为背景

咱们也积攒了阿里团体的海量电商商品数据,微调失去了面向电商商品的文图生成模型。成果如下:

示例 3:女童套头毛衣打底衫秋冬针织衫童装儿童内搭上衣

示例 4:春夏真皮工作鞋女深色软皮久站舒服下班面试职业皮鞋

除了反对特定畛域的利用,文图生成也极大地辅助了人类的艺术创作。应用训练失去的模型,咱们能够秒变“中国国画艺术大师”,示例如下所示:

更多的示例请观赏:

应用教程

观赏了模型生成的作品之后,如果咱们想 DIY,训练本人的文图生成模型,应该如何进行呢?以下咱们简要介绍在 EasyNLP 框架对预训练的文图生成模型进行 Fine-tune 和推理。

装置 EasyNLP

用户能够间接参考链接的阐明装置 EasyNLP 算法框架。

数据筹备

首先筹备训练数据与验证数据,为 tsv 文件。这一文件蕴含以制表符 \t 分隔的两列,第一列为索引号,第二列为文本,第三列为图片的 base64 编码。用于测试的输出文件为两列,仅蕴含索引号和文本。

为了不便开发者,咱们也提供了转换图片到 base64 编码的示例代码:

import base64
from io import BytesIO
from PIL import Image

img = Image.open(fn)
img_buffer = BytesIO()
img.save(img_buffer, format=img.format)
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data) # bytes

下列文件曾经实现预处理,可用于测试:

# train
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv

# valid
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv

# test
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv

模型训练

咱们采纳以下命令对模型进行

easynlp \
    --mode=train \
    --worker_gpu=1 \
    --tables=MUGE_val_text_imgbase64.tsv,MUGE_val_text_imgbase64.tsv \
    --input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
    --first_sequence=text \
    --second_sequence=imgbase64 \
    --checkpoint_dir=./finetuned_model/ \
    --learning_rate=4e-5 \
    --epoch_num=1 \
    --random_seed=42 \
    --logging_steps=100 \
    --save_checkpoint_steps=1000 \
    --sequence_length=288 \
    --micro_batch_size=16 \
    --app_name=text2image_generation \
    --user_defined_parameters='
        pretrain_model_name_or_path=alibaba-pai/pai-painter-large-zh
        size=256
        text_len=32
        img_len=256
        img_vocab_size=16384
    ' 

咱们提供 base 和 large 两个版本的预训练模型,pretrain_model_name_or_path 别离为 alibaba-pai/pai-painter-base-zh 和 alibaba-pai/pai-painter-large-zh。

训练实现后模型被保留到./finetuned_model/。

模型批量推理

模型训练结束后,咱们能够将其用于图像生成,其示例如下:

easynlp \
    --mode=predict \
    --worker_gpu=1 \
    --tables=MUGE_test.text.tsv \
    --input_schema=idx:str:1,text:str:1 \
    --first_sequence=text \
    --outputs=./T2I_outputs.tsv \
    --output_schema=idx,text,gen_imgbase64 \
    --checkpoint_dir=./finetuned_model/ \
    --sequence_length=288 \
    --micro_batch_size=8 \
    --app_name=text2image_generation \
    --user_defined_parameters='
        size=256
        text_len=32
        img_len=256
        img_vocab_size=16384
    '

后果存储在一个 tsv 文件中,每行对应输出中的一个文本,输入的图像以 base64 编码。

应用 Pipeline 接口疾速体验文图生成成果

为了进一步不便开发者应用,咱们在 EasyNLP 框架内也实现了 Inference Pipeline 性能。用户能够应用如下命令调用 Fine-tune 过的电商场景下的文图生成模型:

# 间接构建 pipeline
default_ecommercial_pipeline = pipeline("pai-painter-commercial-base-zh")

# 模型预测
data = ["宽松 T 恤"]
results = default_ecommercial_pipeline(data)  # results 的每一条是生成图像的 base64 编码

# base64 转换为图像
def base64_to_image(imgbase64_str):
    image = Image.open(BytesIO(base64.urlsafe_b64decode(imgbase64_str)))
    return image

# 保留以文本命名的图像
for text, result in zip(data, results):
    imgpath = '{}.png'.format(text)
    imgbase64_str = result['gen_imgbase64']
    image = base64_to_image(imgbase64_str)
    image.save(imgpath)
    print('text: {}, save generated image: {}'.format(text, imgpath))

除了电商场景,咱们还提供了以下场景的模型:

  • 自然风光场景:“pai-painter-scenery-base-zh”
  • 中国山水画场景:“pai-painter-painting-base-zh”
    在下面的代码当中替换“pai-painter-commercial-base-zh”,就能够间接体验,欢送试用。对于用户 Fine-tune 的文图生成模型,咱们也凋谢了自定义模型加载的 Pipeline 接口:
# 加载模型,构建 pipeline
local_model_path = ...
text_to_image_pipeline = pipeline("text2image_generation", local_model_path)

# 模型预测
data = ["xxxx"]
results = text_to_image_pipeline(data)  # results 的每一条是生成图像的 base64 编码 

将来瞻望

在这一期的工作中,咱们在 EasyNLP 框架中集成了中文文图生成性能,同时凋谢了模型的 Checkpoint,不便开源社区用户在资源无限状况下进行大量畛域相干的微调,进行各种艺术创作。在将来,咱们打算在 EasyNLP 框架中推出更多相干模型,敬请期待。咱们也将在 EasyNLP 框架中集成更多 SOTA 模型(特地是中文模型),来反对各种 NLP 和多模态工作。此外,阿里云机器学习 PAI 团队也在继续推动中文多模态模型的自研工作,欢送用户继续关注咱们,也欢送退出咱们的开源社区,共建中文 NLP 和多模态算法库!

Github 地址:https://github.com/alibaba/EasyNLP

Reference

  1. Chengyu Wang, Minghui Qiu, Taolin Zhang, Tingting Liu, Lei Li, Jianing Wang, Ming Wang, Jun Huang, Wei Lin. EasyNLP: A Comprehensive and Easy-to-use Toolkit for Natural Language Processing. arXiv
  2. Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever. Zero-Shot Text-to-Image Generation. ICML 2021: 8821-8831
  3. Ming Ding, Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou, Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, Jie Tang. CogView: Mastering Text-to-Image Generation via Transformers. NeurIPS 2021: 19822-19835
  4. Han Zhang, Weichong Yin, Yewei Fang, Lanxin Li, Boqiang Duan, Zhihua Wu, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang. ERNIE-ViLG: Unified Generative Pre-training for Bidirectional Vision-Language Generation. arXiv
  5. Peng Wang, An Yang, Rui Men, Junyang Lin, Shuai Bai, Zhikang Li, Jianxin Ma, Chang Zhou, Jingren Zhou, Hongxia Yang. Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework. ICML 2022
  6. Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. Hierarchical Text-Conditional Image Generation with CLIP Latents. arXiv
  7. Van Den Oord A, Vinyals O. Neural discrete representation learning. NIPS 2017
  8. Esser P, Rombach R, Ommer B. Taming transformers for high-resolution image synthesis. CVPR 2021: 12873-12883.
  9. Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David J. Fleet, Mohammad Norouzi: Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. arXiv
  10. Jiahui Yu, Yuanzhong Xu, Jing Yu Koh, Thang Luong, Gunjan Baid, Zirui Wang, Vijay Vasudevan, Alexander Ku, Yinfei Yang, Burcu Karagol Ayan, Ben Hutchinson, Wei Han, Zarana Parekh, Xin Li, Han Zhang, Jason Baldridge, Yonghui Wu. Scaling Autoregressive Models for Content-Rich Text-to-Image Generation. arXiv

阿里灵杰回顾
● 阿里灵杰:阿里云机器学习 PAI 开源中文 NLP 算法框架 EasyNLP,助力 NLP 大模型落地
● 阿里灵杰:预训练常识度量较量夺冠!阿里云 PAI 公布常识预训练工具
● 阿里灵杰:EasyNLP 带你玩转 CLIP 图文检索

正文完
 0