AI解數(shù)學(xué)題只靠最后一個(gè)token
大語言模型在解心算題時(shí),只依賴最后一個(gè)token?

最近,來自加州大學(xué)圣克魯茲分校、喬治·梅森大學(xué)和Datadog的研究人員發(fā)現(xiàn):在心算任務(wù)中,幾乎所有實(shí)際的數(shù)學(xué)計(jì)算都集中在序列的最后一個(gè)token上完成,而不是分散在所有token中。
這意味著,相較于在Transformer和多層感知機(jī)(MLP)中常見的全局信息訪問——即每個(gè)token在預(yù)測(cè)時(shí)都能查詢并利用整個(gè)上文信息——在諸如心算這樣的特定任務(wù)中,全局訪問其實(shí)并不是必需的。
這是怎么一回事?
心算只要最后一個(gè)token?!
總的來說,研究人員采用了上下文感知平均消融(Context-Aware Mean Ablation, CAMA)和基于注意力的窺視(attention-based peeking)技術(shù)對(duì)Llama-3-8B等Transformer架構(gòu)的模型進(jìn)行了一系列的消融實(shí)驗(yàn)。
這些實(shí)驗(yàn)通過系統(tǒng)性地移除或改變模型的一部分,探究能讓模型依然表現(xiàn)良好的“最少計(jì)算量”。
在這一過程中,研究人員發(fā)現(xiàn)模型內(nèi)部會(huì)形成一個(gè)稀疏子圖(sparse subgraph)——他們把它稱為“人人為我”(All-for-One, AF1)。
這個(gè)子圖通過最少的計(jì)算層和最有限的信息傳遞,讓模型高效完成運(yùn)算。

在“人人為我”中,輸入Transformer前幾層(L_wait)的token并沒有做跟“自己數(shù)值”相關(guān)的計(jì)算,而是“等待”,并主要承擔(dān)一些通用的準(zhǔn)備工作(比如識(shí)別token、結(jié)構(gòu)編碼、預(yù)測(cè)下一步所需的通用表示)。
然后,在中間的兩層(L_transfer)里,它們就將信息傳遞給最后一個(gè)token。
之后,最后一個(gè)token獨(dú)自完成計(jì)算并給出答案。
這一過程表明,模型內(nèi)部將任務(wù)通用型計(jì)算(如 token 識(shí)別、數(shù)值與結(jié)構(gòu)編碼)與輸入特定型計(jì)算(如實(shí)際算術(shù)運(yùn)算)是分開的。
(注:這篇研究聚焦于心算任務(wù),即涉及兩個(gè)或三個(gè)操作數(shù)的算術(shù)問題(例如42+20?15),這些問題可以通過單個(gè)token的輸出解決,而無需模型進(jìn)行顯式的鏈?zhǔn)剿季S推理。)
接下來,我們具體來看。
眾所周知,大語言模型在許多計(jì)算任務(wù)上表現(xiàn)出色,而其中一個(gè)重要原因是其采用了Transformer架構(gòu)。
與RNN不同,Transformer允許任意token通過自注意力機(jī)制立即訪問所有先前的token以傳遞信息,并使每個(gè)token能夠通過多層感知機(jī)(MLP)并行執(zhí)行各自的獨(dú)立計(jì)算。
但即便如此,模型內(nèi)部的信息流和計(jì)算過程仍然是不透明的。
因此,為了揭開大語言模型的“黑箱”,研究人員采用了以下三個(gè)步驟來進(jìn)行探索。
首先,在模型的初始層抑制token針對(duì)特定輸入的計(jì)算。
研究人員發(fā)現(xiàn),在傳統(tǒng)Transformer的每一層中,token都能訪問所有之前的token,但對(duì)于簡(jiǎn)單的心算任務(wù),每個(gè)token可能未必從一開始就要獲得全局信息。
由此,研究人員引入了等待期(L_wait):讓在前面的L_wait層中的token獨(dú)立計(jì)算,只執(zhí)行任務(wù)通用操作(如理解數(shù)字、識(shí)別算術(shù)結(jié)構(gòu)),而不訪問其他token。

為了實(shí)現(xiàn)這一點(diǎn),他們使用了上下文感知平均消融(CAMA)。
CAMA的作用是屏蔽掉token之間的輸入特定信息,同時(shí)保留每個(gè)token的普遍計(jì)算能力,使模型能夠在不依賴具體輸入的情況下完成基礎(chǔ)準(zhǔn)備工作。

