原來Scaling Law還能被優(yōu)化?Meta這招省token又提效
2017 年,一篇《Attention Is All You Need》論文成為 AI 發(fā)展的一個重要分水嶺,其中提出的 Transformer 依然是現(xiàn)今主流語言模型的基礎范式。尤其是在基于 Transformer 的語言模型的 Scaling Law 得到實驗驗證后,AI 領域的發(fā)展更是進入了快車道。

現(xiàn)如今,這篇論文的引用量正向 19 萬沖刺,而 Transformer 和注意力機制本身也已經(jīng)歷了很多改進和創(chuàng)新,比如我們前段時間報道過的「Multi-Token Attention」和「Multi-matrix Factorization Attention」等。
隨著 AI 的不斷發(fā)展,現(xiàn)如今的一個重要挑戰(zhàn)是如何獲得足夠多高質(zhì)量的 token。又或者,該如何更高效地利用這些 token?為此,還必須對 Transformer 進行進一步的升級改造。
近日,Meta 的一篇論文公布了他們在這方面取得的一個新進展,提出了一種旋轉(zhuǎn)不變型三線性注意力機制,并證明其表示能力與 2-simplicial Transformer 相當。更重要的是,它的表現(xiàn)甚至足以改變 Scaling Law 中的系數(shù)。Meta 也用 Triton 實現(xiàn)了這種注意力機制。

該研究基于 RoPE 向三線性函數(shù)的泛化;而 2-simplicial Transformer 則源自 2019 年 Clift et al. 的研究《Logic and the 2-Simplicial Transformer》,其中將 Transformer 的點積注意力機制泛化到了三線性形式。

- 論文標題:Fast and Simplex: 2-Simplicial Attention in Triton
- 論文地址:https://arxiv.org/pdf/2507.02754.pdf
他們進一步證明,在有限的 token 預算下,2-simplicial Transformer 的擴展性優(yōu)于 Transformer。
此外,他們的實驗還表明,2-simplicial Transformer 相對于 Transformer 具有更有利的參數(shù)數(shù)量 scaling 指數(shù)。這表明,與 Chinchilla scaling 不同,有可能以比 2-simplicial Transformer 的參數(shù)增長更慢的速度增加 token 數(shù)量。
研究結(jié)果表明,在 token 約束下運行時,與點積注意力機制 Transformer 相比,2-simplicial Transformer 可以更有效地逼近自然語言的不可約熵。
神經(jīng) Scaling Law 概述
要理解這項研究的意義,首先需要了解一下 Scaling Law。
簡單來說,就是損失 L 會隨模型參數(shù)總數(shù) N 和 token 數(shù)量 D 呈冪律衰減:

其中,第一項 E 通常被描述為不可約損失,對應于自然文本的熵。第二項描述了這樣一個事實:具有 N 個參數(shù)的模型的表現(xiàn)達不到理想的生成過程。第三項則對應于這樣一個事實:我們僅使用有限的數(shù)據(jù)樣本進行訓練,并且沒有將模型訓練到收斂。
理論上,當 N → ∞ 且 D → ∞ 時,大型語言模型應該接近底層文本分布的不可約損失 E。
對于給定的計算預算 C,其中 F LOP s (N, D) = C,可以將最佳參數(shù)數(shù)量表示為 Nopt ∝ C a,將最佳數(shù)據(jù)集大小表示為 Dopt ∝ C b。Hoffmann 等人 (2022) 的作者進行了多項實驗,并將參數(shù)函數(shù)擬合到損失函數(shù)中,以估計指數(shù) a 和 b:多種不同的方法證實,a 大約為 0.49,b 大約為 0.5。這引出了 Hoffmann 等人 (2022) 的核心論點:必須根據(jù)模型大小按比例縮放 token 數(shù)量。
對于給定的計算預算 C,其中 FLOPs (N, D) = C,可以將最佳參數(shù)數(shù)量表示為 N_opt ∝ C^a,將最佳數(shù)據(jù)集大小表示為 D_opt ∝ C^b。Hoffmann et al. (2022) 進行了多次實驗,并根據(jù)損失擬合了參數(shù)函數(shù),以估計指數(shù) a 和 b。
結(jié)果,通過多種不同方法發(fā)現(xiàn):a 約為 0.49,b 約為 0.5。
如此,便引出了 Hoffmann et al. (2022) 的一個核心論點:必須根據(jù)模型大小按比例擴展 token 數(shù)量。
但是,正如前面討論的那樣,足夠高質(zhì)量且足夠數(shù)量的 token 是預訓練擴展的新瓶頸,因此需要探索替代的訓練算法和架構(gòu)。另一方面,最近的研究表明,之前文獻中提出的大多數(shù)建模和優(yōu)化技術(shù)僅僅改變了誤差(偏移了 E),并沒有從根本上改變冪律中的指數(shù)。谷歌 DeepMind 的研究者 Katie Everett 對此進行過精彩的討論:
https://x.com/_katieeverett/status/1925665335727808651

