背景
Stable Diffusion(SD)是一种风行的AI生成内容(AI Generated Content,AIGC)模型,能在文字输出的根底上生成各种格调多样的图像。在目前的AIGC方向,SD是开源社区最热门的模型。然而,SD可能生成高颜值的图像,十分依赖于用户提供的Prompt。如果没有好的Prompt,SD往往无奈生成用户预期的图像,极大的影响用户的应用体验。在先前的工作中,阿里云机器学习PAI团队在AIGC方向做了很多摸索,包含PAI-Diffusion中文模型的开源、基于Blade的推理优化等,并且推出一系列行业解决方案。为了晋升SD系列模型的易用性、升高应用门槛、开释AI模型的发明后劲,咱们提出并训练实现面向SD主动Prompt丑化器,使得用户只有输出一个极其简略的Prompt,就能够失去一系列通过语言模型优化过的、细节满满的Prompt,帮忙您更简略地生成高颜值图像。在下文中,咱们具体介绍PAI主动Prompt生成模型的性能和背地的技术干货。
一键体验Prompt主动生成
在具体介绍PAI主动Prompt生成模型前,咱们首先给出体验Prompt生成成果。在以下的示例中,咱们别离比照了原始Prompt和咱们生成Prompt在Stable Diffusion v1.5底座模型上生成图像的成果。对于每个Prompt,咱们随机生成两张图片进行比照。
咱们也在ModelScope上构建了一个Demo以供用户体验(链接)。只有输出一个简略的Prompt,咱们的模型就能够扩写成一个细节形容具体的Prompt,而后应用Stable Diffusion进行文图生成。
Prompt丑化器背地的技术
在本节中,咱们具体介绍如何训练一个优良的Prompt丑化器。整体架构如下图所示:
底座模型
PAI-Prompt生成模型底座基于BLOOM(BigScience Language Open-science Open-access Multilingual),由BigScience训练并且开源。BLOOM具备Decoder-only模型架构,和GPT-3十分类似,最大具备1760亿参数。在咱们的模型训练时,咱们抉择了具备11亿参数的BLOOM模型进行持续训练,其起因在于它的参数量不大,当它用于线上部署时,推理速度较快,而且训练和推理老本都相当可控,具备很高的实用价值。
无需数据标注的SFT
因为模型的训练须要高质量和低质量的Prompt对,这些数据个别很难间接去做标注。因而咱们应用如下办法主动话地收集训练数据。
1.摘要生成:首先,咱们收集开源的高质量Prompt数据集,作为语言模型生成的Target。在这种场景下,因为不足低质量的Prompt,咱们能够应用ChatGPT等大模型生成Summary作为低质量的Prompt。以下是一个Summary的例子:
Instruction: Summarize this image description in 10 words or less and ignore words like archdaily, wallpaper, highly detailed, 8k, [r/earthporn]. Check English. Ignore modifiers 'by xxx', 'with xxx' or 'in xxx'. Ignore adjective.Input: a beautiful very detailed illustration of abandoned urbex unfinished building city nature industrial architecture architecture building spaceport by caspar david friedrich, scumm bar meadow nature synthwave, archdaily, wallpaper, highly detailed, trending on artstation.Output: of abandoned urban building in nature.Input: portrait painting of a lost boy by cedric peyravernay and greg ruthkowski, in the style of dishonored concept art, concept design, trending on artstation \nOutput:
2.Prompt扩大:利用低质量的Prompt,应用ChatGPT生成更高质量的Prompt。以下是一个Prompt生成的例子:
Instruction: create a detailed and creative description of the 'input'. Your response should include specific details about the colors, textures, and overall composition of the painting, as well as any unique features or elements that make it stand out.Please provide a clear and concise response that captures the essence of the painting while also encouraging creativity and originality in your description. You may consider describing the setting or environment depicted in the painting.Input: Digital painting of a girl with candy hat.
3.图像题目生成:咱们收集了高质量的图文对,对图像进行image captioning,生成更多可供训练模型的Prompt。
最终,失去的数据会进行好看值和一致性筛选,咱们保留品质较高的数据用于SFT。
面向SD的强化学习优化
RLHF(Reinforcement Learning from Human Feedback)对ChatGPT等大模型的成果晋升有重要的作用。在咱们的利用中,咱们设计了面向Stable Diffusion的强化学习算法,优化Prompt生成模型。
对于Reward Model,咱们在失去图文对数据根底上,应用美学值评分模型来给图片打分,并应用一个语言模型来拟合对应Pprompt->美学值评分,将此作为咱们的打分模型。此外,咱们还采纳最先进的强化学习算法PPO来进一步优化模型,处分函数应用打分模型和一致性得分加权:
reward = a * score_model(prompt) + b * consistency_model(raw_prompt, prompt)
这样能够进一步增强咱们生成Prompt的好看性和图文一致性。在实现了上述三阶段训练当前,咱们的模型在小参数规模下(1.1B)的成果不亚于ChatGPT生成Prompt的成果,示例如下:
模型调用
如果想疾速体验模型成果,能够拜访咱们在ModelScope社区的创空间页面链接。同时,咱们也在huggingface等开源社区上架了这一模型,应用接口如下:
from transformers import AutoTokenizer, AutoModelForCausalLMtokenizer = AutoTokenizer.from_pretrained('alibaba-pai/pai-bloom-1b1-text2prompt-sd')model = AutoModelForCausalLM.from_pretrained('alibaba-pai/pai-bloom-1b1-text2prompt-sd').eval().cuda()raw_prompt = '1 girl'input = f'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'input_ids = tokenizer.encode(input, return_tensors='pt').cuda()outputs = model.generate( input_ids, max_length=384, do_sample=True, temperature=1.0, top_k=50, top_p=0.95, repetition_penalty=1.2, num_return_sequences=5)prompts = tokenizer.batch_decode(outputs[:, input_ids.size(1):], skip_special_tokens=True)prompts = [p.strip() for p in prompts]print(prompts)
将来瞻望
在这一期的工作中,咱们提出并训练实现面向SD主动Prompt丑化器,使得用户只有输出一个极其简略的Prompt,就能够失去一系列通过语言模型优化过的Prompt,帮忙您更简略地生成高颜值图像。在将来,咱们打算减少这一类模型对各种类SD模型的适配,丰盛PAI-AIGC的算法和产品能力。
点击立刻收费试用云产品 开启云上实际之旅!
作者:曹庭锋、汪诚愚、吴梓恒、黄俊
原文链接
本文为阿里云原创内容,未经容许不得转载。