Transformer 模型是 AI 系統(tǒng)的基礎(chǔ)。已經(jīng)有了數(shù)不清的關(guān)于 "Transformer 如何工作" 的核心結(jié)構(gòu)圖表。
但是這些圖表沒有提供任何直觀的計(jì)算該模型的框架表示。當(dāng)研究者對(duì)于 Transformer 如何工作抱有興趣時(shí),直觀的獲取他運(yùn)行的機(jī)制變得十分有用。
Thinking Like Transformers 這篇論文中提出了 transformer 類的計(jì)算框架,這個(gè)框架直接計(jì)算和模仿 Transformer 計(jì)算。使用 RASP 編程語(yǔ)言,使每個(gè)程序編譯成一個(gè)特殊的 Transformer。
在這篇博客中,我用 python 復(fù)現(xiàn)了 RASP 的變體 (RASPy)。該語(yǔ)言大致與原始版本相當(dāng),但是多了一些我認(rèn)為很有趣的變化。通過這些語(yǔ)言,作者 Gail Weiss 的工作,提供了一套具有挑戰(zhàn)性的有趣且正確的方式可以幫助了解其工作原理。
!pip?install?git+https://github.com/srush/RASPy
在說起語(yǔ)言本身前,讓我們先看一個(gè)例子,看看用 Transformers 編碼是什么樣的。這是一些計(jì)算翻轉(zhuǎn)的代碼,即反向輸入序列。代碼本身用兩個(gè) Transformer 層應(yīng)用 attention 和數(shù)學(xué)計(jì)算到達(dá)這個(gè)結(jié)果。
def?flip(): ????length?=?(key(1)?==?query(1)).value(1) ????flip?=?(key(length?-?indices?-?1)?==?query(indices)).value(tokens) ????return?flip flip()
?
?
本文內(nèi)容目錄
部分一:Transformers 作為代碼
部分二:用 Transformers 編寫程序
Transformers 作為代碼
我們的目標(biāo)是定義一套計(jì)算形式來最小化 Transformers 的表達(dá)。我們將通過類比,描述每個(gè)語(yǔ)言構(gòu)造及其在 Transformers 中的對(duì)應(yīng)。(正式語(yǔ)言規(guī)范請(qǐng)?jiān)诒疚牡撞坎榭凑撐娜逆溄?。
這個(gè)語(yǔ)言的核心單元是將一個(gè)序列轉(zhuǎn)換成相同長(zhǎng)度的另一個(gè)序列的序列操作。我后面將其稱之為 transforms。
輸入
在一個(gè) Transformer 中,基本層是一個(gè)模型的前饋輸入。這個(gè)輸入通常包含原始的 token 和位置信息。
在代碼中,tokens 的特征表示最簡(jiǎn)單的 transform,它返回經(jīng)過模型的 tokens,默認(rèn)輸入序列是 "hello":
tokens
?
?
如果我們想要改變 transform 里的輸入,我們使用輸入方法進(jìn)行傳值。
tokens.input([5,?2,?4,?5,?2,?2])
作為 Transformers,我們不能直接接受這些序列的位置。但是為了模擬位置嵌入,我們可以獲取位置的索引:
indices
sop?=?indices sop.input("goodbye")
經(jīng)過輸入層后,我們到達(dá)了前饋網(wǎng)絡(luò)層。在 Transformer 中,這一步可以對(duì)于序列的每一個(gè)元素獨(dú)立的應(yīng)用數(shù)學(xué)運(yùn)算。
在代碼中,我們通過在 transforms 上計(jì)算表示這一步。在每一個(gè)序列的元素中都會(huì)進(jìn)行獨(dú)立的數(shù)學(xué)運(yùn)算。
tokens?==?"l"
?
?
結(jié)果是一個(gè)新的 transform,一旦重構(gòu)新的輸入就會(huì)按照重構(gòu)方式計(jì)算:
model?=?tokens?*?2?-?1 model.input([1,?2,?3,?5,?2])
?
?
該運(yùn)算可以組合多個(gè) Transforms,舉個(gè)例子,以上述的 token 和 indices 為例,這里可以類別 Transformer 可以跟蹤多個(gè)片段信息:
model?=?tokens?-?5?+?indices model.input([1,?2,?3,?5,?2])
(tokens?==?"l")?|?(indices?==?1)
?
?
我們提供了一些輔助函數(shù)讓寫 transforms 變得更簡(jiǎn)單,舉例來說,where 提供了一個(gè)類似 if 功能的結(jié)構(gòu)。
where((tokens?==?"h")?|?(tokens?==?"l"),?tokens,?"q")
map 使我們可以定義自己的操作,例如一個(gè)字符串以 int 轉(zhuǎn)換。(用戶應(yīng)謹(jǐn)慎使用可以使用的簡(jiǎn)單神經(jīng)網(wǎng)絡(luò)計(jì)算的操作)
atoi?=?tokens.map(lambda?x:?ord(x)?-?ord('0'))
atoi.input("31234")
函數(shù) (functions) 可以容易的描述這些 transforms 的級(jí)聯(lián)。舉例來說,下面是應(yīng)用了 where 和 atoi 和加 2 的操作
def?atoi(seq=tokens): ????return?seq.map(lambda?x:?ord(x)?-?ord('0'))? op?=?(atoi(where(tokens?==?"-",?"0",?tokens))?+?2) op.input("02-13")
?
注意力篩選器
到開始應(yīng)用注意力機(jī)制事情就變得開始有趣起來了。這將允許序列間的不同元素進(jìn)行信息交換。
我們開始定義 key 和 query 的概念,Keys 和 Queries 可以直接從上面的 transforms 創(chuàng)建。舉個(gè)例子,如果我們想要定義一個(gè) key 我們稱作 key。
key(tokens)
?
?
對(duì)于 query 也一樣
query(tokens)
?
標(biāo)量可以作為 key 或 query 使用,他們會(huì)廣播到基礎(chǔ)序列的長(zhǎng)度。
query(1)
?
我們創(chuàng)建了篩選器來應(yīng)用 key 和 query 之間的操作。這對(duì)應(yīng)于一個(gè)二進(jìn)制矩陣,指示每個(gè) query 要關(guān)注哪個(gè) key。與 Transformers 不同,這個(gè)注意力矩陣未加入權(quán)重。
eq?=?(key(tokens)?==?query(tokens)) eq
一些例子:
選擇器的匹配位置偏移 1:
offset?=?(key(indices)?==?query(indices?-?1)) offset
key 早于 query 的選擇器:
before?=?key(indices)?
key 晚于 query 的選擇器:after?=?key(indices)?>?query(indices) after選擇器可以通過布爾操作合并。比如,這個(gè)選擇器將 before 和 eq 做合并,我們通過在矩陣中包含一對(duì)鍵和值來顯示這一點(diǎn)。
before?&?eq使用注意力機(jī)制
給一個(gè)注意力選擇器,我們可以提供一個(gè)序列值做聚合操作。我們通過累加那些選擇器選過的真值做聚合。
(請(qǐng)注意:在原始論文中,他們使用一個(gè)平均聚合操作并且展示了一個(gè)巧妙的結(jié)構(gòu),其中平均聚合能夠代表總和計(jì)算。RASPy 默認(rèn)情況下使用累加來使其簡(jiǎn)單化并避免碎片化。實(shí)際上,這意味著 raspy 可能低估了所需要的層數(shù)?;谄骄档哪P涂赡苄枰@個(gè)層數(shù)的兩倍)
注意聚合操作使我們能夠計(jì)算直方圖之類的功能。
(key(tokens)?==?query(tokens)).value(1)
視覺上我們遵循圖表結(jié)構(gòu),Query 在左邊,Key 在上邊,Value 在下面,輸出在右邊
一些注意力機(jī)制操作甚至不需要用到輸入 token 。舉例來說,去計(jì)算序列長(zhǎng)度,我們創(chuàng)建一個(gè) " select all " 的注意力篩選器并且給他賦值。
length?=?(key(1)?==?query(1)).value(1) length?=?length.name("length") length
這里有更多復(fù)雜的例子,下面將一步一步展示。(這有點(diǎn)像做采訪一樣)
我們想要計(jì)算一個(gè)序列的相鄰值的和,首先我們向前截?cái)?
WINDOW=3 s1?=?(key(indices)?>=?query(indices?-?WINDOW?+?1))?? s1然后我們向后截?cái)?
s2?=?(key(indices)?<=?query(indices)) s2
兩者相交:sel?=?s1?&?s2 sel
最終聚合:sum2?=?sel.value(tokens)? sum2.input([1,3,2,2,2])
這里有個(gè)可以計(jì)算累計(jì)求和的例子,我們這里引入一個(gè)給 transform 命名的能力來幫助你調(diào)試。
def?cumsum(seq=tokens): ????x?=?(before?|?(key(indices)?==?query(indices))).value(seq) ????return?x.name("cumsum") cumsum().input([3,?1,?-2,?3,?1])
層
這個(gè)語(yǔ)言支持編譯更加復(fù)雜的 transforms。他同時(shí)通過跟蹤每一個(gè)運(yùn)算操作計(jì)算層。
這里有個(gè) 2 層 transform 的例子,第一個(gè)對(duì)應(yīng)于計(jì)算長(zhǎng)度,第二個(gè)對(duì)應(yīng)于累積總和。
x?=?cumsum(length?-?indices) x.input([3,?2,?3,?5])
用 transformers 進(jìn)行編程
使用這個(gè)函數(shù)庫(kù),我們可以編寫完成一個(gè)復(fù)雜任務(wù),Gail Weiss 給過我一個(gè)極其挑戰(zhàn)的問題來打破這個(gè)步驟,我們可以加載一個(gè)添加任意長(zhǎng)度數(shù)字的 Transformer 嗎?
例如:?給一個(gè)字符串 "19492+23919", 我們可以加載正確的輸出嗎?
如果你想自己嘗試,我們提供了一個(gè)版本你可以自己試試:
https://colab.research.google.com/github/srush/raspy/blob/main/Blog.ipynb挑戰(zhàn)一 : 選擇一個(gè)給定的索引
加載一個(gè)在索引 i 處全元素都有值的序列
def?index(i,?seq=tokens): ????x?=?(key(indices)?==?query(i)).value(seq) ????return?x.name("index") index(1)
?
?
挑戰(zhàn)二 :轉(zhuǎn)換
通過 i 位置將所有 token 移動(dòng)到右側(cè)。
def?shift(i=1,?default="_",?seq=tokens): ????x?=?(key(indices)?==?query(indices-i)).value(seq,?default) ????return?x.name("shift") shift(2)
挑戰(zhàn)三 :最小化
計(jì)算序列的最小值。(這一步開始變得困難,我們版本用了 2 層注意力機(jī)制)
def?minimum(seq=tokens): ????sel1?=?before?&?(key(seq)?==?query(seq)) ????sel2?=?key(seq)?
挑戰(zhàn)四:第一索引
計(jì)算有 token q 的第一索引 (2 層)
def?first(q,?seq=tokens): ????return?minimum(where(seq?==?q,?indices,?99)) first("l")
?
挑戰(zhàn)五 :右對(duì)齊
右對(duì)齊一個(gè)填充序列。例:"ralign().inputs('xyz___') ='—xyz'" (2 層)
def?ralign(default="-",?sop=tokens): ????c?=?(key(sop)?==?query("_")).value(1) ????x?=?(key(indices?+?c)?==?query(indices)).value(sop,?default) ????return?x.name("ralign") ralign()("xyz__")
挑戰(zhàn)六:分離
把一個(gè)序列在 token "v" 處分離成兩部分然后右對(duì)齊 (2 層):
def?split(v,?i,?sop=tokens): ????mid?=?(key(sop)?==?query(v)).value(indices) ????if?i?==?0: ????????x?=?ralign("0",?where(indices??mid,?sop,?"0") ????????return?x split("+",?1)("xyz+zyr")
split("+",?0)("xyz+zyr")
?
?
挑戰(zhàn)七:滑動(dòng)
將特殊 token "<" 替換為最接近的 "<" value (2 層):
def?slide(match,?seq=tokens): ????x?=?cumsum(match)? ????y?=?((key(x)?==?query(x?+?1))?&?(key(match)?==?query(True))).value(seq) ????seq?=??where(match,?seq,?y) ????return?seq.name("slide") slide(tokens?!=?"<").input("xxxh<<
?
挑戰(zhàn)八:增加
你要執(zhí)行兩個(gè)數(shù)字的添加。這是步驟。
add().input("683+345")分成兩部分。轉(zhuǎn)制成整形。加入
?
“683+345” => [0, 0, 0, 9, 12, 8]
計(jì)算攜帶條款。三種可能性:1 個(gè)攜帶,0 不攜帶,< 也許有攜帶。
?
[0, 0, 0, 9, 12, 8] => “00<100”
滑動(dòng)進(jìn)位系數(shù)
?
“00<100” => 001100"
完成加法
這些都是 1 行代碼。完整的系統(tǒng)是 6 個(gè)注意力機(jī)制。(盡管 Gail 說,如果你足夠細(xì)心則可以在 5 個(gè)中完成?。?。
def?add(sop=tokens): ????#?0)?Parse?and?add ????x?=?atoi(split("+",?0,?sop))?+?atoi(split("+",?1,?sop)) ????#?1)?Check?for?carries? ????carry?=?shift(-1,?"0",?where(x?>?9,?"1",?where(x?==?9,?"<",?"0"))) ????#?2)?In?parallel,?slide?carries?to?their?column????????????????????????????????????????? ????carries?=?atoi(slide(carry?!=?"<",?carry)) ????#?3)?Add?in?carries.?????????????????????????????????????????????????????????????????????????????????? ????return?(x?+?carries)?%?10 add()("683+345")
683?+?3451028本博客文章由 Sasha Rush 和 Gail Weiss 共同編寫:??
英文原文: Thinking Like Transformers:?
https://srush.github.io/raspy/中文譯者:?innovation64 (李洋)
編輯:黃飛
?
評(píng)論
查看更多