2-simplicial Transformer
2-simplicial Transformer 由 Clift et al. (2019) 提出,他們將點積注意力機制從雙線性擴展為三線性形式,也就是從 1-simplex 擴展成了 2-simplex。
先來看看標準的注意力機制:

其中,每一項都是點積
。
然后,通過逐行 softmax 運算將注意力分數(shù)(logit)轉(zhuǎn)換為概率權(quán)重:

注意力層的最終輸出是根據(jù)這些注意力分數(shù)對這些值進行線性組合得到的
。
Clift et al. (2019) 的 2-simplicial Transformer 論文將其推廣到三線性積,其中有兩個額外的鍵和值投射矩陣 W_K′ 和 W_V′,從而得到 K′ = XW_K′ 和 V′ = XW_V′。然后,2-simplicial Transformer 的注意力 logit 由 Q、K 和 K′ 的三線性積給出,從而得到以下三階張量:

從而注意力張量變?yōu)椋?/span>

注意力運算的最終輸出定義為:

其中
表示兩個向量的元素級 Hadamard 積。2-simplicial Transformer 的偽代碼如算法 1 所示。注意,公式 5 不包含 RoPE 等任何位置編碼。

基于行列式的三線性形式
Su et al., 2024 提出 RoPE 時,是想將其作為一種用于 Transformer 語言模型的序列位置信息捕獲方法。RoPE 對查詢 q_i 和鍵 k_j 應用位置相關的旋轉(zhuǎn),使得點積 <q_i, K_j> 是相對距離 i-j 的函數(shù)。特別需要注意的是,點積對于正交變換 R 具有不變性:

這對于 RoPE 至關重要,因為對于同一位置 i 相同的查詢 q_i 和鍵 k_i,我們期望其點積不會因基于位置的旋轉(zhuǎn)而發(fā)生變化。請注意,(5) 式中定義的三線性形式并非是旋轉(zhuǎn)不變,并且對 q_i 、k_i 和 k′_i 進行相同的旋轉(zhuǎn)不再保留內(nèi)積。因此,為了將 RoPE 泛化到 2-simplicial 注意力模型,探索其他具有旋轉(zhuǎn)不變性的雙線性和三線性形式至關重要。
而 Meta 的這個團隊注意到,以下函數(shù)也具有旋轉(zhuǎn)不變性:

可以使用帶符號的行列式運算
來計算 A^(det) ∈ ?^n×n×n。對于任意向量 q,令 q^(l) = q = q [3 (l - 1) : 3l] 為其第 l 個大小為 3 的塊。其 logit 定義為:

