关于人工智能:ICLR-2023-PromptPG当强化学习遇见大规模语言模型

3次阅读

共计 4269 个字符,预计需要花费 11 分钟才能阅读完成。

转载自机器之心
编辑:一点人工一点智能
原文:ICLR 2023 | PromptPG:当强化学习遇见大规模语言模型

PromptPG 办法在答复问题的准确性上超过最优基准(Few-shot CoT GPT-3)5.31%。

数学推理是人类智能的一项外围能力,但对于机器来说,抽象思维和逻辑推理依然是一个很大的挑战。大规模预训练语言模型,如 GPT-3 和 GPT-4,在文本模式的数学推理(如数学应用题)上曾经获得了显著的停顿。然而,目前咱们还不分明这些模型是否解决波及到异构信息(如表格数据)的更简单的问题。为了填补这一空白,来自 UCLA 和艾伦人工智能研究院(AI2)的钻研人员推出了 Tabular Math Word Problems (TabMWP),这是一个蕴含了 38,431 个凋谢畛域问题的数据集,须要同时在文本和表格数据上进行数学推理失去正确答案。TabMWP 中的每个问题都与一个上下文相关联,这个上下文蕴含图片、文本或结构化格局的表格。

钻研人员在 TabMWP 上评估了包含 Few-shot GPT-3 等不同的预训练模型。正如已有的钻研发现,Few-shot GPT-3 很依赖 in-context 示例的抉择,这导致其在随机抉择示例的状况下性能相当不稳固。这种不稳固在解决像 TabMWP 这样简单的推理问题时体现得更加重大。为了解决这一问题,作者提出了 PromptPG 办法,这种办法将示例的抉择转化成强化学习中的 contextual bandit 问题,并且利用 Policy Gradient 训练一个策略网络来学习从大量的训练数据中抉择最优的 in-context 示例。试验结果表明,他们提出的 PromptPG 办法在答复问题的准确性上超过最优基准(Few-shot CoT GPT-3)5.31%,并且绝对于随机抉择的 in-context examples,他们的办法显著升高了预测的方差,晋升了这类办法的稳定性。

· 论文链接:https://arxiv.org/abs/2209.14610
· 代码链接:https://github.com/lupantech/PromptPG
· 我的项目主页:https://promptpg.github.io
· 数据可视化:https://promptpg.github.io/explore

01  TabMWP 数据集

上面是来自 TabMWP 数据集的两个例子。其中一个是答案为数值类型的自在文本问题(free-text),另一个是答案为文本类型的多项选择题(multi-choice)。能够看到,每个问题都提供了一个蕴含分步推理的解答。要解决 TabMWP 中的问题,零碎必须同时具备查表和多步数学推理的能力。举下图中的例子来说,要答复“how much will she spend (if Tracy buys three kinds of breads)”,咱们须要先在表格中查找出三种面包对应的价格,再计算购买每种面包的费用,并对它们求和已失去最终的费用。

如下表的统计所示,TabMWP 数据集蕴含 38,431 个表格数学问题。其中 74.7% 的问题属于自在文本问题,25.3% 的问题属于多选题。TabMWP 共有 28,876 个不同的问题,6,153 个不同的答案和 35,442 个不同的解答,表明其在问题散布方面具备丰盛的多样性。这些问题均匀长度为 22.1 个单词,解答均匀长度为 49.5 个单词,这表明 TabMWP 具备词汇的丰富性。TabMWP 的一个显著特点是,每个问题都附带有一个表格上下文,如果没有表格,问题将无奈解决。TabMWP 总共有 37,644 个不同的表格,表格均匀有 5.9 行和 2.2 列,12.9 个单元格,最大可达 54 个单元格。这些统计数据表明,TabMWP 中的表格也具备丰盛的多样性。

TabMWP 数据集有两种不同的问题类型以及五种不同的答案类型:

TabMWP 中的每个问题都有一个表格上下文,它以图像、半结构化文本和结构化三种格局示意。这为开发不同类型的推理模型提供了可能性。

相比于已有的数据集,TabMWP 同时须要表格了解和数学推理能力来答复问题。另外,TabMWP 每道题都有具体的多步推理过程,在数据集大小、表格类型、问题类型和答案类型上有显著的劣势。据本文所知,TabMWP 是第一个在凋谢畛域表格场景下的数学推理数据集。

02  PromptPG 办法

思考到大规模预训练模型例如 GPT-3 在解决数学应用题方面获得的胜利,作者首先应用 Few-shot GPT-3 在 TabMWP 上建设了一个基准。他们从训练集中随机抉择一些上下文示例以及测试样本形成提醒(prompt),提醒 GPT-3 预测答案。然而,最近的钻研表明,这种基于随机抉择的 few-shot 学习在不同的上下文示例抉择上可能会体现得十分不稳固。在解决相似 TabMWP 这样的简单推理问题时,随机抉择的成果可能会更差,因为其问题波及到不同类型和格局的表格。

为了解决这个问题,作者提出了一种改良办法:通过 Policy Gradient 进行提醒学习,从大量的训练数据中学习抉择上下文示例,称为 PromptPG。如图 2 所示,策略网络学习从候选池(candidate examples)中找到最佳的 in-context example,其优化指标是在与 GPT-3 环境交互时最大化给定训练示例(training example)的预测处分。抉择示例的策略网络是一个基于固定参数的 BERT 语言模型和一个参数可学习的单层神经网络。在实现优化学习后,PromptPG 能够对不同的测试题目,动静地从候选示例中选出不同的最优示例,从而最大化进步 GPT-3 的推理性能。