接下來,在少數(shù)幾層中限制跨token位置的信息傳遞路徑。只讓最后token在L_transfer層中訪問所有token,其余層只關(guān)注自己。
最后,在剩余的層中強(qiáng)制所有計(jì)算都在最后一個(gè)token上發(fā)生。
由于CAMA只能從第一層開始,因此,研究人員引入了一種可以在任意層控制信息訪問的方法——基于注意力的窺視(ABP)。
它通過修改注意力掩碼(attention mask),精確指定每個(gè)“查詢”(query)token可以關(guān)注哪些“鍵”(key)。
在這篇論文中,研究人員主要使用了以下兩種模式:
- 完全窺探 (Full-peeking): token可以關(guān)注所有在它之前的token,這是標(biāo)準(zhǔn)的因果注意力。在AF1的傳遞階段,最后一個(gè)token使用此模式來收集信息。
 - 自我窺探 (Self-peeking): token只能關(guān)注它自己,在傳遞和計(jì)算階段,所有非末尾的token都使用此模式;在計(jì)算階段,最后一個(gè)token也切換到此模式。
 
實(shí)驗(yàn)驗(yàn)證
在完成方法和操作流程的構(gòu)建后,研究者進(jìn)行了一系列實(shí)驗(yàn)來發(fā)現(xiàn)、驗(yàn)證和分析AF1子圖。這里主要涉及到Llama-3-8B和Llama-3.1-8B,以及在Pythia和GPT-J模型上的驗(yàn)證。
首先,通過三階段消融與窺視實(shí)驗(yàn),研究人員發(fā)現(xiàn)Llama-3-8B在A+B+C任務(wù)中只需前14層做任務(wù)通用計(jì)算(CAMA 層),然后通過2層信息傳輸讓最后的token獲取全局信息,剩余層僅進(jìn)行最后token的自計(jì)算。

這個(gè)幾乎保留全部性能的子圖被命名為AF1_llama。
接下來,研究人員又進(jìn)一步驗(yàn)證了AF1_llama在Llama-3-8B和Llama-3.1-8B上的表現(xiàn)。
實(shí)驗(yàn)表明,AF1_llama在八個(gè)任務(wù)中總體表現(xiàn)出高忠實(shí)度。

更進(jìn)一步,實(shí)驗(yàn)進(jìn)一步驗(yàn)證了第15和16層的信息傳輸在Llama-3-8B中的重要性。
研究表明,僅少數(shù)注意力頭對(duì)算術(shù)計(jì)算關(guān)鍵,即使移除近60個(gè)頭部,模型仍能保持約95%的準(zhǔn)確率,表明大部分注意力頭冗余,而關(guān)鍵頭集中在少數(shù)層。

此外,為了探究AF1_llama是否可以在Llama-3-8B上泛化到表示A+B和A?B運(yùn)算的其他算術(shù)形式,研究進(jìn)一步將口頭描述運(yùn)算以及將運(yùn)算嵌入到應(yīng)用題或Python代碼中。
實(shí)驗(yàn)表明,AF1_llama在不包含額外語義上下文的直接算術(shù)任務(wù)中仍保持了相當(dāng)高的準(zhǔn)確率。
然而,它在需要語義理解的任務(wù)上,如應(yīng)用題和Python代碼,完全失敗了,這表明它需要額外的組件來處理其他能力,比如理解自然語言或Python程序輸入。

最后,研究人員在Pythia和GPT-J中也發(fā)現(xiàn)了類似AF1的子圖,但與Llama不同,這些模型的等待期更短(L_wait ≈ 9–11)、信息傳輸層更長(zhǎng),且性能邊界不如Llama清晰。
盡管忠實(shí)度普遍低于Llama,但對(duì)二元運(yùn)算任務(wù)的子圖仍能恢復(fù)超過一半的原始模型準(zhǔn)確率。

總體而言,這項(xiàng)工作為大語言模型中的算術(shù)推理和跨token計(jì)算的機(jī)制理解做出了貢獻(xiàn)。此外,它通過CAMA和ABP提供了方法論上的創(chuàng)新,可服務(wù)于算術(shù)任務(wù)之外的更廣泛應(yīng)用。















 
 
 












 
 
 
 