拯救Transformer推理能力!DeepMind新研究TransNAR:給模型嵌入「算法推理大腦」
如今的NLP領(lǐng)域,已然是Transformer架構(gòu)的天下。
從Bert到GPT,再到Llama、Claude,LLM模型使用Transformer已經(jīng)是再正常不過(guò)的事情。
Transformer的「大一統(tǒng)」局面正是由于其簡(jiǎn)單、高效的架構(gòu),以及在理解自然語(yǔ)言方面無(wú)與倫比的泛化能力。
然而,隨著研究的逐漸深入,Transformer的一個(gè)致命缺陷也逐漸暴露出來(lái)——無(wú)法勝任算法推理任務(wù),尤其是不能進(jìn)行精確、穩(wěn)健的推理。

這嚴(yán)重限制了模型在數(shù)學(xué)、代碼等領(lǐng)域下游任務(wù)的應(yīng)用,近年來(lái)對(duì)Transformer的各種調(diào)優(yōu)、修改似乎也收效甚微。
于是DeepMind的研究人員想到了混合架構(gòu)——將Transformers的語(yǔ)言理解能力與基于圖神經(jīng)網(wǎng)絡(luò)(GNN)的神經(jīng)算法推理器(NAR)的穩(wěn)健性結(jié)合起來(lái),提升其算法推理能力。
他們最近在arxiv上的一篇論文就提出了這個(gè)名為T(mén)ransNAR的架構(gòu),但遺憾的是,目前還沒(méi)有公布源代碼。

論文地址:https://arxiv.org/abs/2406.09308
神經(jīng)算法推理(NAR)由本文作者之一Petar Veleckovic在2021年與人合著的一篇論文中提出,并被接收為Patterns期刊的opinion paper。

論文地址:https://arxiv.org/abs/2105.02761
NAR被稱(chēng)為「構(gòu)建能執(zhí)行算法的神經(jīng)網(wǎng)絡(luò)的藝術(shù)」。作者提出,算法與深度學(xué)習(xí)的本質(zhì)不同,但如果神經(jīng)網(wǎng)絡(luò)能夠更好地模仿算法,它甚至可能具備算法的強(qiáng)泛化性。
更進(jìn)一步,神經(jīng)網(wǎng)絡(luò)若能表示出算法中連續(xù)空間內(nèi)的元素,就會(huì)使已知算法更接近現(xiàn)實(shí)世界的問(wèn)題,提出的解決方案可能超過(guò)人類(lèi)科學(xué)家。

如上圖所示,NAR的整體想法是訓(xùn)練出一個(gè)高維隱空間中的處理器網(wǎng)絡(luò)P(processor network),旨在不斷逼近算法的運(yùn)行結(jié)果A(x)。
但由于算法的輸入和輸出一般是圖、樹(shù)、矩陣等抽象、結(jié)構(gòu)化的形式,這與深度學(xué)習(xí)模型高維、嘈雜且多變的輸入很不兼容,因此還需要訓(xùn)練編碼器f和解碼器g,將抽象形式轉(zhuǎn)換為自然形式。
NAR發(fā)布后,有多項(xiàng)研究證實(shí)了它有同時(shí)執(zhí)行多種算法的能力,也能部署在各種下游任務(wù)中。更重要的是,它的泛化能力似乎遠(yuǎn)遠(yuǎn)優(yōu)于Transformer架構(gòu)。
原則上,NAR可以擴(kuò)展到比訓(xùn)練數(shù)據(jù)的分布大幾個(gè)數(shù)量級(jí)的系統(tǒng)上,有時(shí)這個(gè)數(shù)量級(jí)能達(dá)到1.8萬(wàn)倍。
在使用適當(dāng)?shù)臍w納偏差(inductive biases)時(shí),即使輸入比訓(xùn)練集大6倍,NAR也能在高度復(fù)雜的算法任務(wù)中保持完美的泛化能力。
找到了Transformer和NAR這兩種十分強(qiáng)大且各有所長(zhǎng)的架構(gòu),下面最關(guān)鍵的問(wèn)題就是如何進(jìn)行相應(yīng)的調(diào)整和修改,使這兩個(gè)似乎完全不相容的模型真正實(shí)現(xiàn)溝通和Embedding交換。
TransNAR:用預(yù)訓(xùn)練NAR增強(qiáng)Transformer
如何實(shí)現(xiàn)NAR+Transformer的有效溝通?作者從多模態(tài)LLM中找到了靈感。
多模態(tài)LLM可以同時(shí)接收文本和圖像兩種模態(tài)的輸入,TransNAR也是如此。一邊是算法運(yùn)行需要的圖結(jié)構(gòu),一邊是描述問(wèn)題的自然語(yǔ)言。
作者的設(shè)想是,將預(yù)訓(xùn)練的NAR作為T(mén)ransformer中編碼的調(diào)制器(modulator),二者通過(guò)embedding溝通,同時(shí)借鑒VLM和Flamingo模型中所用的交叉注意算子,融合不同模態(tài)的信息。

