Meta 新作:FlashAttention 的數(shù)值偏差有多大?
一、背景
最近 Meta 的研究員開發(fā)了一個(gè)新的框架來了解 LLM 訓(xùn)練中數(shù)值偏差的影響,并基于該框架評估了 LLM 中廣泛采用的 FlashAttention 的數(shù)值偏差。
對應(yīng)的論文為:[2405.02803] Is Flash Attention Stable?
PS:其實(shí)論文很簡單,結(jié)論也很簡單:使用 FlashAttention 相比 Baseline Attention 確實(shí)會帶來數(shù)值偏差。但帶來的數(shù)值偏差比從 FP32 到 FP16 的數(shù)值偏差小得多,甚至小于不同初始化方法帶來的偏差。吐槽一下,論文中的圖都比較模糊。
二、摘要
LLM 預(yù)訓(xùn)練的代價(jià)很高,也更加的復(fù)雜。很多 LLM 在預(yù)訓(xùn)練中都遇到了訓(xùn)練過程不穩(wěn)定的情況,通常表示為損失的毛刺(Spike)。數(shù)值偏差(Numeric Deviation)被認(rèn)為是導(dǎo)致這種訓(xùn)練不穩(wěn)定的潛在原因,但由于訓(xùn)練的成本很高,量化這一點(diǎn)非常有挑戰(zhàn)性。
本文中,作者開發(fā)了一種系統(tǒng)性的方法來理解數(shù)值偏差的影響,并使用廣泛采用的 FlashAttention 來驗(yàn)證了該框架。作者發(fā)現(xiàn),與 Baseline Attention 相比,在單個(gè)前向傳播中,BF16 下的 FlashAttention 會有超過一個(gè)數(shù)量級的數(shù)值偏差。然而,使用基于 Wasserstein 距離的數(shù)據(jù)驅(qū)動分析來提供數(shù)值偏差對訓(xùn)練過程中模型權(quán)重影響的上限,發(fā)現(xiàn) FlashAttention 中的數(shù)值偏差比低精度訓(xùn)練的影響小 2-5 倍。
三、引言
3.1 數(shù)值精度
如下圖為常見的浮點(diǎn)數(shù)值精度,其中 sign 表示符號位,exponent 表示指數(shù)位,fraction 表示尾數(shù)位。相比 float32,float16 的指數(shù)位和尾數(shù)位都更小,而 bfloat16 的指數(shù)位和 float32 相同,只是尾數(shù)位更少。因此,通常 float32 轉(zhuǎn) float16 時(shí)通常會帶來較大的精度損失,而 float32 轉(zhuǎn) bfloat16 通常只需要做小數(shù)位的截?cái)?,損失相對較小。現(xiàn)在的 LLM 預(yù)訓(xùn)練中通常都會使用 bfloat16。
- Float32:指數(shù)位 8 位,尾數(shù)位 23 位,數(shù)據(jù)范圍為[1.18e-38, 3.40e+38]
- float16:指數(shù)位 5 位,尾數(shù)位 10 位,數(shù)據(jù)范圍為[6.10e-05, 6.55e+04]
- bfloat16:指數(shù)位 8 位,尾數(shù)位 7 位,數(shù)據(jù)范圍為[1.18e-38, 3.39e+38]
3.2 數(shù)值誤差
在浮點(diǎn)數(shù)的計(jì)算中會存在兩種常見的誤差:
- 溢出誤差(Overflow Error):浮點(diǎn)都有一個(gè)有限的表示范圍,當(dāng)計(jì)算結(jié)果超出這個(gè)表示范圍時(shí)就會產(chǎn)生溢出錯誤,往往表現(xiàn)為無窮大。比如,令 float a = FLT_MAX * 2,此時(shí) a 的值為正無窮大。
- 舍入誤差(Rounding Error):浮點(diǎn)數(shù)有固定的有效位數(shù),當(dāng)一個(gè)數(shù)值不能被精確表示時(shí),就會被舍入到最接近的可表示的浮點(diǎn)數(shù)。這種輸入在數(shù)值計(jì)算中是不可避免的,因?yàn)榇蠖鄶?shù)實(shí)數(shù)在計(jì)算機(jī)中無法被精確表示。比如在 C 中打印 0.1f,printf("a = %.20f\n", 0.1f),其輸出結(jié)果為 0.10000000149011611938,是一個(gè)近似值。
除此之外,有時(shí)也會提到下溢誤差(Underflow Error):當(dāng)一個(gè)非常小的非零結(jié)果小于浮點(diǎn)數(shù)表示范圍下限時(shí)發(fā)生,通常導(dǎo)致結(jié)果被舍入為零。
由于 float16 和 bfloat16 的不同指數(shù)位和尾數(shù)位,也就導(dǎo)致它們出現(xiàn)誤差的場景不太一樣。
- float16:指數(shù)位較少,尾數(shù)位較多,表示范圍有限,但表示精度更高,因此更容易發(fā)生溢出誤差。
- bfloat16:指數(shù)位較多,尾數(shù)位較少,表示范圍更大,但表示精度有限,因此更容易發(fā)生舍入誤差。下溢誤差也更多一些。
3.3 訓(xùn)練損失毛刺
在 Meta OPT、BigScience Bloom、Google PaLM、TII Falcon 以及智源 GLM 訓(xùn)練中都出現(xiàn)了訓(xùn)練損失出現(xiàn)毛刺的情況,也有一些有效的手段可以緩解,但依舊不知道其根因。比如 Google PaLM 中驗(yàn)證了其并非是單個(gè)樣本導(dǎo)致的。
如下圖所示,是 [2211.05100] BLOOM: A 176B-Parameter Open-Access Multilingual Language Model 中遇到的毛刺現(xiàn)象:
3.4 評估指標(biāo)
Wasserstein 距離,也稱為 Earth Mover’s Distance (EMD),是一種衡量兩個(gè)概率分布之間差異的方法。這種距離的直觀含義是,將一個(gè)概率分布轉(zhuǎn)變成另一個(gè)概率分布所需要的“工作量”或“成本”,其中“工作量”可以理解為將一堆形狀不同的沙子(一個(gè)概率分布)鏟動并重塑為另一堆沙子(另一個(gè)概率分布)所需要的努力。
Wasserstein 距離基于最優(yōu)運(yùn)輸理論。給定兩個(gè)概率分布 P 和 ??,以及一個(gè)成本函數(shù) ??(??,??),Wasserstein 距離定義為將分布 P 轉(zhuǎn)變?yōu)?Q 所需的最小成本。數(shù)學(xué)上,它表示為:
這里的 π 是 P 和 ?? 之間的所有可能的聯(lián)合分布的集合,而 Π(P,Q) 表示所有這些聯(lián)合分布中,邊際分布分別是 P 和 Q 的集合。
相比其他距離度量(如歐氏距離或 KL 散度),Wasserstein 距離的一個(gè)主要優(yōu)勢在于其能夠更加有效地處理概率分布之間的微小變化,特別是當(dāng)這些分布不重疊或僅部分重疊時(shí)。這使得 Wasserstein 距離在數(shù)據(jù)稀疏或異構(gòu)的情況下特別有用。
四、方法&實(shí)驗(yàn)
4.1 方法
作者開發(fā)了一個(gè) microbenchmark 來隔離和研究 FlashAttention 引起的數(shù)值偏差。其設(shè)計(jì)如下圖 Fig 2 所示,在原始的 FlashAttention 中只支持 FP16 和 BF16 格式,因此作者重新實(shí)現(xiàn)了 FlashAttention,以便分析不同的數(shù)值精度的影響。作者進(jìn)一步修改模型,可以在每次調(diào)用 Attention 時(shí)計(jì)算 Baseline Attention 和 FlashAttention 的注意力矩陣輸出,從而可以使用最大差異(max difference)以及 Wasserstein 距離來度量差異。作者也進(jìn)行了一系列訓(xùn)練來度量整個(gè)訓(xùn)練過程中模型權(quán)重的差異。
4.2 數(shù)據(jù)類型的影響
如下圖 Fig.3 所示,作者對比了不同數(shù)據(jù)類型下 Baseline Attention 和 FlashAttention 的數(shù)值偏差,可以看出,數(shù)值精度越高,偏差越?。?/p>
為了進(jìn)一步分析這種數(shù)值偏差,作者探索了序列長度對數(shù)值偏差的影響,其中會保持 FlashAttention 的 tile 大小和 SRAM 大小相同。如下圖所示,隨著序列長度的增加,數(shù)值偏差也會適當(dāng)增加。其中左圖(a)表示最大誤差,右圖(b)表示誤差的均值。由于序列變長,也就需要更多的 tile,相應(yīng)也有更多的 resaling,這也就可能產(chǎn)生更多的誤差:
4.3 算法配置的影響
如下圖 Fig 6 所示,作者進(jìn)一步探索了 FlashAttention 中不同配置的影響:
- (a)和(c)針對不同的 Block/tile Area 大小的影響,使用比較大的 Block 后 Baseline Attention 和 FlashAttention 的差異很小,主要是因?yàn)?rescaling 計(jì)算更少一些。
- (b)使用 Square Block 對 Baseline Attention 和 FlashAttention 的影響不大。?
4.4 模型權(quán)重的變化
作者進(jìn)一步驗(yàn)證了訓(xùn)練中模型權(quán)重的變化(對比 Baseline Attention 和 FlashAttention),如下圖 Fig 7 所述,不管是最大誤差還是 Wasserstein 距離都會隨著訓(xùn)練的迭代而逐漸變大,并且趨勢類似:
如下圖 Fig.8 所示,作者進(jìn)一步驗(yàn)證了整個(gè)訓(xùn)練中其他變量帶來的模型權(quán)重的偏差??梢钥闯?,雖然 Baseline Attention 和 FlashAttention 會導(dǎo)致權(quán)重產(chǎn)生誤差,但是其甚至比不同初始化方法帶來的誤差還小,更是遠(yuǎn)小于 FP16 vs BF16 和 FP16 vs FP32 帶來的誤差:
五、參考鏈接
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
