再戰(zhàn)Transformer!原作者帶隊(duì)的Mamba 2來(lái)了,新架構(gòu)訓(xùn)練效率大幅提升
自 2017 年被提出以來(lái),Transformer 已經(jīng)成為 AI 大模型的主流架構(gòu),一直穩(wěn)居語(yǔ)言建模方面 C 位。
但隨著模型規(guī)模的擴(kuò)展和需要處理的序列不斷變長(zhǎng),Transformer 的局限性也逐漸凸顯。一個(gè)很明顯的缺陷是:Transformer 模型中自注意力機(jī)制的計(jì)算量會(huì)隨著上下文長(zhǎng)度的增加呈平方級(jí)增長(zhǎng)。
幾個(gè)月前,Mamba 的出現(xiàn)打破了這一局面,它可以隨上下文長(zhǎng)度的增加實(shí)現(xiàn)線性擴(kuò)展。隨著 Mamba 的發(fā)布,這些狀態(tài)空間模型 (SSM) 在中小型規(guī)模上已經(jīng)實(shí)現(xiàn)了與 Transformers 匹敵,甚至超越 Transformers。
Mamba 的作者只有兩位,一位是卡內(nèi)基梅隆大學(xué)機(jī)器學(xué)習(xí)系助理教授 Albert Gu,另一位是 Together.AI 首席科學(xué)家、普林斯頓大學(xué)計(jì)算機(jī)科學(xué)助理教授 Tri Dao。
Mamba 面世之后的這段時(shí)間里,社區(qū)反應(yīng)熱烈??上У氖?,Mamba 的論文卻慘遭 ICLR 拒稿,讓一眾研究者頗感意外。
僅僅六個(gè)月后,原作者帶隊(duì),更強(qiáng)大的 Mamba 2 正式發(fā)布了。
- 論文地址:???https://arxiv.org/pdf/2405.21060???
- GitHub 地址:???https://github.com/state-spaces/mamba???
- 論文標(biāo)題:Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
總體而言,本文提出了 SSD(state space duality)框架,基于此,研究者設(shè)計(jì)了一個(gè)新的體系架構(gòu) Mamba-2,其核心層是對(duì) Mamba 的選擇性 SSM 的改進(jìn),速度提高了 2-8 倍,同時(shí)在語(yǔ)言建模方面繼續(xù)與 Transformers 競(jìng)爭(zhēng)。
Tri Dao 表示,他們構(gòu)建了一個(gè)豐富的 SSD 理論框架,許多線性注意力變體和 SSM 是等效的,由此產(chǎn)生的模型 Mamba-2 比 Mamba-1 更好、更快。
Mamba-2 的新算法使其能夠利用更大的狀態(tài)維度 (16 → 256),同時(shí)訓(xùn)練速度更快。在需要更大狀態(tài)容量的任務(wù)上,例如 MQAR 任務(wù),它比 Mamba-1 有了顯著的改進(jìn)。
此外研究者還發(fā)現(xiàn),最近新出的混合模型(Jamba、Zamba)增加了一些注意力層來(lái)提高模型質(zhì)量?;谶@些發(fā)現(xiàn),研究者將 4-6 個(gè)注意力層與 Mamba-2 層混合,其表現(xiàn)優(yōu)于 Transformer++ 和純 Mamba-2,因而得出注意力和 SSM 是互補(bǔ)的。
這項(xiàng)研究的貢獻(xiàn)概括為:
本文展示了狀態(tài)空間模型與一類(lèi)稱為半可分矩陣的結(jié)構(gòu)化矩陣族之間的等價(jià)性。這一聯(lián)系是 Mamba-2 框架的核心,揭示了狀態(tài)空間模型的新屬性和算法。
本文顯著改進(jìn)了線性注意力理論,首先通過(guò)張量收縮的語(yǔ)言對(duì)其循環(huán)形式提供了一個(gè)明確的證明,然后將其推廣到一種新的結(jié)構(gòu)化掩碼注意力(SMA)家族。
本文將 SSM(狀態(tài)空間模型)和 SMA(結(jié)構(gòu)化掩碼注意力)聯(lián)系起來(lái),顯示它們有一個(gè)很大的交集,彼此是對(duì)偶的,同時(shí)具有 SSM 式的線性形式和類(lèi)似注意力的二次方形式。本文還證明了任何具有快速循環(huán)形式的核注意方法都是 SSM。
除了內(nèi)在的理論價(jià)值外,研究者所提出的框架為理解和改進(jìn)序列模型開(kāi)辟了廣闊的方向。
在算法層面。所提框架為計(jì)算 SSM 提供了新的高效且易于實(shí)現(xiàn)的算法。本文提出了一種基于半可分離矩陣塊分解的 SSD 算法,該算法利用了 SSM 線性遞推和二次對(duì)偶形式,在所有主要效率軸上獲得了最優(yōu)的權(quán)衡。基于 SSD 的實(shí)現(xiàn)比 Mamba 的優(yōu)化選擇性掃描實(shí)現(xiàn)快 2 到 8 倍,同時(shí)允許使用更大的循環(huán)狀態(tài)大?。ㄊ?Mamba 的 8 倍甚至更高,且?guī)缀醪挥绊懰俣龋SD 與優(yōu)化過(guò)的 softmax 注意力實(shí)現(xiàn)(FlashAttention-2)具有高度競(jìng)爭(zhēng)力,在序列長(zhǎng)度 2k 時(shí)性能相當(dāng),在序列長(zhǎng)度 16K 時(shí)速度快 6 倍。
架構(gòu)設(shè)計(jì)。采用 SSM 等新架構(gòu)的一個(gè)主要障礙是針對(duì) Transformers 量身定制的生態(tài)系統(tǒng),例如用于大規(guī)模訓(xùn)練的硬件高效優(yōu)化和并行技術(shù)。本文框架允許使用已建立的慣例和技術(shù)來(lái)構(gòu)建 SSM 的架構(gòu)設(shè)計(jì)選擇詞匯表,并進(jìn)一步改進(jìn)它們。
本文還對(duì) Mamba 塊做了一些修改,這些修改允許實(shí)現(xiàn)張量并行,其主要思想包括引入分組值注意力 (GVA,grouped-value attention) 頭結(jié)構(gòu)。
將修改后的并行 Mamba 塊與作為內(nèi)部 SSM 層的 SSD 結(jié)合使用,形成了 Mamba-2 架構(gòu)。研究者在與 Mamba 相同的設(shè)置中研究了 Mamba-2 的 Chinchilla 擴(kuò)展法則,發(fā)現(xiàn)它在困惑度和實(shí)際運(yùn)行時(shí)間方面均優(yōu)于 Mamba 和 Transformer++。研究者還在 Pile 數(shù)據(jù)集上訓(xùn)練了一系列 Mamba-2 模型,結(jié)果顯示 Mamba-2 在標(biāo)準(zhǔn)下游評(píng)估中匹配或超過(guò) Mamba 和開(kāi)源的 Transformers。例如,在 Pile 上訓(xùn)練了 3000 億 token 的 2.7B 參數(shù)的 Mamba-2 在性能上超過(guò)了在同一數(shù)據(jù)集上訓(xùn)練的 2.8B 參數(shù)的 Mamba 和 Pythia 以及 6.9B 參數(shù)的 Pythia。
系統(tǒng)優(yōu)化:SSD 框架連接 SSM 和 transformer,允許利用為 transformer 開(kāi)發(fā)的豐富的系統(tǒng)優(yōu)化工作。
SSD 層
Mamba-2 的核心貢獻(xiàn)是新的 SSD(state space dual)層。SSD 層可以被定義為選擇性 SSM 的特例。與 Mamba 相比,Mamba-2 的改動(dòng)會(huì)略微降低表達(dá)能力,但卻顯著提高了訓(xùn)練效率,特別是允許在現(xiàn)代加速器上使用矩陣乘法單元。
SSD 層的對(duì)偶注意力:
除了最新的 SSD 層,研究者也對(duì) Mamba 的神經(jīng)網(wǎng)絡(luò)架構(gòu)做了一些小的改變,Mamba-2 架構(gòu)如下所示。
Mamba-2 在網(wǎng)絡(luò)架構(gòu)上的主要變化是從順序生成變?yōu)椴⑿猩?SSM 參數(shù),并且 Mamba-2 更適合張量并行等擴(kuò)展方法。
通過(guò)提供狀態(tài)空間模型的顯式矩陣變換形式,研究團(tuán)隊(duì)揭示了理解和使用它們的新方法。從計(jì)算的角度來(lái)看,任何計(jì)算狀態(tài)空間模型前向傳播的方法都可以看作是半可分離矩陣上的矩陣乘法算法。半可分離矩陣視角為 SSD 提供了一個(gè)視角,其中雙重模式分別指的是線性時(shí)間半可分離矩陣乘法算法和二次時(shí)間樸素矩陣乘法。
研究團(tuán)隊(duì)定義了結(jié)構(gòu)化狀態(tài)空間模型和結(jié)構(gòu)化注意力,討論了它們的屬性,并表明它們都有二次算法和線性算法。
自最初的 Mamba 論文研究了合成任務(wù) —— 如:合成復(fù)制和歸納 Head 以來(lái),許多后續(xù)工作開(kāi)始研究更難的關(guān)聯(lián)回憶任務(wù)。由 Zoology 和 Based 系列工作引入的 MQAR(multi-query associative recall)任務(wù)已成為事實(shí)上的標(biāo)準(zhǔn)。
通過(guò)運(yùn)行一個(gè)比文獻(xiàn)中通常報(bào)告的版本要難得多的任務(wù),該團(tuán)隊(duì)發(fā)現(xiàn) Mamba-2 明顯優(yōu)于 Mamba-1,而改善性能的一個(gè)原因是狀態(tài)大?。ū?Mamba-1 大約 16 倍)。
在這篇文章中,作者深入探討了模型背后的理論。
從兩個(gè)完全不同的角度推導(dǎo)出 SSD 的「對(duì)偶性」:
- 一個(gè)從 SSM 的角度出發(fā);
- 另一個(gè)從注意力機(jī)制的角度出發(fā)。
SSD 框架提供了狀態(tài)空間模型、注意力機(jī)制和結(jié)構(gòu)化矩陣之間豐富的聯(lián)系。
雖然 SSD 模型可以被視為框架內(nèi)每個(gè)分支的具體實(shí)例,但 SSD 框架本身更加通用,為未來(lái)的工作開(kāi)辟了許多方向。
SSD 框架(紅色,藍(lán)色):狀態(tài)空間模型(即半可分矩陣)和結(jié)構(gòu)化掩碼注意力機(jī)制包含了大量高效的序列模型。它們的交集是 SSD 模型(紫色)。
SSD 算法
通常,矩陣乘法(matmul)的 FLOPs 速度要比非矩陣乘法 FLOPs 快得多(高達(dá) 16 倍):A100 GPU 具有 312 TFLOPS 的 BF16 矩陣乘法性能,但只有 19 TFLOPS 的 FP32 算術(shù)性能,而 H100 具有 989 TFLOPS 的 BF16 矩陣乘法性能,但只有 67 TFLOPS 的 FP32 算術(shù)性能。
Mamba-2 的主要目標(biāo)之一是「利用張量核心加速 SSM」。
在綁定參數(shù)并引入 Head 結(jié)構(gòu)后,Mamba-1 中的 SSM 變成了 SSD,這是一種更具限制性的形式,具有類(lèi)似注意力的公式。并且由于 SSD 連接 SSM 和結(jié)構(gòu)化矩陣,計(jì)算 SSM 的高效算法直接對(duì)應(yīng)于「token-mixing」或「sequence-mixing」矩陣 M 的不同分解。
因此,可以通過(guò)尋找替代的矩陣乘法方式,例如通過(guò)各種方式對(duì)其進(jìn)行分解,從而創(chuàng)建計(jì)算 SSM 的新算法。
通過(guò)精心選擇塊大小,對(duì)這個(gè)矩陣進(jìn)行簡(jiǎn)單塊分解,就可以集 SSD 線性遞歸和二次注意力對(duì)偶形式的兩種優(yōu)勢(shì)于一身。
而這也就是 SSD 算法的起源,它有 4 個(gè)步驟,并且對(duì)于這個(gè)算法有兩種完全不同的詮釋。
SSD 算法:分塊矩陣分解
首先將半可分 SSM 矩陣劃分為大小為 Q×Q 的塊,然后,利用半分矩陣的性質(zhì)來(lái)分解每個(gè)低秩的非對(duì)角塊:
- (橙色)每個(gè)對(duì)角塊是一個(gè)更小的半可分矩陣,可以以喜歡的方式計(jì)算這個(gè)乘法,特別是使用 SSD 的二次(類(lèi)似注意力機(jī)制)形式。
- (綠色)總共有 T/Q 個(gè)不同的綠色塊,通過(guò)批處理矩陣乘法來(lái)計(jì)算。
- (黃色)注意,黃色項(xiàng)本身是一個(gè) 1 - 半可分矩陣,這一步等價(jià)于對(duì)某些修改后的 A 因子的 SSM 掃描。
- (藍(lán)色)與綠色類(lèi)似,通過(guò)批處理矩陣乘法來(lái)計(jì)算。
SSD 算法:分塊和狀態(tài)傳遞
該算法的另一種詮釋涉及「推理 SSM 如何在實(shí)際序列上進(jìn)行操作」。
首先將輸入序列分割成大小為 Q 的塊,步驟可以分為:
- 分塊內(nèi)部輸出:計(jì)算每個(gè)塊的局部輸出(假設(shè)初始狀態(tài)(對(duì)于塊)為 0,則每個(gè)塊的輸出是多少?)
- 塊狀態(tài):計(jì)算每個(gè)塊的最終狀態(tài)(假設(shè)初始狀態(tài)(對(duì)于塊)為 0,則每個(gè)塊的最終狀態(tài)是多少?)
- 傳遞狀態(tài):計(jì)算所有塊的最終狀態(tài)的遞歸 - 使用任何所需的算法,例如并行或順序掃描(考慮到所有先前輸入,每個(gè)塊的實(shí)際最終狀態(tài)是多少?)
- 輸出狀態(tài):對(duì)于每個(gè)塊,根據(jù)其真實(shí)的初始狀態(tài)(在步驟 3 中計(jì)算),僅從初始狀態(tài)得出的輸出計(jì)算貢獻(xiàn)
可以看到,大部分算法(步驟 1、2 和 4)利用了矩陣乘法(因此利用了張量核心),而且可以并行計(jì)算。
只有步驟 3 需要掃描,但它只操作一個(gè)非常短的序列,通常只需要很少時(shí)間。
系統(tǒng)及擴(kuò)展優(yōu)化
張量并行
使用張量并行對(duì) Mamba-1 進(jìn)行大規(guī)模訓(xùn)練的一項(xiàng)困難是,每層都需要 2 次 all-reduce,而在 Transformer 中,每個(gè)注意力或 MLP 層只需 1 次 all-reduce。這是因?yàn)?SSM 的一些參數(shù)是內(nèi)部激活的函數(shù),而不是層的輸入函數(shù)。在 Mamba-2 中,由于采用了「并行投影」結(jié)構(gòu),所有 SSM 參數(shù)都是層輸入的函數(shù),因此可以輕松地將張量并行應(yīng)用于輸入投影:將輸入投影和輸出投影矩陣分割成 2、4、8 個(gè)碎片,具體取決于張量并行度。研究者使用 grouped norm,分組數(shù)除以張量并行度,這樣每個(gè) GPU 都能單獨(dú)完成歸一化。這些變化導(dǎo)致每層只需 1 次 all-reduce,而不是 2 次。
序列并行
在對(duì)超長(zhǎng)序列進(jìn)行訓(xùn)練時(shí),可能需要沿著序列長(zhǎng)度進(jìn)行分割,并將不同部分分配給不同的設(shè)備。序列并行主要有兩種形式:對(duì)于殘差和歸一化操作,用 reduce-scatter、殘差 + 歸一化、然后 all-gather,取代張量并行中的 all-reduce。由于 Mamba-2 使用與 Transformer 相同的殘差和歸一化結(jié)構(gòu),因此這種形式的序列并行無(wú)需修改即可直接應(yīng)用。對(duì)于注意力或 SSM 操作,又稱上下文并行(CP)。對(duì)于注意力,可以使用環(huán)形注意力沿序列維度進(jìn)行分割。對(duì)于 Mamba-2,SSD 框架再次提供了幫助:使用相同的蒯分解,可以讓每個(gè) GPU 計(jì)算其本地輸出和最終狀態(tài),然后在更新每個(gè) GPU 的最終輸出之前,在 GPU 之間傳遞狀態(tài)(使用發(fā)送 / 接收通信原語(yǔ))。
實(shí)驗(yàn)結(jié)果
該研究在 MQAR 的一種具有挑戰(zhàn)性的版本上,使用更難的任務(wù)、更長(zhǎng)的序列和更小的模型進(jìn)行了對(duì)比實(shí)驗(yàn)?;€包括標(biāo)準(zhǔn)的多頭 softmax 注意力以及 Based 架構(gòu),實(shí)驗(yàn)結(jié)果如圖 8 所示。
下表顯示了 Mamba-2 在一系列下游零樣本評(píng)估任務(wù)上的性能:
感興趣的讀者可以閱讀論文原文,了解更多研究?jī)?nèi)容。
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