TransNAR接受雙重輸入,包括文本形式的算法問(wèn)題規(guī)范(T個(gè)token)及其對(duì)應(yīng)的圖表征(N個(gè)節(jié)點(diǎn)),并輸出問(wèn)題的文本答案。其中輸入的圖表征遵循算法推理基準(zhǔn)CLRS-30的格式。
我們可以假設(shè),編碼完成后,文本輸入存儲(chǔ)在T ∈ R^(T×k)中,圖輸入存儲(chǔ)在G ∈ R^(N×l)中。

TransNAR的前向傳播過(guò)程如下:
首先,我們通過(guò)設(shè)置T^(0) = T和G^(0) = G來(lái)正確初始化輸入。
接下來(lái),為了計(jì)算第(t+1)步的表征,文本(token)表征被輸入到Transformer的當(dāng)前層:

其中,Qt,Kt ∈ Rk×d_k,Vt ∈ Rk×k分別是鍵、查詢(xún)和值矩陣的變換,F(xiàn)FN是一個(gè)前饋神經(jīng)網(wǎng)絡(luò)。
以類(lèi)似的方式,圖表征被輸入到NAR層,例如實(shí)現(xiàn)一個(gè)標(biāo)準(zhǔn)的max-MPNN:

其中,ψ,? : Rk × Rk → Rk分別是可學(xué)習(xí)的消息函數(shù)和更新函數(shù),max是逐元素最大值聚合。
需要注意的是,方程2僅簡(jiǎn)要提供了節(jié)點(diǎn)之間的成對(duì)交互——實(shí)際上,這里的NAR是一個(gè)Triplet-GMPNN,它還包含三元組交互和一個(gè)門(mén)控機(jī)制。
此外,還需注意,NAR的可學(xué)習(xí)部分沒(méi)有時(shí)間步索引——每一步都應(yīng)用相同的共享函數(shù)。這很好地契合了圖算法計(jì)算的迭代和重復(fù)性質(zhì)。
一旦兩個(gè)流都準(zhǔn)備好它們的表征Θt+1和Gt+1,圖中的節(jié)點(diǎn)嵌入將對(duì)Transformer的token嵌入進(jìn)行條件設(shè)置,從而產(chǎn)生Transformer流中TransNAR塊的最終結(jié)果:

其中,Qt×,Kt× ∈ Rk×d_k, Vtx ∈ Rk×k分別是交叉注意力的鍵、查詢(xún)和值變換。在結(jié)束這一層之前,對(duì)Gt+1不進(jìn)行額外的變換。
這個(gè)過(guò)程會(huì)一直重復(fù),直到最后的第Nl層,在這一層中,從TN_l讀取最終的文本輸出。
最終輸出通過(guò)最后一層生成的預(yù)測(cè)頭轉(zhuǎn)換為token logits,并通過(guò)標(biāo)準(zhǔn)的下一個(gè)token預(yù)測(cè)來(lái)監(jiān)督訓(xùn)練。
在開(kāi)始TransNAR微調(diào)之前,首先預(yù)訓(xùn)練NAR,使其能夠穩(wěn)健地執(zhí)行CLRS-30覆蓋的三十個(gè)算法。這種方法已知可以在圖空間中實(shí)現(xiàn)高達(dá)4倍輸入規(guī)模的分布外泛化。
在微調(diào)過(guò)程中,NAR的參數(shù)通常保持凍結(jié)狀態(tài),因?yàn)轭~外的梯度會(huì)削弱模型的原有穩(wěn)健性特性。同樣的原因,圖嵌入不會(huì)執(zhí)行交叉注意力。
LLM本身可以在大規(guī)模數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練,以建立其一般語(yǔ)言先驗(yàn),即使在開(kāi)始時(shí)隨機(jī)初始化LM,也能獲得相同的實(shí)驗(yàn)結(jié)果。
實(shí)驗(yàn)設(shè)置
在實(shí)驗(yàn)中,作者展示了TransNAR為大語(yǔ)言模型架構(gòu)中的分布外推理帶來(lái)的顯著優(yōu)勢(shì)。
Transformer架構(gòu)和初始化
論文使用Chinchilla家族的一個(gè)decoder-only架構(gòu)、6層的Transformer模型,首先在MassiveText上進(jìn)行了預(yù)訓(xùn)練,參數(shù)量有70M,上下文大小為2048。
為了探究初始化設(shè)置的影響,作者設(shè)計(jì)了兩個(gè)變體進(jìn)行消融實(shí)驗(yàn)。
第一個(gè)變體中,Transformer權(quán)重用預(yù)訓(xùn)練的結(jié)果初始化,模擬微調(diào)場(chǎng)景;第二個(gè)變體則是完全隨機(jī)的初始化。這兩個(gè)模型分別被標(biāo)記為「預(yù)訓(xùn)練」和「未訓(xùn)練」。
隨機(jī)位置編碼
之前DeepMind的一篇論文論證過(guò),隨機(jī)位置編碼可以增強(qiáng)Transformer的長(zhǎng)度泛化與推理穩(wěn)健性。

