乐趣区

关于人工智能:了解-Transformers-是如何思考的

Transformer 模型是 AI 零碎的根底。曾经有了数不清的对于 “Transformer 如何工作 ” 的外围构造图表。

然而这些图表没有提供任何直观的计算该模型的框架示意。当研究者对于 Transformer 如何工作抱有趣味时,直观的获取他运行的机制变得非常有用。

Thinking Like Transformers 这篇论文中提出了 transformer 类的计算框架,这个框架间接计算和模拟 Transformer 计算。应用 RASP 编程语言,使每个程序编译成一个非凡的 Transformer。

在这篇博客中,我用 Python 复现了 RASP 的变体 (RASPy)。该语言大抵与原始版本相当,然而多了一些我认为很乏味的变动。通过这些语言,作者 Gail Weiss 的工作,提供了一套具备挑战性的乏味且正确的形式能够帮忙理解其工作原理。

!pip install git+https://github.com/srush/RASPy

在说起语言自身前,让咱们先看一个例子,看看用 Transformers 编码是什么样的。这是一些计算翻转的代码,即反向输出序列。代码自身用两个 Transformer 层利用 attention 和数学计算达到这个后果。

def flip():
    length = (key(1) == query(1)).value(1)
    flip = (key(length - indices - 1) == query(indices)).value(tokens)
    return flip
flip()

文章目录

  • 第一局部:Transformers 作为代码
  • 第二局部:用 Transformers 编写程序

Transformers 作为代码

咱们的指标是定义一套计算模式来最小化 Transformers 的表白。咱们将通过类比,形容每个语言结构及其在 Transformers 中的对应。(正式语言标准请在本文底部查看论文全文链接)。

这个语言的外围单元是将一个序列转换成雷同长度的另一个序列的序列操作。我前面将其称之为 transforms。

输出

在一个 Transformer 中,根本层是一个模型的前馈输出。这个输出通常蕴含原始的 token 和地位信息。

在代码中,tokens 的特色示意最简略的 transform,它返回通过模型的 tokens,默认输出序列是 “hello”:

tokens

如果咱们想要扭转 transform 里的输出,咱们应用输出办法进行传值。

tokens.input([5, 2, 4, 5, 2, 2])

作为 Transformers,咱们不能间接承受这些序列的地位。然而为了模仿地位嵌入,咱们能够获取地位的索引:

indices
sop = indices
sop.input("goodbye")

前馈网络

通过输出层后,咱们达到了前馈网络层。在 Transformer 中,这一步能够对于序列的每一个元素独立的利用数学运算。

在代码中,咱们通过在 transforms 上计算示意这一步。在每一个序列的元素中都会进行独立的数学运算。

tokens == "l"

后果是一个新的 transform,一旦重构新的输出就会依照重构形式计算:

model = tokens * 2 - 1
model.input([1, 2, 3, 5, 2])

该运算能够组合多个 Transforms,举个例子,以上述的 token 和 indices 为例,这里能够类别 Transformer 能够跟踪多个片段信息:

model = tokens - 5 + indices
model.input([1, 2, 3, 5, 2])
(tokens == "l") | (indices == 1)

咱们提供了一些辅助函数让写 transforms 变得更简略,举例来说,where 提供了一个相似 if 性能的构造。

where((tokens == "h") | (tokens == "l"), tokens, "q")

map 使咱们能够定义本人的操作,例如一个字符串以 int 转换。(用户应审慎应用能够应用的简略神经网络计算的操作)

atoi = tokens.map(lambda x: ord(x) - ord('0'))
atoi.input("31234")

函数 (functions) 能够容易的形容这些 transforms 的级联。举例来说,上面是利用了 where 和 atoi 和加 2 的操作

def atoi(seq=tokens):
    return seq.map(lambda x: ord(x) - ord('0')) 

op = (atoi(where(tokens == "-", "0", tokens)) + 2)
op.input("02-13")

注意力筛选器

到开始利用注意力机制事件就变得开始乏味起来了。这将容许序列间的不同元素进行信息替换。

咱们开始定义 key 和 query 的概念,Keys 和 Queries 能够间接从下面的 transforms 创立。举个例子,如果咱们想要定义一个 key 咱们称作 key

key(tokens)

对于 query 也一样

query(tokens)

标量能够作为 keyquery 应用,他们会播送到根底序列的长度。

query(1)

咱们创立了筛选器来利用 key 和 query 之间的操作。这对应于一个二进制矩阵,批示每个 query 要关注哪个 key。与 Transformers 不同,这个注意力矩阵未退出权重。

eq = (key(tokens) == query(tokens))
eq

一些例子:

  • 选择器的匹配地位偏移 1:
offset = (key(indices) == query(indices - 1))
offset
  • key 早于 query 的选择器:
before = key(indices) < query(indices)
before
  • key 晚于 query 的选择器:
after = key(indices) > query(indices)
after

选择器能够通过布尔操作合并。比方,这个选择器将 before 和 eq 做合并,咱们通过在矩阵中蕴含一对键和值来显示这一点。

before & eq

应用注意力机制

给一个注意力选择器,咱们能够提供一个序列值做聚合操作。咱们通过累加那些选择器选过的真值做聚合。

(请留神:在原始论文中,他们应用一个均匀聚合操作并且展现了一个奇妙的构造,其中均匀聚合可能代表总和计算。RASPy 默认状况下应用累加来使其简单化并防止碎片化。实际上,这意味着 raspy 可能低估了所须要的层数。基于平均值的模型可能须要这个层数的两倍)

留神聚合操作使咱们可能计算直方图之类的性能。

(key(tokens) == query(tokens)).value(1)

