清華稀疏Attention,無需訓(xùn)練加速一切模型!
在當(dāng)今各類大語言模型以及視頻模型中,長序列場景越來越普遍,而 Attention 的計(jì)算復(fù)雜度隨著序列長度呈平方增長,成為長序列任務(wù)下的主要計(jì)算瓶頸。此前,清華大學(xué)陳鍵飛團(tuán)隊(duì)提出的即插即用量化的 SageAttention 系列工作已實(shí)現(xiàn) 3 倍加速于 FlashAttention,且在各類大模型上均保持了端到端的精度,已被業(yè)界和社區(qū)廣泛使用。為了進(jìn)一步加速 Attention,清華大學(xué)陳鍵飛團(tuán)隊(duì)進(jìn)一步提出了無需訓(xùn)練可直接使用的稀疏 Attention(SpargeAttn)可用來加速任意模型。實(shí)現(xiàn)了 4-7 倍相比于 FlashAttention 的推理加速,且在語言,視頻、圖像生成等大模型上均保持了端到端的精度表現(xiàn)。

論文標(biāo)題:SpargeAttn: Accurate Sparse Attention Accelerating Any Model Inference
下圖展示了 SpargeAttn 的速度,可以發(fā)現(xiàn)在 RTX4090 上,SpargeAttn 在 60% 稀疏度的情況下可以達(dá)到 900TOPS 的速度,甚至是使用 A100 顯卡速度的 4.5 倍(A100 上 FlashAttention 只有 200TOPS)。

在 SpargeAttn 的 Github 倉庫中可以發(fā)現(xiàn),SpargeAttn 的使用方法比較簡潔,只需要進(jìn)行一次簡單的超參數(shù)搜索過程,就可以永久地對任意的模型輸入進(jìn)行推理加速。
接下來,將從前言,挑戰(zhàn),方法,以及實(shí)驗(yàn)效果四個(gè)方面介紹 SpargeAttn。
前言
隨著大模型需要處理的序列長度越來越長,Attention 的速度優(yōu)化變得越來越重要。這是因?yàn)橄啾扔诰W(wǎng)絡(luò)中其它操作的 O (N) 的時(shí)間復(fù)雜度,Attention 的時(shí)間復(fù)雜度是 O (N^2)。盡管 Attention 的計(jì)算復(fù)雜度為 O (N^2),但幸運(yùn)的是 Attention 具備很好的稀疏性質(zhì),即 P 矩陣的很多值都接近 0。如何利用這種稀疏性來節(jié)省計(jì)算就成為了 attention 加速的一個(gè)重要方向。大多數(shù)現(xiàn)有的工作都集中在利用 P 矩陣在語言模型中表現(xiàn)出來的固定的稀疏形狀(如滑動(dòng)窗口)來節(jié)省計(jì)算,或是需要重新訓(xùn)練模型,比如 DeepSeek 的 NSA 以及 Kimi 的 MoBA。此外,現(xiàn)有稀疏 Attention 通常需要較大的上下文窗口(如 64K~1M)才能有明顯加速。SpargeAttn 的目標(biāo)是開發(fā)一個(gè)無需訓(xùn)練、對各種模型(語言 / 視頻 / 圖像)通用、精度無損、對中等長度的上下文(如 4-32K)也有加速效果的注意力機(jī)制。

圖 1: 不同的模型表現(xiàn)出不同的稀疏形狀
實(shí)現(xiàn)通用的,無需訓(xùn)練的稀疏 Attenion 有哪些挑戰(zhàn)?
挑戰(zhàn) 1
通用性:Attention 雖然具備稀疏性質(zhì),但是其稀疏形狀在不同的模型甚至同一模型的不同層中都是不同的,體現(xiàn)出很強(qiáng)的動(dòng)態(tài)性。如圖 1 所示,前兩種模型分別為視頻模型和圖像生成模型,這兩個(gè)模型中的 Attention 的稀疏形狀相比語言模型更加沒有規(guī)律。設(shè)計(jì)一種各種模型通用的稀疏 Attention 是困難的。
挑戰(zhàn) 2
可用性:對于各種 Attention 的輸入,很難同時(shí)實(shí)現(xiàn)準(zhǔn)確且高效的稀疏 Attention。這是因?yàn)闇?zhǔn)確性要求了完全精確地預(yù)測 P 中的稀疏區(qū)域,高效性則要求了此預(yù)測的時(shí)間開銷極短。在一個(gè)極短的時(shí)間內(nèi)完全精準(zhǔn)地預(yù)測 P 的稀疏形狀是困難的。
方法
為了解決上述的兩個(gè)挑戰(zhàn),研究團(tuán)隊(duì)提出了對應(yīng)的解決辦法。
- 研究團(tuán)隊(duì)提出了一種各模型通用的快速的對 P 矩陣稀疏部分進(jìn)行預(yù)測的算法。該方法選擇性地對 Q, K 矩陣進(jìn)行壓縮并預(yù)測 P 矩陣,接著使用 TopCdf 操作省略 P 中稀疏部分對應(yīng)的 QK^T 與 PV 的矩陣乘法。
 - 研究團(tuán)隊(duì)提出了在 GPU Warp 級別上的稀疏 Online Softmax 算法,該算法通過利用 Online Softmax 中全局最大值與局部最大值之間的差異,進(jìn)一步省略了一些 PV 的矩陣乘法計(jì)算。
 - 可選的,針對視頻和圖像模型,研究團(tuán)隊(duì)充分利用圖像以及視頻中的 Token 局部相似性質(zhì),使用希爾伯特重排的方法對 Attention 前的 Token 進(jìn)行重新排列,進(jìn)一步提高稀疏度。
 - 最后,研究團(tuán)隊(duì)將這種稀疏方法與基于量化的 SageAttention 融合到一起,進(jìn)一步加速 Attention。
 

圖 2: SpargeAttn 的算法流程圖
SpargeAttn 的算法流程如下所示:

實(shí)驗(yàn)效果
總的來說,SpargeAttn 在視頻、圖像、文本生成等大模型均可以實(shí)現(xiàn)無需訓(xùn)練的加速效果,同時(shí)保證了各任務(wù)上的端到端的精度。
下表展示了 SpargeAttn 在各模型上的稀疏度,Attention 速度,以及各任務(wù)上的端到端精度,可以發(fā)現(xiàn) SpargeAttn 在保證了加速的同時(shí)沒有影響模型精度:(注:此論文中的所有實(shí)驗(yàn)都是基于 SageAttention 實(shí)現(xiàn),目前 Github 倉庫中已有基于 SageAttention2 的實(shí)現(xiàn),進(jìn)一步提供了 30% 的加速。

值得一提的是,此前的稀疏 Attention 工作很多無法實(shí)際使用的原因之一是稀疏預(yù)測部分的 Overhead 較大,而 SpargeAttn 團(tuán)隊(duì)還將稀疏預(yù)測部分的代碼進(jìn)行了極致優(yōu)化,將 Overhead 壓縮到了幾乎在各種長度的序列下都可以忽略的地步:

下表展示了對于各模型的端到端的加速效果,以視頻生成模型 Mochi 為例,SpargeAttn 提供了近兩倍的端到端加速效果:(注:此論文中的所有實(shí)驗(yàn)都是基于 SageAttention 實(shí)現(xiàn),目前 Github 倉庫中已有基于 SageAttention2 的實(shí)現(xiàn),進(jìn)一步提供了 30% 的加速)
















 
 
 













 
 
 
 