論文地址:https://arxiv.org/abs/2305.16843
作者也提到,隨機(jī)位置嵌入確實(shí)在基線(xiàn)模型和TransNAR上都帶來(lái)了顯著增益,因此本文中的所有實(shí)驗(yàn)也都使用隨機(jī)位置嵌入。
預(yù)訓(xùn)練NAR
論文使用CLRS-30基準(zhǔn)中的問(wèn)題預(yù)訓(xùn)練了一個(gè)多任務(wù)、基于MPNN的NAR,輸入問(wèn)題規(guī)模最多達(dá)16個(gè)。
由于CLRS-30的標(biāo)準(zhǔn)圖結(jié)構(gòu)表達(dá),這樣訓(xùn)練出來(lái)的NAR有很強(qiáng)的分布外(OOD)泛化能力,有時(shí)在4倍大小的圖上仍保持競(jìng)爭(zhēng)力,這種豐富的知識(shí)表達(dá)正是文本模型可資利用的。
結(jié)合節(jié)點(diǎn)和邊緣的跨注意力貢獻(xiàn)
在上述的算法描述中,我們將NAR模型的圖輸入限于N個(gè)節(jié)點(diǎn),但作者注意到了之前的研究曾嘗試過(guò),同時(shí)對(duì)圖的節(jié)點(diǎn)和邊生成隱變量表達(dá),也許可以添加有用的互補(bǔ)信息。
于是實(shí)驗(yàn)中引入圖中邊的特征E(t) ∈ RN×N×k,并再次應(yīng)用公式3讓?duì)?t)對(duì)E(t)進(jìn)行交叉注意力。
作者也嘗試其他方法,希望將E(t)和G(t)結(jié)合起來(lái),比如拼接后加線(xiàn)性層組合、向量求和、2層MLP,或者用Gram-Schmidt過(guò)程使二者的貢獻(xiàn)正交化,但這些都沒(méi)有給原始方法帶來(lái)提升。
數(shù)據(jù)集
訓(xùn)練數(shù)據(jù)使用CLRS-Text基準(zhǔn),即CLRS-30基準(zhǔn)的文本版本,以確定性的方式直接從基于圖的CLRS-30中派生,因此這兩個(gè)數(shù)據(jù)集傳達(dá)的是完全相同的信息。
表1展示了該數(shù)據(jù)集的幾個(gè)樣本,以及它們的輸入大小和token數(shù)量。
由于語(yǔ)言模型上下文長(zhǎng)度的限制,實(shí)驗(yàn)選擇用規(guī)模為4、8、12的問(wèn)題訓(xùn)練,并在規(guī)模為110、12、14的問(wèn)題上評(píng)估。
值得注意的是,與當(dāng)前的評(píng)估環(huán)境相比,CLRS-Text是對(duì)LM最具挑戰(zhàn)性的長(zhǎng)程推理任務(wù)之一——相比小學(xué)數(shù)學(xué),復(fù)雜度顯著提高。
CLRS-Text的挑戰(zhàn)性主要源于它允許顯式控制分布外泛化。然而,每個(gè)問(wèn)題都有清晰的多項(xiàng)式時(shí)間解法,這意味當(dāng)今典型LLM的參數(shù)量應(yīng)該足以解決這些問(wèn)題。
該數(shù)據(jù)集每種算法的每種輸入規(guī)模包含一萬(wàn)個(gè)樣本,總共240萬(wàn)個(gè)數(shù)據(jù)點(diǎn),其中70%用于訓(xùn)練、30%用于驗(yàn)證。