由于公式 8 根據(jù) Sarrus 規(guī)則包含 2 個點積項,因此需要修改算法 1,使用 2 個 einsum 而不是第 2 行中的 1 個。最終的注意力權(quán)重 S 是通過對上述 logit 應用 softmax 函數(shù)來計算的,類似于公式 6。然后,token i 的輸出是值向量的加權(quán)和,如公式 7 所示。
定理:對于任意輸入大小 n 和輸入范圍 m = n^{O (1)},存在一個具有單個注意力頭的 Transformer 架構(gòu),其 logit 計算方式如公式 (9) 所示,注意力頭維度為 d = 7,使得對于所有 X ∈ [M]^N,如果
,則 Transformer 對元素 x_i 的輸出為 1,否則為 0。
對該定理的證明請見原論文附錄。
模型設計
由于 2-simplicial 注意力在序列長度 n 上的擴展復雜度為 O (n^3),因此將其應用于整個序列是不切實際的。該團隊的做法是將其參數(shù)化為 O (n× w_1 × w_2),其中 w_1 和 w_2 定義的是序列上滑動窗口的維度。每個查詢向量 Q_i 會關注 w_1 個 K 鍵和 w_2 個 K′ 鍵的局部區(qū)域,從而減輕計算負擔。該團隊系統(tǒng)地評估了 w_1 和 w_2 的各種配置,以確定計算效率和模型性能之間的最佳平衡點(見表 1)。

對于因果點積注意力機制,長度為 n 的序列的復雜度由下式給出:

其中 n 是序列長度。這涉及兩次矩陣乘法:一次用于 Q@K,一次用于 P@V,每次乘法每個元素都需要兩次浮點運算。因果掩碼使其能夠跳過 1/2 的計算。
相比之下,以 w_1 和 w_2 為參數(shù)的 2-simplicial 注意力機制的復雜度表示為:

其復雜度的增長來源是三線性 einsum 運算,與標準點積注意力機制相比,它需要進行一次額外的乘法運算。
該團隊選擇窗口大小為 (512, 32),以平衡延遲和質(zhì)量。在此配置下,2-simplicial 注意力機制的計算復雜度與 48k 上下文長度的點積注意力機制相當。
圖 2 給出了一個實現(xiàn)。因此,像在 Flash 注意力機制中那樣平鋪式查詢 Q 會導致計算吞吐量較低。受 Native Sparse Attention 的啟發(fā),Meta 該團隊采用的模型架構(gòu)利用了較高 (64) 的分組查詢注意力 (GQA) 比率。這種方法能夠沿著查詢頭高效地平鋪,確保密集計算,并消除昂貴的逐元素掩碼。

該團隊還引入了一系列針對 2-simplicial 注意力的核優(yōu)化,這些優(yōu)化基于使用在線 softmax 的 Flash Attention。詳見原論文。下面來重點看看實驗表現(xiàn)。

實驗與結(jié)果
這個團隊訓練了一系列 MoE 模型,其參數(shù)范圍從 1B 活動參數(shù)和 57B 總參數(shù)到 3.5B 活動參數(shù)和 176B 總參數(shù)。具體配置見原論文。

該團隊發(fā)現(xiàn),從 1B (活動)參數(shù)模型到 3.5B (活動)參數(shù)模型,負對數(shù)似然的擴展(?)出現(xiàn)了下降。
此外,在小于 2B (活動)參數(shù)的模型中,使用 2-simplicial 注意力機制沒有任何好處。
基于此,該團隊估算了 2-simplicial 注意力機制與點積注意力機制的冪律系數(shù)有何不同?;谇笆龇椒?,其損失可以表示為:

由于訓練這兩個模型使用的 token 數(shù)量相同,因此可以忽略第三項,將損失簡化為:

其中 β = - log E′′ - logA ,由于 E′ 較小,E′′ 是 E′ 的近似值。注意,這里使用了 log (a + b) = log (1 + a/b) + log (b) 來分離這兩個項,并將 1 + a/b 項隱藏在 E′′ 中。
因此,可以根據(jù)表 2 中的損失估算兩組模型的 α 和 β,其中 N 代表每個模型中的有效參數(shù)。
該團隊在表 3 中估計了 Transformer 和 2-simplicial Transformer 的斜率 α 和截距 β。

可以看到,與點積注意力 Transformer 相比,2-simplicial 注意力具有更陡的斜率 α,即其 Scaling Law 的指數(shù)更高。



































