不是RNN的鍋!清華團隊深入分析長上下文建模中的狀態(tài)崩潰,Mamba作者點贊
與Transformer相比,RNN模型的一大優(yōu)勢是應(yīng)對長序列的能力。
比如Mamba,內(nèi)部狀態(tài)大小始終保持不變,計算隨序列長度線性增長,吃得多,消化快。
理論雖如此,但實際情況卻是,目前的這些RNN模型在長上下文中的有效性并不能令人滿意。
為啥會這樣?空有效率但實際上能力不行?
近日,來自清華的研究團隊對此進行了深入的實驗研究:
論文地址:https://arxiv.org/pdf/2410.07145v1
文章表明,Mamba這類RNN模型在長上下文中主要面臨兩個問題:
一是無法推斷比訓(xùn)練長度更長的輸入,原因是較短的訓(xùn)練數(shù)據(jù)導(dǎo)致了循環(huán)狀態(tài)過擬合;
二是內(nèi)存容量的上限,由于模型無法有效遺忘很久以前的信息,導(dǎo)致新的信息存不進來了。
——這倆問題明顯不是RNN的鍋。
而經(jīng)過研究人員的對癥下藥,Mamba-2(370M)在256K上下文長度上達到了近乎完美的密鑰檢索精度。
所以結(jié)論就是,Mamba yes!「RNN神教」前景一片光明!
對此,Mamba的作者Albert Gu點贊轉(zhuǎn)發(fā),并發(fā)表了相當(dāng)詳細的見解:
「這是一篇很棒的論文(名字也很棒)—— 關(guān)于狀態(tài)空間模型(SSM)的狀態(tài)容量和長上下文能力的巧妙實驗?!?/span>
令人驚訝的是,對于每個狀態(tài)大小 M,當(dāng)訓(xùn)練上下文長度達到或超過某個臨界值 K 時,都會出現(xiàn)一個轉(zhuǎn)折點,在這個點上 SSM 就能夠穩(wěn)健地實現(xiàn)長度泛化。
這是因為當(dāng)上下文長度小于 K 時,循環(huán)狀態(tài)沒有被充分利用,導(dǎo)致模型在訓(xùn)練期間會「過擬合」。但一旦通過足夠長序列的訓(xùn)練使模型的狀態(tài)容量得到充分利用,它就會自動獲得泛化能力。
值得注意的是,K 與 M 竟然呈線性關(guān)系!—— 這表明每個 token 可能存在某種固有的信息含量(即存在一個值 B,使得上下文中的每個 token 對應(yīng) B 字節(jié)的循環(huán)狀態(tài))。這個 B 值可能是由模型架構(gòu)決定的?
「反過來說,過分擔(dān)心循環(huán)模型的長度泛化問題可能是一個誤區(qū)。我們無需設(shè)計新機制或特殊的緩解措施:只需要在更長的序列上訓(xùn)練(因為是線性時間復(fù)雜度,所以不會增加計算開銷?。?,就能獲得更好的泛化效果?!?/span>
最后,Albert Gu用一句話總結(jié):要讓你的Mamba吃得飽飽的,它就能發(fā)揮出最佳狀態(tài)!
喂飽你的Mamba
先來復(fù)習(xí)一下基礎(chǔ)知識。
本文以Mamba2作為主要研究對象,內(nèi)部的計算表示為下圖中的并行結(jié)構(gòu):
整體的輸入輸出遵循SSM(也即RNN)的形式:
而把上圖中模塊內(nèi)部所有的計算寫出來,就是下面這一坨公式:
之前提到的兩個問題,核心在于模型的內(nèi)部狀態(tài),也就是ht的表現(xiàn)。
所以下面在探索問題和解決方案時,咱們可以重點關(guān)注這些公式中,與ht計算相關(guān)的參數(shù)。
之前有研究表明,當(dāng)上下文長度超過其訓(xùn)練長度時,Mamba-1和RWKV-4的性能會嚴重下降。
順著這個思路,研究人員在兩個方向上進行了實驗分析:狀態(tài)崩潰(STATE COLLAPSE)和容量上限(STATE CAPACITY)。
狀態(tài)崩潰
狀態(tài)崩潰(SC)指的是,RNN模型在輸入上表現(xiàn)出異常行為的時間比訓(xùn)練期間看到的時間更長的現(xiàn)象。
上圖展示了Mamba-2和RWKV-6在訓(xùn)練長度之外的語言建模損失。為了可控性和合成任意長度的提示,這個損失是在僅由「\n」字符組成的提示上計算的(稱為「newlines」提示)。
結(jié)果表明,當(dāng)上下文長度遠大于其訓(xùn)練長度時,兩個RNN的性能都會嚴重下降,最后就跟瞎猜差不多了。
語言建??赡軣o法反映下游能力,上圖給出了Mamba-2(在8K上下文窗口上訓(xùn)練)在密鑰檢索任務(wù)上的評估結(jié)果。
我們可以發(fā)現(xiàn),Mamba-2在8K上下文中具有近乎完美的檢索準確性,但在序列長度超過16K后就沒法看了,無論模型參數(shù)量大小。
從上面的公式來看,這種結(jié)果可能出人意料,因為內(nèi)部狀態(tài)ht的更新應(yīng)該具有穩(wěn)定的指數(shù)內(nèi)存衰減,即對于最后k個token具有良好的檢索準確性。
問題出在哪里?
由于遞歸狀態(tài)的維度不會隨時間而變化,因此狀態(tài)崩潰期間行為的急劇變化一定是狀態(tài)值變化的結(jié)果。
作者對Mamba-2 370M中每一層的遞歸狀態(tài)進行了統(tǒng)計,發(fā)現(xiàn)當(dāng)上下文長度超過訓(xùn)練長度時,一些頭部的平均值和方差會急劇變化:
圖5顯示了模型第38層第2個頭的狀態(tài),在t=20K時方差爆炸。從中可以發(fā)現(xiàn)這種方差爆炸在很大程度上可以歸因于少數(shù)異常通道,其余大多數(shù)通道則相對穩(wěn)定。
分析一下公式,與ht計算有關(guān)的?t、Bt和xt:
如上圖所示,雖然三者都是輸入的函數(shù),但xt相對穩(wěn)定,而Bt比?t更早發(fā)生爆炸,進一步探索還能發(fā)現(xiàn)生成?t和Bt的卷積權(quán)重明顯更大。
作者認為,產(chǎn)生SC的原因是,對于訓(xùn)練長度來說,狀態(tài)容量過大,模型能夠?qū)崿F(xiàn)強大的語言建模性能,而無需學(xué)習(xí)如何忘記。
上圖顯示了第一個token在不同時間步的內(nèi)存強度,作者發(fā)現(xiàn)爆炸的頭(第38層的第2、4、7個頭)強烈傾向于在訓(xùn)練長度內(nèi)保留所有信息,在t=8K時內(nèi)存強度超過0.8。
解決方案
為了緩解SC,使模型沿序列長度更好地泛化,作者提出了3種解決方案,總的思想是修改狀態(tài)的update規(guī)則來避免其溢出。
Method 1: Forget More and Remember Less
通過增加狀態(tài)衰減量(忘記更多)或減少輸入信息的數(shù)量(記住更少)來減少SC,作者選擇干預(yù)Bt和αt(分別控制輸入強度和內(nèi)存衰減強度)。
Method 2: State Normalization
在每次更新后對狀態(tài)進行歸一化,以確保狀態(tài)的范數(shù)始終低于閾值:
PS:這種方式會將模型轉(zhuǎn)換為非線性RNN,無法以與原始模型相同的方式并行化,預(yù)填充速度要慢得多。
Method 3: Sliding Window by State Difference
利用狀態(tài)ht可以寫為加權(quán)和的形式,來模擬滑動窗口機制,無需在每一步都從窗口的開頭重新處理。
此方法適用于所有可以寫成加權(quán)和的RNN,包括RWKV 5和6、RetNet、GLA等。盡管會使生成的計算和內(nèi)存成本翻倍,但仍然是一個可以接受的權(quán)衡,因為RNN的生成成本比Transformer低很多。
以上3個是不需要訓(xùn)練的方案,而基于SC是由狀態(tài)參數(shù)過擬合引起的假設(shè),我們也可以嘗試使用超過狀態(tài)容量的序列長度來訓(xùn)練模型。
容量上限
根據(jù)以上的討論,當(dāng)且僅當(dāng)訓(xùn)練長度包含的信息少于狀態(tài)容量時,才會發(fā)生SC,所以我們可以通過實驗間接估計模型的狀態(tài)容量。
研究人員訓(xùn)練了多個具有不同狀態(tài)大小和訓(xùn)練長度的Mamba-2,并將SC未發(fā)生的最小訓(xùn)練長度視為狀態(tài)容量。
實驗數(shù)據(jù)選擇RedPajama-V2,一個從CommonCrawl中提取的30T token的開放數(shù)據(jù)集,進行去重以確保數(shù)據(jù)質(zhì)量。
在評估過程中,對長度超過16K token的文檔進行抽樣,如果不夠長,則對其進行拼接。
研究人員試驗了具有不同狀態(tài)大小的模型配置,包括來自Mamba-2官方checkpoint的三個預(yù)訓(xùn)練模型,大小分別為130M、370M和780M,另外3個模型(36M、47M、85M)則從頭開始訓(xùn)練。
實驗結(jié)果
上圖展示了在Mamba-2 780M上無訓(xùn)練長度泛化方法的結(jié)果。我們可以看到,雖然LongMamba大大提高了模型的長度泛化性(3倍以上),但它在較短的序列上會導(dǎo)致明顯更大的困惑度,并且仍然不可避免地表現(xiàn)出SC。
相比之下,本文的所有的方法都成功地抑制了SC,使模型能夠泛化到超過64K個token。
三種方案中,狀態(tài)歸一化在較短序列上的性能大大低于其他方法,這可能是因為歸一化折疊狀態(tài)會改變heads之間的規(guī)范比率,破壞了學(xué)習(xí)機制。
上圖顯示了Mamba-2在語言建模和密鑰檢索方面的狀態(tài)容量。兩個圖中最右邊的數(shù)據(jù)點對應(yīng)于Mamba-2 370M。
左邊的圖可以擬合出一個線性關(guān)系,而右邊的圖則表明Mamba-2在密鑰檢索方面的容量與狀態(tài)大小呈指數(shù)級關(guān)系。
這是因為上下文中的信息量不會隨著其長度的增加而增加。換句話說,模型存儲了恒定數(shù)量的信息,而狀態(tài)的組合數(shù)量隨著元素數(shù)量呈指數(shù)增長。