一文看懂Mamba,Transformer最強(qiáng)競(jìng)爭(zhēng)者
深度學(xué)習(xí)架構(gòu)有很多,但近些年最成功的莫過(guò)于 Transformer,其已經(jīng)在多個(gè)應(yīng)用領(lǐng)域確立了自己的主導(dǎo)地位。
如此成功的一大關(guān)鍵推動(dòng)力是注意力機(jī)制,這能讓基于 Transformer 的模型關(guān)注與輸入序列相關(guān)的部分,實(shí)現(xiàn)更好的上下文理解。但是,注意力機(jī)制的缺點(diǎn)是計(jì)算開(kāi)銷大,會(huì)隨輸入規(guī)模而二次增長(zhǎng),也因此就難以處理非常長(zhǎng)的文本。
好在前段時(shí)間誕生了一種頗具潛力的新架構(gòu):結(jié)構(gòu)化的狀態(tài)空間序列模型(SSM)。該架構(gòu)能高效地捕獲序列數(shù)據(jù)中的復(fù)雜依賴關(guān)系,并由此成為 Transformer 的一大強(qiáng)勁對(duì)手。
這類模型的設(shè)計(jì)靈感來(lái)自經(jīng)典的狀態(tài)空間模型 —— 我們可以將其看作是循環(huán)神經(jīng)網(wǎng)絡(luò)和卷積神經(jīng)網(wǎng)絡(luò)的融合模型。它們可使用循環(huán)或卷積運(yùn)算進(jìn)行高效地計(jì)算,從而讓計(jì)算開(kāi)銷隨序列長(zhǎng)度而線性或近線性地變化,由此大幅降低計(jì)算成本。
更具體而言,SSM 最成功的變體之一 Mamba 的建模能力已經(jīng)可以比肩 Transformer,同時(shí)還能維持隨序列長(zhǎng)度的線性可擴(kuò)展性。
Mamba 首先引入了一個(gè)簡(jiǎn)單卻有效選擇機(jī)制,其可根據(jù)輸入對(duì) SSM 進(jìn)行重新參數(shù)化,從而可讓模型在濾除不相關(guān)信息的同時(shí)無(wú)限期地保留必要和相關(guān)的數(shù)據(jù)。然后,Mamba 還包含一種硬件感知型算法,可使用掃描(scan)而非卷積來(lái)循環(huán)地計(jì)算模型,這在 A100 GPU 上能讓計(jì)算速度提升 3 倍。
如圖 1 所示,憑借強(qiáng)大的建模復(fù)雜長(zhǎng)序列數(shù)據(jù)的能力和近乎線性的可擴(kuò)展性,Mamba 已經(jīng)崛起成為一種基礎(chǔ)模型,并有望變革計(jì)算機(jī)視覺(jué)、自然語(yǔ)言處理和醫(yī)療等多個(gè)研究和應(yīng)用領(lǐng)域。
因此,研究和應(yīng)用 Mamba 的文獻(xiàn)迅速增長(zhǎng),讓人目不暇接,一篇全面的綜述報(bào)告必定大有裨益。近日,香港理工大學(xué)的一個(gè)研究團(tuán)隊(duì)在 arXiv 上發(fā)布了他們的貢獻(xiàn)。
- 論文標(biāo)題:A Survey of Mamba
- 論文地址:https://arxiv.org/pdf/2408.01129
這份綜述報(bào)告從多個(gè)角度對(duì) Mamba 進(jìn)行了總結(jié),既能幫助初學(xué)者學(xué)習(xí) Mamba 的基礎(chǔ)工作機(jī)制,也能助力經(jīng)驗(yàn)豐富的實(shí)踐者了解最新進(jìn)展。
Mamba 是一個(gè)熱門研究方向,也因此有多個(gè)團(tuán)隊(duì)都在嘗試編寫綜述報(bào)告,除了本文介紹的這一篇,還有另一些關(guān)注狀態(tài)空間模型或視覺(jué) Mamba 的綜述,詳情請(qǐng)參閱相應(yīng)論文:
- Mamba-360: Survey of state space models as transformer alternative for long sequence modelling: Methods, applications, and challenges. arXiv:2404.16112
- State space model for new-generation network alternative to transformers: A survey. arXiv:2404.09516
- Vision Mamba: A Comprehensive Survey and Taxonomy. arXiv:2405.04404
- A survey on vision mamba: Models, applications and challenges. arXiv:2404.18861
- A survey on visual mamba. arXiv:2404.15956
預(yù)備知識(shí)
Mamba 集中了循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)的循環(huán)框架、Transformer 的并行計(jì)算和注意力機(jī)制、狀態(tài)空間模型(SSM)的線性特性。因此,為了透徹地理解 Mamba,就必需先理解這三種架構(gòu)。
循環(huán)神經(jīng)網(wǎng)絡(luò)
循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)具有保留內(nèi)部記憶的能力,因此很擅長(zhǎng)處理序列數(shù)據(jù)。
具體來(lái)說(shuō),在每個(gè)離散時(shí)間步驟 k,標(biāo)準(zhǔn) RNN 在處理一個(gè)向量時(shí)會(huì)連同前一時(shí)間步驟的隱藏狀態(tài)一起處理,之后輸出另一個(gè)向量并更新隱藏狀態(tài)。這個(gè)隱藏狀態(tài)就可作為 RNN 的記憶,其能保留過(guò)去已見(jiàn)過(guò)的輸入的信息。這種動(dòng)態(tài)記憶讓 RNN 可處理不同長(zhǎng)度的序列。
也就是說(shuō),RNN 是一種非線性的循環(huán)模型,可通過(guò)使用存儲(chǔ)在隱藏狀態(tài)中歷史知識(shí)來(lái)有效地捕獲時(shí)間模式。
Transformer
Transformer 的自注意力機(jī)制有助于捕獲輸入之中的全局依賴。其實(shí)現(xiàn)方式是基于每個(gè)位置相對(duì)于其它位置的重要程度為它們分配權(quán)重。更具體而言,首先對(duì)原始輸入進(jìn)行線性變換,將輸入向量的序列 x 轉(zhuǎn)換成三類向量:查詢 Q、鍵 K 和值 V。
然后計(jì)算歸一化的注意力分?jǐn)?shù) S 并計(jì)算注意力權(quán)重。
除了可以執(zhí)行單個(gè)注意力函數(shù),我們還可以執(zhí)行多頭注意力。這讓模型可以捕獲不同類型的關(guān)系,并從多個(gè)視角理解輸入序列。多頭注意力會(huì)使用多組自注意力模塊并行地處理輸入序列。其中每個(gè)頭都獨(dú)立運(yùn)作,執(zhí)行的計(jì)算與標(biāo)準(zhǔn)自注意力機(jī)制一樣。
之后,將每個(gè)頭的注意力權(quán)重匯聚組合,得到值向量的加權(quán)和。這個(gè)聚合步驟可讓模型使用來(lái)自多個(gè)頭的信息并捕獲輸入序列中的多種不同模式和關(guān)系。
狀態(tài)空間
狀態(tài)空間模型(SSM)是一種傳統(tǒng)的數(shù)學(xué)框架,可用于描述系統(tǒng)隨時(shí)間變化的動(dòng)態(tài)行為。近些年來(lái),人們已將 SSM 廣泛應(yīng)用于控制論、機(jī)器人學(xué)和經(jīng)濟(jì)學(xué)等多個(gè)不同領(lǐng)域。
究其核心,SSM 是通過(guò)一組名為「狀態(tài)」的隱藏變量來(lái)體現(xiàn)系統(tǒng)的行為,使其能有效捕獲時(shí)間數(shù)據(jù)的依賴關(guān)系。不同于 RNN,SSM 是一種具有關(guān)聯(lián)(associative)屬性的線性模型。具體來(lái)說(shuō),經(jīng)典的狀態(tài)空間模型會(huì)構(gòu)建兩個(gè)關(guān)鍵方程(狀態(tài)方程和觀察方程),以通過(guò)一個(gè) N 維的隱藏狀態(tài) h (t) 建模當(dāng)前時(shí)間 t 時(shí)輸入 x 與輸出 y 之間的關(guān)系。
- 離散化
為了滿足機(jī)器學(xué)習(xí)的需求,SSM 必需經(jīng)歷一個(gè)離散化過(guò)程 —— 將連續(xù)參數(shù)轉(zhuǎn)變成離散參數(shù)。通常來(lái)說(shuō),離散化方法的目標(biāo)是將連續(xù)時(shí)間劃分為具有盡可能相等積分面積的 K 個(gè)離散區(qū)間。為了實(shí)現(xiàn)這一目標(biāo),SSM 采用的最具代表性的解決方案之一是 Zero-Order Hold(ZOH),其假設(shè)區(qū)間 Δ = [??_{???1}, ??_?? ] 上的函數(shù)值保持不變。離散 SSM 與循環(huán)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)相似,因此離散 SSM 能比基于 Transformer 的模型更高效地執(zhí)行推理過(guò)程。
- 卷積計(jì)算
離散 SSM 是一個(gè)具有結(jié)合屬性的線性系統(tǒng),因此可以與卷積計(jì)算無(wú)縫整合。
RNN、Transformer 和 SSM 之間的關(guān)系
圖 2 展示了 RNN、Transformer 和 SSM 的計(jì)算算法。
一方面,常規(guī) RNN 的運(yùn)作基于一種非線性的循環(huán)框架,其中每個(gè)計(jì)算都僅依賴于之前的隱藏狀態(tài)和當(dāng)前輸入。
盡管這種形式可讓 RNN 在自回歸推理時(shí)快速生成輸出,但它也讓 RNN 難以充分利用 GPU 的并行計(jì)算能力,導(dǎo)致模型訓(xùn)練速度變慢。
另一方面,Transformer 架構(gòu)是在多個(gè)「查詢 - 鍵」對(duì)上并行執(zhí)行矩陣乘法,而矩陣乘法可以高效地分配給硬件資源,從而更快地訓(xùn)練基于注意力的模型。但是,如果要讓基于 Transformer 的模型生成響應(yīng)或預(yù)測(cè),則推理過(guò)程會(huì)非常耗時(shí)。
不同于僅支持一類計(jì)算的 RNN 和 Transformer,離散 SSM 靈活性很高;得益于其線性性質(zhì),它既能支持循環(huán)計(jì)算,也可支持卷積計(jì)算。這種特性讓 SSM 不僅能實(shí)現(xiàn)高效推理,也能實(shí)現(xiàn)并行訓(xùn)練。但是,需要指出,最常規(guī)的 SSM 是時(shí)不變的,也就是說(shuō)其 A、B、C 和 Δ 與模型輸入 x 無(wú)關(guān)。這會(huì)限制其上下文感知型建模的能力,導(dǎo)致 SSM 在選擇性復(fù)制等一些特定任務(wù)上表現(xiàn)不佳。
Mamba
為了解決上述傳統(tǒng) SSM 的缺點(diǎn),實(shí)現(xiàn)上下文感知型建模,Albert Gu 和 Tri Dao 提出了可用作通用序列基礎(chǔ)模型主干網(wǎng)絡(luò)的 Mamba,參閱機(jī)器之心報(bào)道《五倍吞吐量,性能全面包圍 Transformer:新架構(gòu) Mamba 引爆 AI 圈》。
之后,他們倆又進(jìn)一步提出了 Mamba-2,其中的結(jié)構(gòu)化空間狀態(tài)對(duì)偶(SSD/Structured Space-State Duality)構(gòu)建了一個(gè)將結(jié)構(gòu)化 SSM 與多種形式的注意力連接起來(lái)的穩(wěn)健的理論框架,讓我們可將原本為 Transformer 開(kāi)發(fā)的算法和系統(tǒng)優(yōu)化技術(shù)遷移用于 SSM,也可參閱機(jī)器之心報(bào)道《再戰(zhàn) Transformer!原作者帶隊(duì)的 Mamba 2 來(lái)了,新架構(gòu)訓(xùn)練效率大幅提升》。
Mamba-1:使用硬件感知型算法的選擇式狀態(tài)空間模型
Mamba-1 基于結(jié)構(gòu)化狀態(tài)空間模型引入了三大創(chuàng)新技術(shù),即基于高階多項(xiàng)式投影算子(HiPPO)的內(nèi)存初始化、選擇機(jī)制和硬件感知型計(jì)算。如圖 3 所示。這些技術(shù)的目標(biāo)是提升 SSM 的長(zhǎng)程線性時(shí)間序列建模能力。
具體來(lái)說(shuō),其中的初始化策略可構(gòu)建一個(gè)連貫的隱藏狀態(tài)矩陣,以有效地促進(jìn)長(zhǎng)程記憶。
然后,選擇機(jī)制可讓 SSM 有能力獲取可感知內(nèi)容的表征。
最后,為了提升訓(xùn)練效率,Mamba 還包含兩種硬件感知型計(jì)算算法:Parallel Associative Scan(并行關(guān)聯(lián)掃描)和 Memory Recomputation(內(nèi)存重新計(jì)算)。
Mamba-2:狀態(tài)空間對(duì)偶
Transformer 啟發(fā)了多種不同技術(shù)的發(fā)展,比如參數(shù)高效型微調(diào)、災(zāi)難性遺忘緩解、模型量化。為了讓狀態(tài)空間模型也能受益于這些原本為 Transformer 開(kāi)發(fā)的技術(shù),Mamba-2 引入了一個(gè)新框架:結(jié)構(gòu)化狀態(tài)空間對(duì)偶(SSD)。該框架在理論上將 SSM 和不同形式的注意力連接到了一起。
本質(zhì)上講,SSD 表明,Transformer 使用的注意力機(jī)制和 SSM 中使用的線性時(shí)不變系統(tǒng)都可被視為半可分離的矩陣變換。
此外,Albert Gu 和 Tri Dao 還證明選擇式 SSM 等價(jià)于使用一種半可分離掩碼矩陣實(shí)現(xiàn)的結(jié)構(gòu)化線性注意力機(jī)制。
Mamba-2 基于 SSD 設(shè)計(jì)了一種能更高效使用硬件的計(jì)算方法,這要用到一種塊分解矩陣乘法算法。
具體來(lái)說(shuō),通過(guò)這種矩陣變換將狀態(tài)空間模型視為半可分離矩陣,Mamba-2 能將該計(jì)算分解為矩陣塊,其中對(duì)角塊表示塊內(nèi)計(jì)算。而非對(duì)角塊則表示通過(guò) SSM 的隱藏狀態(tài)分解的塊間計(jì)算。該方法可讓 Mamba-2 的訓(xùn)練速度超過(guò) Mamba-1 的并行關(guān)聯(lián)掃描的 2-8 倍,同時(shí)性能還能媲美 Transformer。
Mamba 塊
下面來(lái)看看 Mamba-1 和 Mamba-2 的塊設(shè)計(jì)。圖 4 比較了這兩種架構(gòu)。
Mamba-1 的設(shè)計(jì)是以 SSM 為中心,其中選擇式 SSM 層的任務(wù)是執(zhí)行從輸入序列 X 到 Y 的映射。在這種設(shè)計(jì)中,經(jīng)過(guò)了初始的創(chuàng)建 X 的線性投射之后,會(huì)使用 (A, B, C) 的線性投射。然后,輸入 token 和狀態(tài)矩陣會(huì)通過(guò)選擇式 SSM 單元,利用并行關(guān)聯(lián)掃描,從而得到輸出 Y。之后,Mamba-1 采用了一個(gè) skip 連接,以鼓勵(lì)特征復(fù)用和緩解常在模型訓(xùn)練過(guò)程中發(fā)生的性能下降問(wèn)題。最后,通過(guò)交錯(cuò)地堆疊該模塊與標(biāo)準(zhǔn)歸一化和殘差連接,便可構(gòu)建出 Mamba 模型。
至于 Mamba-2,則是引入了 SSD 層來(lái)創(chuàng)建從 [X, A, B, C] 到 Y 的映射。其實(shí)現(xiàn)方式是在塊的起點(diǎn)處使用單個(gè)投射來(lái)同時(shí)處理 [X, A, B, C],這類似于標(biāo)準(zhǔn)注意力架構(gòu)以并行方式生成 Q、K、V 投射的方式。
也就是說(shuō),通過(guò)移除序列線性投射,Mamba-2 塊是在 Mamba-1 塊的基礎(chǔ)上進(jìn)行了簡(jiǎn)化。這能讓 SSD 結(jié)構(gòu)的計(jì)算速度超過(guò) Mamba-1 的并行選擇式掃描。此外,為了提升訓(xùn)練穩(wěn)定性,Mamba-2 還在 skip 連接之后添加了一個(gè)歸一化層。
Mamba 模型正在發(fā)展進(jìn)步
狀態(tài)空間模型和 Mamba 近來(lái)發(fā)展迅猛,已經(jīng)成為了一大極具潛力的基礎(chǔ)模型骨干網(wǎng)絡(luò)選擇。盡管 Mamba 在自然語(yǔ)言處理任務(wù)上表現(xiàn)不俗,但也仍具有一些難題,比如記憶丟失、難以泛化到不同任務(wù)、在復(fù)雜模式方面的表現(xiàn)不及基于 Transformer 的語(yǔ)言模型。為了解決這些難題,研究社區(qū)為 Mamba 架構(gòu)提出了諸多改進(jìn)方案?,F(xiàn)有的研究主要集中于修改塊設(shè)計(jì)、掃描模式和記憶管理。表 1 分類總結(jié)了相關(guān)研究。
塊設(shè)計(jì)
Mamba 塊的設(shè)計(jì)和結(jié)構(gòu)對(duì) Mamba 模型的總體性能有很大的影響,也因此這成為了一大研究熱點(diǎn)。
如圖 5 所示,基于構(gòu)建新 Mamba 模塊的不同方法,現(xiàn)有研究可以分為三類:
- 集成方法:將 Mamba 塊與其它模型集成到一起,實(shí)現(xiàn)效果與效率的平衡;
- 替換方法:用 Mamba 塊替換其它模型框架中的主要層;
- 修改方法:修改經(jīng)典 Mamba 塊內(nèi)的組件。
掃描模式
并行關(guān)聯(lián)掃描是 Mamba 模型內(nèi)的一大關(guān)鍵組件,其目標(biāo)是解決由選擇機(jī)制導(dǎo)致的計(jì)算問(wèn)題、提升訓(xùn)練過(guò)程速度以及降低內(nèi)存需求。其實(shí)現(xiàn)方式是利用時(shí)變的 SSM 的線性性質(zhì)來(lái)在硬件層級(jí)上設(shè)計(jì)核融合和重新計(jì)算。但是,Mamba 的單向序列建模范式不利于全面學(xué)習(xí)多樣化的數(shù)據(jù),比如圖像和視頻。
為緩解這一問(wèn)題,一些研究者探索了新的高效掃描方法,以提升 Mamba 模型的性能以及促進(jìn)其訓(xùn)練過(guò)程。如圖 6 所示,在開(kāi)發(fā)掃描模式方面,現(xiàn)有的研究成果可以分為兩類:
- 展平式掃描方法:以展平的視角看待 token 序列,并基于此處理模型輸入;
- 立體式掃描方法:跨維度、通道或尺度掃描模型輸入,這又可進(jìn)一步分為三類:分層掃描、時(shí)空掃描、混合掃描。
記憶管理
類似于 RNN,在狀態(tài)空間模型內(nèi),隱藏狀態(tài)的記憶有效地存儲(chǔ)了之前步驟的信息,因此對(duì) SSM 的整體性能有著至關(guān)重要的影響。盡管 Mamba 引入了基于 HiPPO 的方法來(lái)進(jìn)行記憶初始化,但管理 SSM 單元中的記憶依然難度很大,其中包括在層之前轉(zhuǎn)移隱藏信息以及實(shí)現(xiàn)無(wú)損記憶壓縮。
為此,一些開(kāi)創(chuàng)性研究提出了一些不同的解決方案,包括記憶的初始化、壓縮和連接。
讓 Mamba 適應(yīng)多樣化的數(shù)據(jù)
Mamba 架構(gòu)是選擇式狀態(tài)空間模型的一種擴(kuò)展,其具備循環(huán)模型的基本特性,因而非常適合作為處理文本、時(shí)間序列、語(yǔ)音等序列數(shù)據(jù)的通用基礎(chǔ)模型。
不僅如此,近期一些開(kāi)創(chuàng)性研究更是擴(kuò)展了 Mamba 架構(gòu)的應(yīng)用場(chǎng)景,使其不僅能處理序列數(shù)據(jù),還能用于圖像和圖譜等領(lǐng)域,如圖 7 所示。
這些研究的目標(biāo)是既充分利用 Mamba 能獲取長(zhǎng)程依賴關(guān)系的出色能力,也讓其發(fā)揮學(xué)習(xí)和推理過(guò)程中的效率優(yōu)勢(shì)。表 2 簡(jiǎn)單總結(jié)了這些研究成果。
序列數(shù)據(jù)
序列數(shù)據(jù)是指以特定順序收集和整理的數(shù)據(jù),其中數(shù)據(jù)點(diǎn)的順序具有重要意義。這份綜述報(bào)告全面總結(jié)了 Mamba 在多種序列數(shù)據(jù)上的應(yīng)用,包括自然語(yǔ)言、視頻、時(shí)間序列、語(yǔ)音和人體運(yùn)動(dòng)數(shù)據(jù)。詳見(jiàn)原論文。
非序列數(shù)據(jù)
不同于序列數(shù)據(jù),非序列數(shù)據(jù)并不遵循特定的順序。其數(shù)據(jù)點(diǎn)可以任意順序進(jìn)行組織而不會(huì)對(duì)數(shù)據(jù)的含義造成顯著影響。對(duì)于專門設(shè)計(jì)用于捕獲數(shù)據(jù)中時(shí)間依賴關(guān)系的循環(huán)模型(RNN 和 SSM 等)來(lái)說(shuō),這種缺乏固有順序的數(shù)據(jù)會(huì)很難處理。
令人驚訝的是,近期的一些研究成功讓 Mamba(代表性的 SSM)實(shí)現(xiàn)了對(duì)非序列數(shù)據(jù)的高效處理,包括圖像、圖譜和點(diǎn)云數(shù)據(jù)。
多模態(tài)數(shù)據(jù)
為了提升 AI 的感知和場(chǎng)景理解能力,可以整合多個(gè)模態(tài)的數(shù)據(jù),比如語(yǔ)言(序列數(shù)據(jù))和圖像(非序列數(shù)據(jù))。這樣的整合能提供非常有價(jià)值和補(bǔ)充性的信息。
近段時(shí)間來(lái),多模態(tài)大型語(yǔ)言模型(MLLM)是最受關(guān)注的研究熱點(diǎn);這類模型繼承了大型語(yǔ)言模型(LLM)的強(qiáng)大能力,包括強(qiáng)大的語(yǔ)言表達(dá)和邏輯推理能力。盡管 Transformer 已經(jīng)成為該領(lǐng)域的主導(dǎo)方法,但 Mamba 也正在崛起成為一大強(qiáng)勁競(jìng)爭(zhēng)者,其在對(duì)齊混合源數(shù)據(jù)和實(shí)現(xiàn)序列長(zhǎng)度的線性復(fù)雜度擴(kuò)展方面表現(xiàn)出色,這使 Mamba 有望在多模態(tài)學(xué)習(xí)方面替代 Transformer。
應(yīng)用
下面介紹基于 Mamba 的模型的一些值得注意的應(yīng)用。該團(tuán)隊(duì)將這些應(yīng)用分為了以下類別:自然語(yǔ)言處理、計(jì)算機(jī)視覺(jué)、語(yǔ)音分析、藥物發(fā)現(xiàn)、推薦系統(tǒng)以及機(jī)器人和自主系統(tǒng)。
這里我們不再過(guò)多介紹,詳見(jiàn)原論文。
挑戰(zhàn)與機(jī)遇
Mamba 雖然已經(jīng)在一些領(lǐng)域取得了出色表現(xiàn),但總體而言,Mamba 研究仍還處于起步階段,前方仍還有一些挑戰(zhàn)有待克服。當(dāng)然,這些挑戰(zhàn)同時(shí)也是機(jī)遇。
- 如何開(kāi)發(fā)和改進(jìn)基于 Mamba 的基礎(chǔ)模型;
- 如何充分實(shí)現(xiàn)硬件感知型計(jì)算,以盡可能利用 GPU 和 TPU 等硬件,提升模型效率;
- 如何提升 Mamba 模型的可信度,這需要安全和穩(wěn)健性、公平性、可解釋性以及隱私方面的進(jìn)一步研究;
- 如何將 Transformer 領(lǐng)域的新技術(shù)用于 Mamba,如參數(shù)高效型微調(diào)、災(zāi)難性遺忘緩解、檢索增強(qiáng)式生成(RAG)。