NVIDIA:Blackwell GPU MXFP8 預(yù)訓(xùn)練最佳實(shí)踐
一、背景
筆者之前寫(xiě)過(guò) FP8 訓(xùn)練的綜述文章以及 FP4 訓(xùn)練和推理的綜述文章,本文對(duì)其進(jìn)一步補(bǔ)充,介紹 NVIDIA 最新的使用 MXFP8 預(yù)訓(xùn)練的方案。
對(duì)應(yīng)的論文:[2506.08027] Recipes for Pre-training LLMs with MXFP8 [1]
二、摘要
精度縮放——即在預(yù)訓(xùn)練過(guò)程中使用更少的比特來(lái)表示模型參數(shù)及相關(guān) Tensor——已成為一種在不犧牲精度前提下提升 GPU 效率的有效技術(shù)。NVIDIA 最新 Blackwell GPU 中引入 Microscaling (MX) 格式,為 Tensor 量化提供了細(xì)粒度解決方案。
盡管 MX 格式相較于其他低精度表示法有望提升數(shù)值穩(wěn)定性,但在實(shí)際應(yīng)用中仍需謹(jǐn)慎使用。本文研究表明:當(dāng)采用 OCP 規(guī)范建議的舍入模式進(jìn)行 LLM 預(yù)訓(xùn)練時(shí),會(huì)導(dǎo)致模型不收斂。為此,作者提出了一種改進(jìn)的舍入模式——通過(guò)采用"向無(wú)窮大舍入"方式計(jì)算縮放因子,成功實(shí)現(xiàn)了 8B 參數(shù)模型在 15T Token 上采用 MXFP8 格式的預(yù)訓(xùn)練。
PS:可能是因?yàn)椴捎?Hopper GPU 模擬的方式,而不是真實(shí)的 Blackwell 訓(xùn)練,因此論文并沒(méi)有提供相應(yīng)的效率提升數(shù)據(jù)。
三、引言
3.1 MXFormat
2023 年,OCP(Open Compute Project) 在 AMD, Arm, Intel, Meta, Microsoft, NVIDIA, Qualcomm 的參與下提出 Microscaling(MX)Format 規(guī)范(OCP Microscaling Formats (MX) Specification Version 1.0 [2]),主要是為了對(duì)跨硬件/軟件平臺(tái)可實(shí)施的新功能及格式進(jìn)行標(biāo)準(zhǔn)化,有效減少軟件與基礎(chǔ)設(shè)施成本,并消除定制化解決方案帶來(lái)的各類(lèi)附加費(fèi)用或管理負(fù)擔(dān),推動(dòng)硬件性能與效率的提升。
如下圖所示,MX 最主要的特點(diǎn)是其包含三部分內(nèi)容(很類(lèi)似于常見(jiàn)的 Per-Block 細(xì)粒度量化方式,只不過(guò)這里是制定了一個(gè)統(tǒng)一的規(guī)范):
- P:規(guī)定了 d 個(gè) bit 數(shù)據(jù)的表示(編碼)方式,比如 FP8 的 E5M2 是怎么表示的。
- k:k 個(gè)元組作為一個(gè) Block。
- X:上述 k 個(gè)元素的 Block 會(huì)對(duì)應(yīng)一個(gè)共享的 Scale 值。
我們?cè)谥暗奈恼轮刑岬竭^(guò),即使都是 E5M2 或者 E4M3,不同公司的硬件可能采用不同的格式。比如 NVIDIA Hopper GPU 上的 E5M2 符合 IEEE 754 Style,而 E4M3 卻不符合 IEEE 754 Style。如下圖所示,IEEE 754 Style 的 E4M3 的范圍為 [-240, 240],而 ARM-Intel-Nvidia Style 的 E4M3 的范圍是 [-448, 448]:
在 MX 中也對(duì)上述問(wèn)題進(jìn)行了規(guī)范化,以 MXFP8 為例,其規(guī)定的 E4M3 和 E5M2 編碼方式如下圖 Table 1 和 Table 2 所示:
3.2 細(xì)粒度量化
更低的精度通常意味著更難量化,為了維持精度需要更細(xì)力度的 Scaling Factor,比如:
- FP16:早期使用 FP16 進(jìn)行混合精度訓(xùn)練時(shí)通常整個(gè)模型一個(gè) Scaling Factor 即可。
- FP8:在 Inference 時(shí)通常 Per-Tensor 的 Scaling Factor 即可比較好的維持精度;而 Training 時(shí)往往需要 Per-Block 或 Per-Channel,不過(guò) Block 通常比較大,比如 128x128 或 128x1。
- FP4:需要 Per-Block 量化,并且 Block 需要比較小,比如 32 或 16。
更細(xì)粒度的量化也意味著更高的額外成本,Block 越?。6仍郊?xì)),額外成本越高。如下圖所示,對(duì)于一個(gè)常見(jiàn)的內(nèi)積操作:
- Per-Tensor 量化:需要額外執(zhí)行 1 次 Scaling Factor 處理。
- Per-Block 量化:需要額外執(zhí)行很多次 Scaling Factor 處理。
為了更好的解決上述問(wèn)題,NVIDIA 在新的 Blackwell Tensor Core 中支持了新的 Block-Scaled 類(lèi)型,原生支持 Microscaling Formats,如下圖 Table 1 所示,其支持 MXFP8、MXFP6、MXFP4:
如下圖所示,在 Tensor Core 計(jì)算時(shí),可以將數(shù)據(jù) A/B 及它們對(duì)應(yīng)的 Scaling Factor A/B 一起輸入,并全部在 Tensor Core 內(nèi)完成。
當(dāng)然,其對(duì)數(shù)據(jù)類(lèi)型也有一定的要求,如下圖所示:
四、MXFP8 預(yù)訓(xùn)練
4.1 轉(zhuǎn)換 FP32 到 MXFP8
在訓(xùn)練的 Forward 與 Backward 過(guò)程中,模型 Weight、Activation 及 Gradient Tensor 均由 FP32 量化到 MXFP8 格式。量化后的 MXFP8 Tensor 隨后存儲(chǔ)于硬件中并執(zhí)行運(yùn)算。作者首先闡述了轉(zhuǎn)換過(guò)程 Quantize_to_fp8(Vi/2X)。下文所述的量化方法統(tǒng)一適用于所有 MX 格式(包括 E4M3、E5M2、E2M3、E3M2 及 E2M1),僅 MX 數(shù)據(jù)類(lèi)型存在差異。
計(jì)算 X 值:通常情況下,Tensor 中各 Block 內(nèi)的大部分?jǐn)?shù)值會(huì)超出目標(biāo) MX 格式的可表示范圍,既可能低于最小可表示數(shù)(下溢),也可能高于最大可表示數(shù)(上溢)。為解決這一問(wèn)題,需將 Block 內(nèi)所有數(shù)值乘以一個(gè) Scale 因子,使絕大多數(shù)數(shù)值被調(diào)整至可表示范圍內(nèi)。
該 Scale 因子 X 的計(jì)算基于 32 個(gè)(MX 規(guī)范,k)高精度輸入值中的絕對(duì)最大值(amax),即amax = max(‖Vi‖); 1≤i≤32。其核心目標(biāo)是將輸入中的 amax 映射為 MX 格式中的最大可表示值。當(dāng)輸入數(shù)據(jù)包含無(wú)窮大(Infinity)或非數(shù)值(NaN)時(shí)需特殊處理:若某 Block 的 amax 為 0,則設(shè)定 X=-127,此時(shí) X 為 2-127,且該情況下所有 Qi’ 均置為 0。
根據(jù) OCP 規(guī)范,當(dāng) X 不為 Inf、NaN 或 0 時(shí),X 應(yīng)設(shè)定為不超過(guò) “amax 除以 MX 格式類(lèi)型最大可表示二次冪” 的最大二次冪。以 E4M3 類(lèi)型為例,由于 448 是其最大幅值,故 X 的計(jì)算公式為X=floor(log?(amax))/floor(log?(448))。值得注意的是,OCP 規(guī)范在此計(jì)算過(guò)程中忽略了該比率浮點(diǎn)數(shù)尾數(shù)部分的影響。
作者觀察到遵循 OCP 規(guī)范時(shí)存在精度下降現(xiàn)象。如下圖 Figure 2 所示為兩種 Token 規(guī)模(300B 和 1T)下訓(xùn)練的 843M 參數(shù)量 Transformer 模型的訓(xùn)練損失曲線(xiàn)。其采用兩種不同配置方案:
- cfg1:所有 Tensor(Weight W、Activation A、Gradient G)均采用 E4M3 格式。
- cfg2:Weight W、Activation A 采用 E4M3 格式,Gradient G 采用 E5M2 格式。
E5M2 格式相較于 E4M3 具有約 1.8 倍的 binades 優(yōu)勢(shì)。鑒于 Gradient 通常具有更大的動(dòng)態(tài)范圍,早期的工作 [2209.05433] FP8 Formats for Deep Learning [3] 主張采用 E5M2 格式進(jìn)行 Tensor 縮放。實(shí)驗(yàn)結(jié)果表明,在 cfg1 與 cfg2 中,使用 OCP 方法計(jì)算縮放因子均會(huì)導(dǎo)致訓(xùn)練發(fā)散(如下圖 Figure 2a)或相對(duì)于 BF16 基準(zhǔn)的損失差距擴(kuò)大(如下圖 Figure 2b)。
如下圖 Algorithm 1 概述了作者計(jì)算 Scale 因子的方法。其核心改進(jìn)在于:當(dāng)處理 amax 與 MX 格式最大可表示值 destmax 的比值指數(shù)時(shí),采用向正無(wú)窮方向的 round-up 策略(同時(shí)飽和至 UE8M0 格式的極值邊界)。這與 OCP 方案形成鮮明對(duì)比,后者實(shí)質(zhì)上是建議對(duì) Scale 值執(zhí)行 round-down 操作。由于高精度值 Vi 需通過(guò) Scale 因子 2X 進(jìn)行縮放,對(duì)分?jǐn)?shù)項(xiàng) (Vi/2X) 分母實(shí)施 round-up 操作,會(huì)更傾向于將 amax 映射至 destmax 以下;反之,OCP 方法則傾向于使 amax 超過(guò) destmax(后續(xù)必須通過(guò)截?cái)嗵幚硎蛊淇杀硎荆W髡咄茰y(cè) OCP 取整方法帶來(lái)的飽和效應(yīng)會(huì)影響模型精度。
如上圖 Figure 2 所示,采用提出的舍入方案后,Gradient 位寬配置為 E4M3 的 MXFP8(藍(lán)色曲線(xiàn))與 E5M2 的 MXFP8(紫色曲線(xiàn))在 300B 和 1T Token 的訓(xùn)練過(guò)程中,其損失曲線(xiàn)均與 BF16 完全重合。
FP32 數(shù)值到 MX 格式的量化過(guò)程:當(dāng)縮放因子 X 確定后,Tensor Vi 通過(guò)乘以 2X 進(jìn)行尺度變換,隨后量化至最接近的 FP8 可表示數(shù)值(即 Quantize_to_fp8())。該量化步驟采用“就近取偶(Round-to-nearest-ties-to-even,RN)”舍入法,且轉(zhuǎn)換過(guò)程具有飽和特性——若舍入結(jié)果超出 FP8 最大值或低于最小值,則將結(jié)果截取至相應(yīng)的極值。
這種轉(zhuǎn)換機(jī)制在低精度 LLM 預(yù)訓(xùn)練中的典型應(yīng)用場(chǎng)景是:矩陣乘積累加運(yùn)算(MMA)的輸出(通常以 FP32 格式存儲(chǔ))需要映射為 MXFP8 格式,相比存儲(chǔ) FP32 數(shù)值可顯著節(jié)省寫(xiě)入帶寬和存儲(chǔ)容量。模型后續(xù)運(yùn)算讀取 MXFP8 數(shù)值時(shí),相較加載 FP32 數(shù)據(jù)也能減少讀取帶寬消耗。此外,由于 Tensor Core 可直接處理 MX 格式輸入,低精度 MMA 操作不僅能降低能耗,還能獲得更高的計(jì)算吞吐量。
4.2 所有 Tensor 采用 E4M3
在 Blackwell 架構(gòu)中,F(xiàn)P8 浮點(diǎn)格式包含兩種變體:E4M3 與 E5M2。實(shí)驗(yàn)研究表明:
Weight 與 Activation 量化性能對(duì)比:采用 E4M3 格式量化 Weight 和 Activation 時(shí)展現(xiàn)出更優(yōu)的訓(xùn)練收斂性。如下圖 Figure 3a 所示(測(cè)試模型與 Figure 2b 相同,參數(shù)量 843M),當(dāng) Activation(紫色曲線(xiàn))或 Weight(藍(lán)色曲線(xiàn))采用 E5M2 格式時(shí),其損失函數(shù)收斂性顯著差于所有 Tensor 采用 E4M3 量化方案(橙色曲線(xiàn))。值得注意的是,僅 Gradient 采用 E5M2 時(shí)(黃色曲線(xiàn))仍能維持較好的收斂特性。
Gradient Tensor 量化分析:E4M3 格式在 Gradient 量化中能保持與 BF16 預(yù)訓(xùn)練相當(dāng)?shù)膿p失,這一優(yōu)勢(shì)在參數(shù)量 ≥2B 的模型中尤為顯著。如下圖 Figure 3c 展示了 8B LLM(1T Token 訓(xùn)練)的對(duì)比結(jié)果:E4M3 Gradient 量化(橙色曲線(xiàn))的最終損失值顯著低于 E5M2 方案(黃色曲線(xiàn)),且該差距隨訓(xùn)練 Token 數(shù)量增加而擴(kuò)大。這一現(xiàn)象揭示了模型參數(shù)量對(duì)數(shù)值格式選擇的敏感性,強(qiáng)調(diào)需在不同規(guī)模模型中系統(tǒng)評(píng)估格式的數(shù)值特性(PS:這也是為什么筆者一直提到之前很多文章只在小規(guī)模模型、數(shù)據(jù)量下做實(shí)驗(yàn)不夠有說(shuō)服力的原因)。
既往研究采用的 Tensor 級(jí)縮放 FP8 方案及 DeepSeek V3([2412.19437] DeepSeek-V3 Technical Report [4]) 提出的粗粒度 Block-Scaled 方案均默認(rèn)選用 E5M2 格式處理 Gradient Tensor(PS:論文這里表述有問(wèn)題,DeepSeek V3 中其實(shí)所有 Tensor 都已經(jīng)采用 E4M3 格式;此外,這里的粗粒度是相對(duì) 32 的 Block 大小而言,在 DeepSeek 中為了效率采用的是 128 或者 128x128 的 Block 大?。1狙芯堪l(fā)現(xiàn):當(dāng)采用細(xì)粒度縮放(32 元素 Block)時(shí),E4M3 格式的 17.8 個(gè) binades 可充分滿(mǎn)足動(dòng)態(tài)范圍需求。在滿(mǎn)足動(dòng)態(tài)范圍前提下,量化精度成為關(guān)鍵因素——E4M3 每個(gè)指數(shù)區(qū)間包含 8 個(gè)量化樣本,其采樣密度是 E5M2(4樣本/區(qū)間)的 2 倍。因此,提出的 MXFP8 預(yù)訓(xùn)練方案對(duì)所有三類(lèi) Tensor(Weight、Activation、Gradient)均采用 E4M3 數(shù)據(jù)類(lèi)型進(jìn)行量化。
MXFP8 實(shí)例化層級(jí)及訓(xùn)練流程:論文所有研究均采用基于語(yǔ)言的 Transformer 模型,未來(lái)工作將探索該方案在語(yǔ)音與視覺(jué)模型中的應(yīng)用。研究表明,量化策略建議:
- 將模型中所有 Transformer Block 的 QKV、Proj 以及 FFN 的 Up-proj 和 Down-proj 轉(zhuǎn)換為 MXFP8 格式。
- Self-Attention 中的批量矩陣乘法(BMM1:Query-Key 點(diǎn)積和 BMM2:Attention Score-Value 乘積)以及 Softmax、激活函數(shù)和殘差相加等運(yùn)算仍保持高精度計(jì)算。
- 輸入 Embedding 層和最終輸出 LM-head 同樣采用 BF16 或 FP16 格式。
如下圖 Figure 4 所示,這種配置能最可靠地維持與 BF16 預(yù)訓(xùn)練相當(dāng)?shù)木人?,論文所有?shí)驗(yàn)均遵循此準(zhǔn)則。在 MXFP 量化訓(xùn)練過(guò)程中,框架需為 Tensor(Weight、Activation 和 Gradient)保持兩個(gè)副本:每個(gè)副本沿點(diǎn)積歸約(dot-product reduction)軸(行與列)分別量化。Figure 4 展示了訓(xùn)練迭代中各 Tensor 在 Forward(FPROP)、Weight Gradient(WGRAD)和 Activation Gradient(DGRAD)計(jì)算中的使用方式。由于每個(gè) Tensor 需以原始和轉(zhuǎn)置兩種形態(tài)參與運(yùn)算,量化需沿行列兩個(gè)獨(dú)立軸向分別執(zhí)行。
當(dāng)前研究總結(jié):提出的 MX Scale 因子舍入方案解決了基于 OCP 方法導(dǎo)致的不收斂問(wèn)題,在 843M 參數(shù)模型上實(shí)現(xiàn)了 1T Token 訓(xùn)練下與 BF16 相當(dāng)?shù)木?。結(jié)合 Algorithm 1 中的 E4M3 格式及 Scale 因子計(jì)算方法,該方案可擴(kuò)展至 8B 參數(shù)模型(15T Token 訓(xùn)練)——作者聲稱(chēng),這是目前采用 MXFP 格式的最大規(guī)模 LLM 預(yù)訓(xùn)練案例。
如下圖 Figure 6 所示,W16A2.5 規(guī)模 MoE 模型,1T Token 預(yù)訓(xùn)練也能實(shí)現(xiàn)同樣的效果:
4.3 15T Token MXFP8 預(yù)訓(xùn)練結(jié)果
作者采用 Megatron-LM 框架預(yù)訓(xùn)練了一個(gè) 8B 參數(shù)的 Nemotron 模型。該模型包含 32 個(gè) Transformer Block,每個(gè) Block 32 個(gè) Attention Head,隱層維度為 4096,采用 GQA 且 Group 大小為 8,KV 通道數(shù)為 128,預(yù)訓(xùn)練階段序列長(zhǎng)度為 8192。共訓(xùn)練 15T Token,Batch Size 為 768。初始學(xué)習(xí)率設(shè)為 6×10??,并通過(guò) cosine decay 到 6×10??。如下圖 Table 2 為幾個(gè)模型的詳細(xì)配置:
采用分階段數(shù)據(jù)混合策略進(jìn)行訓(xùn)練:第一階段使用促進(jìn)數(shù)據(jù)多樣性的混合數(shù)據(jù)集,第二階段則轉(zhuǎn)向高質(zhì)量數(shù)據(jù)集(如維基百科),在訓(xùn)練進(jìn)度達(dá)到 60% 時(shí)切換至第二階段。此類(lèi)混合策略在其他大規(guī)模預(yù)訓(xùn)練框架中亦有應(yīng)用。
模型預(yù)訓(xùn)練在 3072 Hopper GPU 上完成(實(shí)驗(yàn)周期內(nèi)缺乏支持 MX 格式的 Bloackwell 硬件平臺(tái))。通過(guò)在 Hopper GPU 上模擬 MX 格式實(shí)現(xiàn):輸入矩陣乘法加速器(MMA)的 Tensor 先量化為 MX 格式,在執(zhí)行 BF16 MMA 運(yùn)算前轉(zhuǎn)換回 BF16 格式。為驗(yàn)證模擬方案的數(shù)值保真度,作者與 Blackwell 平臺(tái)上采用真實(shí) MXFP8 格式訓(xùn)練的 2B 參數(shù) LLM 進(jìn)行對(duì)比實(shí)驗(yàn),確認(rèn)二者輸出結(jié)果完全一致。
如下圖 Figure 5 展示了 8B 預(yù)訓(xùn)練模型的訓(xùn)練損失及任務(wù)級(jí)準(zhǔn)確率??梢钥闯?,兩組下游任務(wù)的評(píng)估分?jǐn)?shù):
- MMLU 上的 5-shot 分?jǐn)?shù)。
- 9 個(gè)通用 Reasoning 基準(zhǔn)(ARC-Challenge 與ARC-Easy、Race、PIQA、Winogrande、Hellaswag、OpenBookQA、Social IQA 和 Commonsense QA)上 1-shot 分?jǐn)?shù)的平均值。
主要結(jié)果如下:
- 采用 MXFP8 預(yù)訓(xùn)練時(shí),模型的驗(yàn)證困惑度與 BF16 預(yù)訓(xùn)練結(jié)果持平(Figure 5 左圖)。在整個(gè)預(yù)訓(xùn)練過(guò)程中,MXFP8 與 BF16 的驗(yàn)證困惑度差異始終小于 0.50%。
- Figure 5 中、右兩圖顯示了兩組下游任務(wù)的評(píng)估分?jǐn)?shù)。MXFP8 訓(xùn)練模型的得分與 BF16 訓(xùn)練模型完全匹配,證明 MXFP8 可作為 LLM 預(yù)訓(xùn)練的有效候選方案。
MXFP8 與 FP8 對(duì)比:除 MXFP8 和 BF16 外,F(xiàn)igure 5 還展示了傳統(tǒng) FP8 精度訓(xùn)練同模型的任務(wù)級(jí)分?jǐn)?shù)。FP8 方案采用軟件模擬的分塊縮放技術(shù),通過(guò)整體 Tensor 縮放使多數(shù) Tensor 值落入量化格式的可表示范圍。遵循 [2504.03624] Nemotron-H: A Family of Accurate and Efficient Hybrid Mamba-Transformer Models [5] 的 FP8 預(yù)訓(xùn)練設(shè)置建議:模型首尾 Transformer Block 保持 BF16 精度,其余 Block 的線(xiàn)性層量化為 FP8,該配置適用于 20T Token 規(guī)模的 8B 和 56B 參數(shù) LLM 預(yù)訓(xùn)練。但保留部分 BF16 層會(huì)影響端到端加速比,并增加預(yù)訓(xùn)練復(fù)雜性——需額外決策哪些層維持高精度。實(shí)驗(yàn)表明,MXFP8 在這兩組任務(wù)上無(wú)需任何 BF16 層即可達(dá)到與 FP8 相當(dāng)?shù)木取?/p>
MXFP8 與分塊 FP8 的對(duì)比:進(jìn)一步地,諸如 Deepseek-V3 等研究表明,在使用 FP8 時(shí)需要縮小 Block 規(guī)模。在此配置下,部分 Tensor 需采用 1x128 向量級(jí)縮放,而其他 Tensor 則需實(shí)施分塊(如 128x128)縮放,這增加了 GEMM Kernel 函數(shù)設(shè)計(jì)的復(fù)雜度。MXFP8 的原生支持則簡(jiǎn)化了這一過(guò)程——其細(xì)粒度縮放機(jī)制提供了更優(yōu)的數(shù)值魯棒性,同時(shí)規(guī)避了小 Block 尺寸與硬件速度之間的權(quán)衡問(wèn)題。
綜上所述,相比于 BF16 或 FP8 預(yù)訓(xùn)練,MXFP8 能保持同等精度。在 GB200 Blackwell 系統(tǒng)上,MXFP8 的吞吐量是 BF16 的 2 倍,這使得端到端 MXFP8 預(yù)訓(xùn)練速度超越 BF16 預(yù)訓(xùn)練。與 FP8 相比,MXFP8 方案還更加簡(jiǎn)便(所有層均可量化且縮放由硬件處理),同時(shí)保持同等或更優(yōu)的吞吐性能。
五、參考鏈接:
- [1] https://arxiv.org/abs/2506.08027
- [2] https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
- [3] https://arxiv.org/abs/2209.05433
- [4] https://arxiv.org/abs/2412.19437
- [5] https://arxiv.org/abs/2504.03624
本文轉(zhuǎn)載自??AI閑談??,作者:AI閑談