视觉上咱们遵循图表构造,Query 在右边,Key 在上边,Value 在上面,输入在左边

一些注意力机制操作甚至不须要用到输出 token。举例来说,去计算序列长度,咱们创立一个 ” select all ” 的注意力筛选器并且给他赋值。

length = (key(1) == query(1)).value(1)
length = length.name("length")
length

这里有更多简单的例子,上面将一步一步展现。(这有点像做采访一样)

咱们想要计算一个序列的相邻值的和,首先咱们向前截断:

WINDOW=3
s1 = (key(indices) >= query(indices - WINDOW + 1))  
s1

而后咱们向后截断:

s2 = (key(indices) <= query(indices))
s2

两者相交:

sel = s1 & s2
sel

最终聚合:

sum2 = sel.value(tokens) 
sum2.input([1,3,2,2,2])

这里有个能够计算累计求和的例子,咱们这里引入一个给 transform 命名的能力来帮忙你调试。

def cumsum(seq=tokens):
    x = (before | (key(indices) == query(indices))).value(seq)
    return x.name("cumsum")
cumsum().input([3, 1, -2, 3, 1])

这个语言反对编译更加简单的 transforms。他同时通过跟踪每一个运算操作计算层。

这里有个 2 层 transform 的例子,第一个对应于计算长度,第二个对应于累积总和。

x = cumsum(length - indices)
x.input([3, 2, 3, 5])

用 transformers 进行编程

应用这个函数库,咱们能够编写实现一个简单工作,Gail Weiss 给过我一个极其挑战的问题来突破这个步骤:咱们能够加载一个增加任意长度数字的 Transformer 吗?

例如:给一个字符串 “19492+23919”, 咱们能够加载正确的输入吗?

如果你想本人尝试,咱们提供了一个 版本 你能够本人试试。

挑战一:抉择一个给定的索引

加载一个在索引 i 处全元素都有值的序列

def index(i, seq=tokens):
    x = (key(indices) == query(i)).value(seq)
    return x.name("index")
index(1)

挑战二:转换

通过 i 地位将所有 token 挪动到右侧。

def shift(i=1, default="_", seq=tokens):
    x = (key(indices) == query(indices-i)).value(seq, default)
    return x.name("shift")
shift(2)

挑战三:最小化

计算序列的最小值。(这一步开始变得艰难,咱们版本用了 2 层注意力机制)

def minimum(seq=tokens):
    sel1 = before & (key(seq) == query(seq))
    sel2 = key(seq) < query(seq)
    less = (sel1 | sel2).value(1)
    x = (key(less) == query(0)).value(seq)
    return x.name("min")
minimum()([5,3,2,5,2])

挑战四:第一索引

计算有 token q 的第一索引 (2 层)

def first(q, seq=tokens):
    return minimum(where(seq == q, indices, 99))
first("l")

挑战五:右对齐

右对齐一个填充序列。例:”ralign().inputs('xyz___') ='—xyz'” (2 层)

def ralign(default="-", sop=tokens):
    c = (key(sop) == query("_")).value(1)
    x = (key(indices + c) == query(indices)).value(sop, default)
    return x.name("ralign")
ralign()("xyz__")

挑战六:拆散

把一个序列在 token “v” 处拆散成两局部而后右对齐 (2 层):

def split(v, i, sop=tokens):

    mid = (key(sop) == query(v)).value(indices)
    if i == 0:
        x = ralign("0", where(indices < mid, sop, "_"))
        return x
    else:
        x = where(indices > mid, sop, "0")
        return x
split("+", 1)("xyz+zyr")
split("+", 0)("xyz+zyr")

挑战七:滑动

将非凡 token “<” 替换为最靠近的 “<” value (2 层):

def slide(match, seq=tokens):
    x = cumsum(match) 
    y = ((key(x) == query(x + 1)) & (key(match) == query(True))).value(seq)
    seq =  where(match, seq, y)
    return seq.name("slide")
slide(tokens != "<").input("xxxh<<<l")

挑战八:减少

你要执行两个数字的增加。这是步骤。

add().input("683+345")
  1. 分成两局部。转制成整形。退出

“683+345”=> [0, 0, 0, 9, 12, 8]

  1. 计算携带条款。三种可能性:1 个携带,0 不携带,< 兴许有携带。

[0, 0, 0, 9, 12, 8] =>“00<100”

  1. 滑动进位系数

“00<100”=> 001100″

  1. 实现加法

这些都是 1 行代码。残缺的零碎是 6 个注意力机制。(只管 Gail 说,如果你足够仔细则能够在 5 个中实现!)。

def add(sop=tokens):
    # 0) Parse and add
    x = atoi(split("+", 0, sop)) + atoi(split("+", 1, sop))
    # 1) Check for carries 
    carry = shift(-1, "0", where(x > 9, "1", where(x == 9, "<", "0")))
    # 2) In parallel, slide carries to their column                                         
    carries = atoi(slide(carry != "<", carry))
    # 3) Add in carries.                                                                                  
    return (x + carries) % 10
add()("683+345")
683 + 345
1028

完满搞定!

参考资料 & 文内链接:

  • 如果你对这个主题感兴趣想理解更多,请查看论文:Thinking Like Transformers
  • 以及理解更多 RASP 语言
  • 如果你对「形式语言和神经网络」(FLaNN) 感兴趣或者有意识感兴趣的人,欢送邀请他们退出咱们的 线上社区!
  • 本篇博文,蕴含库、Notebook 和博文的内容
  • 本博客文章由 Sasha Rush 和 Gail Weiss 独特编写

英文原文:Thinking Like Transformers

译者:innovation64 (李洋)

退出移动版