以下为 PromptPG 的学习算法。

03  试验与剖析

3.1 预训练与微调

表 3 比照了 PromptPG 和不同基准在 TabMWP 数据集上的后果。能够看到,TAPEX 因为在表格数据上进行了预训练,在类似参数量的前提下,其比 UnifiedQA 的体现要更好。对于 TAPEX 和 UnifiedQA 来说,进步模型的参数量都能够进步预测的准确性。此外,在 TabMWP 上进行模型的微调也能够极大地晋升预测的准确性。

3.2 大规模语言模型

GPT-3 在没有任何微调的状况下(Zero-shot GPT-3),能够获得与微调过的 UnifiedQA 以及 TAPEX 模型相近的准确性。如果 Few-shot GPT-3 模型随机抉择两个 in-context 示例作为 GPT-3 的提醒,其相比 Zero-shot GPT-3 能够进一步晋升 0.17%。通过让 Few-shot GPT-3 在生成最终答案前生成多步的两头步骤(Few-shot-CoT GPT-3),钻研人员能够失去最优的基准模型,其准确率达到了 62.92%。

3.3 PromptPG

区别于随机抉择 in-context 示例,本文提出的 PromptPG 通过 Policy Gradient 训练一个策略网络来抉择更适合的 in-context 示例,在 TabMWP 上获得了最高的预测后果(68.23%),其均匀预测准确率超过最好基准模型(Few-shot-CoT GPT-3)5.31%。值得注意的是,对于简直所有的问题类型、答案类型和问题难度,PromptPG 都展现出了其在预测准确率上的劣势。尽管如此,PromptPG 间隔人类 90.22% 的体现则还有很大的晋升空间。

3.4 融化试验

表 4 表明,TabMWP 的所有输出元素(问题文本、表格信息、选项信息)都对正确答复问题至关重要。只有所有的问题元素作为输出信息,Zero-shot GPT-3 才获得了其绝对最高的均匀预测准确率(59.50%)。

3.5 不同的示例抉择

作为比照试验,钻研人员还比拟了其余不同示例抉择的办法。如表 5 所示,抉择与测试问题雷同的题型或者答案类型能够帮忙模型找到更相干的示例,并进步答复的准确性。抉择最简单的示例则并不能稳固地进步答复准确性。在候选示例中固定抉择两个最好的示例,能够小幅度进步准确性,并升高方差。抉择语义上最靠近测试问题的示例能够达到最靠近 PromptPG 办法的准确性。总体来说,PromptPG 全面展示了其在晋升预测准确性和升高预测方差上的劣势。

下图展现了 PromptPG 抉择的示例以及最终的预测后果。能够看到,PromptPG 办法能够抉择与测试题目具备相似的数学能力的示例,从而进步 Few-shot GPT-3 的推理性能。

3.6 预测胜利的例子

以下展现了 PromptPG 对一个自在文本问题的正确答复。这个问题要求对表格中的八个数字别离进行加法和除法计算以失去平均值。

在如下的例子中,模型被要求了解一个税收报告,并计算扣税后的工资。

以下展现了 PromptPG 对多选题问题的正确预测。给定的表格一共有 9 行和 6 列。模型胜利地定位到了表格中的指标单元格,并进行多步推理以预测正确答案。

在以下的例子中,模型须要比拟估算和总成本,以验证 Ariana 是否有足够的钱。

3.7 预测失败的例子

以下展现了 PromptPG 对自在文本问题的谬误预测。模型检索到了谬误的玫瑰石英价格,从而错误计算了三个物品的老本总和。

在以下的例子中,问题提供了一个形象的茎叶表。模型无奈了解这个特定畛域的表格,并且不足高级逻辑推理能力从而失去了谬误的答案。

以下的例子表明,现有的模型仿佛不具备对数字排序的能力。

在以下的例子中,表格中没有呈现与问题提到的以后工夫完全一致的工夫,因而模型无奈精确定位到下一站的登程工夫。

以下的例子中,模型很难精确实现一长串数字的算术运算。

04  论断与瞻望

作者提出了 TabMWP,这是第一个针对表格语境的数学问题求解的大规模数据集。TabMWP 蕴含了 38,431 个凋谢畛域的问题,其中包含两种问题类型和五种答案类型,每个问题都标注了多步的解答过程。作者应用了最先进的 QA 和 TableQA 办法,在预训练和微调设置下对 TabMWP 进行了全面的试验,以及应用大型预训练语言模型 GPT-3 进行评估。作者进一步提出了一种全新的强化学习办法 PromptPG,该办法利用 Policy Gradient 学习从训练数据中抉择最优的实例用于提醒用于 GPT-3 模型。试验结果表明,与随机抉择相比,PromptPG 的性能显著优于现有的基线,并且缩小了预测中的性能不稳定性。

1. 书籍举荐 -《深度强化学习》

2. 书籍举荐 -《强化学习与最优控制》

3. 张俊林:由 ChatGPT 反思大语言模型(LLM)的技术精要

4. 复旦邱锡鹏:深度分析 ChatGPT 类大语言模型的关键技术

5. 谷歌复用 30 年前经典算法,CV 引入强化学习,网友:视觉 RLHF 要来了?

正文完
 0