訓(xùn)練細(xì)節(jié)
實(shí)驗(yàn)將batch大小設(shè)置為256訓(xùn)練了7個(gè)epoch,并使用Adam優(yōu)化器,學(xué)習(xí)率為10-4。
如前所述,在所有Chinchilla Transformer的旋轉(zhuǎn)位置編碼(RoPE)之上應(yīng)用隨機(jī)位置編碼,最大長(zhǎng)度為8192,且訓(xùn)練期間保持NAR凍結(jié)。
評(píng)估指標(biāo)
作者提出,合適的評(píng)估指標(biāo)應(yīng)該反映模型在特定樣本上失敗的原因,且需要度量型輸出與正確答案的接近程度。因此,使用精確字符串匹配來(lái)計(jì)算模型準(zhǔn)確性是絕對(duì)不可行的。
論文選擇的性能指標(biāo)包括以下三個(gè):
1. 形狀分?jǐn)?shù):一個(gè)二元指標(biāo),用于判斷輸出是否具有正確的形狀。例如,在排序任務(wù)中,輸出應(yīng)與輸入有完全相同的元素?cái)?shù)量?;蛘?,如果輸出是一個(gè)矩陣,我們需要確保其形狀與輸入和任務(wù)一致。
2. 解析分?jǐn)?shù):一個(gè)二元指標(biāo),用于判斷輸出是否不含任何非法字符。例如,在對(duì)數(shù)字列表進(jìn)行排序的任務(wù)中,輸出不應(yīng)包含任何字母。
3. CLRS分?jǐn)?shù):輸出中與真實(shí)答案匹配的元素百分比,也常用于CLRS-30測(cè)試。形狀分?jǐn)?shù)為0時(shí),CLRS分?jǐn)?shù)也會(huì)自動(dòng)置零。
這種多方面的指標(biāo)設(shè)計(jì)能夠捕捉到LLM在文本上進(jìn)行推理任務(wù)的各種失敗模式。
比如在某個(gè)問(wèn)題規(guī)模上過(guò)度專(zhuān)門(mén)化訓(xùn)練(導(dǎo)致輸出的形狀不正確)、無(wú)法處理看不見(jiàn)的數(shù)字組合(導(dǎo)致解析錯(cuò)誤),由于推理錯(cuò)誤造成的答案不一致則由CLRS分?jǐn)?shù)反映。
結(jié)果
實(shí)驗(yàn)結(jié)果顯示,TransNAR整體上顯著優(yōu)于Transformer模型,在動(dòng)態(tài)規(guī)劃、幾何、圖、貪心算法、排序、字符串等任務(wù)上的OOD推理能力都有大幅提升。

并且在大多數(shù)單個(gè)算法上,無(wú)論是在分布內(nèi)還是分布外都表現(xiàn)更佳。
特別值得注意的是,這種方法不僅增強(qiáng)了Transformer原有的OOD泛化能力,還激發(fā)了一些模型先前完全不具備的能力。
比如Graham掃描(graham_scan)、最長(zhǎng)公子串長(zhǎng)度(lcs_length)、強(qiáng)連通分量(scc)這些經(jīng)典問(wèn)題中,基線(xiàn)模型得分為零或接近零,但TransNAR卻實(shí)現(xiàn)了突破。

分析形狀分?jǐn)?shù)可以進(jìn)一步解釋?zhuān)瑸槭裁碩ransNAR表現(xiàn)如此出色。

首先,回顧一下,如果形狀不匹配,CLRS得分必然為零。
從形狀得分來(lái)看,將Transformer的輸出建立在NAR嵌入基礎(chǔ)上顯著提高了答案中形狀正確的比例——這表明TransNAR緩解了一種特定的LLM故障模式。
此外,通過(guò)對(duì)比「預(yù)訓(xùn)練」和「未訓(xùn)練」兩種初始化方式的分?jǐn)?shù),可以看到模型較好的穩(wěn)定性和可用性。在隨機(jī)初始化時(shí),也能訓(xùn)練到與微調(diào)相當(dāng)?shù)乃疁?zhǔn)。
然而,在一些算法中,TransNAR仍未能超越基線(xiàn),且在分布內(nèi)和分布外都是如此。
這些算法包括二分搜索、尋找最大子數(shù)組、最小值和快速選擇等,都涉及在輸入列表中按照索引搜索特定元素。
這暗示了TransNAR的一種故障模式:模型無(wú)法泛化到訓(xùn)練數(shù)據(jù)中未見(jiàn)過(guò)的新索引邊界。因此,使用索引提示或許是一條有前景的改進(jìn)途徑。
另一種可能的解釋是,NAR最終計(jì)算出的隱藏狀態(tài)難以在交叉注意力層以可泛化的方式被解碼。如果原因在此,解決途徑可以是增加交叉注意力的容量,或者采用漸進(jìn)式解碼。
此外,TransNAR在架構(gòu)上有一個(gè)本質(zhì)的局限性,就是必需一個(gè)能得出ground truth的模擬器或者數(shù)據(jù)標(biāo)簽,用于將輸入的文本轉(zhuǎn)換為圖結(jié)構(gòu),再作為模型輸入。
但是作者強(qiáng)調(diào),TransNAR的概念對(duì)于未來(lái)研究是有借鑒意義的??梢钥紤]將這種混合架構(gòu)的想法移植到單模態(tài)LLM,或者將TransNAR訓(xùn)練后獲得的知識(shí)提煉出來(lái)注入到普通的Transformer中。




































