低內存占用也能實現(xiàn)滿血訓練?!北理北大港中文MMLab推出Fira訓練框架
內存占用小,訓練表現(xiàn)也要好……大模型訓練成功實現(xiàn)二者兼得。
來自北理、北大和港中文MMLab的研究團隊提出了一種滿足低秩約束的大模型全秩訓練框架——Fira,成功打破了傳統(tǒng)低秩方法中內存占用與訓練表現(xiàn)的“非此即彼”僵局。

展開來說——
為了突破內存瓶頸,許多低秩訓練方法應運而生,如LoRA(分解參數(shù)矩陣)和GaLore(分解梯度矩陣)。
△圖1:從宏觀層面分析三種內存高效低秩訓練方法
然而,如上圖所示,LoRA將訓練局限于參數(shù)的低秩子空間,降低了模型的表征能力,難以實現(xiàn)預訓練;GaLore將訓練局限于梯度的低秩子空間,造成了子空間外梯度的信息損失。
相較于全秩訓練,這兩種方法由于施加了低秩約束,會導致訓練表現(xiàn)有所下降。
但是,若提高秩值,則會相應地增加內存占用。
因此,在實際應用中,它們需要在確保訓練表現(xiàn)與降低內存消耗之間找到一個恰當?shù)钠胶恻c。
這引發(fā)了一個核心問題:
能否在維持低秩約束以確保內存高效的同時,實現(xiàn)全秩參數(shù)、全秩梯度的訓練以提升表現(xiàn)?
Fira即為最新答案,它有三大亮點:
- 即插即用:Fira簡單易用,其核心實現(xiàn)僅涉及兩行關鍵公式,現(xiàn)已封裝進Python庫,可直接融入現(xiàn)有的大模型訓練流程中,替換原有優(yōu)化器。代碼示例如下:
from fira import FiraAdamW, divide_params
param_groups = divide_params(model, target_modules_list = [“Linear”], rank=8)
optimizer = FiraAdamW(param_groups, lr=learning_rate)- 雙贏解決方案:在維持低秩約束的前提下,F(xiàn)ira實現(xiàn)了大模型的全秩訓練,打破了內存占用與訓練表現(xiàn)的取舍難題。與此同時,區(qū)別于系統(tǒng)方法(如梯度檢查點),F(xiàn)ira不以時間換內存;
- 實驗驗證:Fira在多種規(guī)模的模型(60M至7B參數(shù))以及預訓練和微調任務中均展現(xiàn)出卓越性能,優(yōu)于現(xiàn)有的LoRA和GaLore,甚至能達到或超越全秩訓練的效果。
打造Fira訓練框架
Fira訓練框架由兩部分組成:
1) 基于梯度模長的縮放策略:利用了團隊在大模型低秩和全秩訓練中發(fā)現(xiàn)的共通點——自適應優(yōu)化器對原始梯度的修正效應,實現(xiàn)了低秩約束下的全秩訓練。
2) 梯度模長限制器,通過限制梯度模長的相對增長比例,解決了大模型訓練中常出現(xiàn)的損失尖峰問題。
背景動機
大模型訓練常常面臨顯著的內存瓶頸,尤其是其中的優(yōu)化器狀態(tài)。
舉例來說,使用Adam優(yōu)化器從頭預訓練一個LLaMA 7B模型(batchsize為1,精度為BF16)可能需要至少58GB內存。
其中14GB用于加載參數(shù),14GB用于儲存梯度,28GB用于儲存優(yōu)化器狀態(tài),剩下2GB用于儲存激活值。
在這之中,優(yōu)化器狀態(tài)所占內存甚至要大于參數(shù)本身。
因此,使用低秩方法來減少這一部分內存,實現(xiàn)大模型的內存高效訓練十分重要。
而在現(xiàn)有的低秩方法中,LoRA通過分解參數(shù)矩陣,使用低秩適配器來減少內存占用;Galore通過分解梯度矩陣,在自適應優(yōu)化器中儲存低秩梯度來減少內存占用。
鑒于使用LoRA低秩適配器方法來實現(xiàn)全參數(shù)訓練的困難性,團隊選擇拓展Galore的梯度投影方法來實現(xiàn)全秩訓練。
在Galore中,全秩梯度G?? ? ?mxn,會被投影矩陣P?? ? ?mxr分解成兩項低秩梯度P??R??和(G??—P??R??),其中
。
為減少像Adam這樣的自適應優(yōu)化器在內存中對應的狀態(tài)占用,Galore僅在優(yōu)化器核心??中保留低秩梯度R??,而非全秩梯度G??。
而另一項梯度(G??—P??R??),則會因為缺少對應的優(yōu)化器狀態(tài),被Galore直接丟棄,從而造成嚴重的信息損失。
這也解釋了,為什么Galore的性能會在rank值減小時,顯著衰減。

△圖2:Fira與Galore及其變體的訓練損失對比
為了彌補上述信息損失,最直觀的方法是直接加上這一部分梯度(G??—P??R??):

其中,W是參數(shù)矩陣, ??是學習率。
然而,如圖所示,使用這種方法(Galore-add)不僅未能帶來性能提升,反而可能導致訓練過程更加不穩(wěn)定,且結果更差。
分析原因可歸結于這一部分的梯度缺乏優(yōu)化器狀態(tài),直接使用會退化為單純的SGD算法,并且可能與前面使用的Adam優(yōu)化器的梯度不匹配,導致效果不佳。
基于梯度模長的縮放策略
為了解決上述挑戰(zhàn),團隊提出了scaling factor概念,來描述Adam這樣的自適應優(yōu)化器對原始梯度的修正效應,并揭示了它在大模型的低秩訓練和全秩訓練之間的相似性。

其中,?? 就是scaling factor,代表經(jīng)過優(yōu)化器修正過的梯度與原始梯度的模長比例。
如下圖,如果根據(jù)scaling factor的平均值對參數(shù)矩陣進行排序,可以發(fā)現(xiàn)低秩和全秩之間的排序非常相似。

△圖3:scaling factor在大模型低秩和全秩訓練間的相似性
基于這個觀察,團隊就嘗試在矩陣層面用低秩梯度R??的scaling factor,作為全秩梯度G??的scaling factor的替代,從而近似地修正(G??—P??R??),彌補其缺少的優(yōu)化器狀態(tài):

這樣團隊就在低秩約束下成功實現(xiàn)了全秩訓練。
進一步來說,剛才是從矩陣層面來考慮scaling factor。
順理成章地,團隊可以從更細粒度的角度——列的層面,來考慮scaling factor,實現(xiàn)更加精細地修正。

其中R??,:,?? 是低秩梯度R??的第i列,
是scaling factor的第i項。
梯度模長限制器
在訓練過程中,梯度常常會突然增大,導致?lián)p失函數(shù)出現(xiàn)尖峰,從而影響訓練的表現(xiàn)。
經(jīng)過分析,可能原因是Galore在切換投影矩陣時存在不穩(wěn)定性,以及維持(G??—P??R??)這種原始梯度的方向的方式,無法像Adam這樣的自適應算法,有效應對大模型訓練中存在的陡峭損失景觀。
△圖4:3種Fira變體的訓練損失與梯度模長
然而,常見的梯度裁剪方法(如圖中的Fira-gradient-clipping)由于采用絕對裁剪,難以適應不同參數(shù)矩陣間梯度的較大差異,從而可能導致次優(yōu)的訓練結果。
為此,團隊提出了一種新的梯度模長限制器,它通過限制梯度模長的相對增長比例,來更好地適應不同梯度的變化:

其中??是比例增長的上限,S??=????(R??)(G??—P??R??)是原始梯度(G??—P??R??)修正后的結果。
通過提出的控制梯度相對增長比例的方法,能夠將梯度的驟然增大轉化為平緩的上升,從而有效穩(wěn)定訓練過程。
如圖2和圖3所示,團隊的限制器成功避免了損失函數(shù)的尖峰情況,并顯著提升了訓練表現(xiàn)。
實驗結果
如下表所示,在預訓練任務中,F(xiàn)ira在保持內存高效的前提下,驗證集困惑度(↓)顯著超過各類基線方法,甚至超越全秩方法。
具體來說,在預訓練LLaMA 1B模型時,F(xiàn)ira節(jié)約了61.1%優(yōu)化器狀態(tài)所占內存,并且取得了比全秩訓練更加好的結果。
△使用C4數(shù)據(jù)集預訓練不同大小的LLaMA模型驗證集困惑度(↓)對比
在預訓練LLaMA 7B模型時,F(xiàn)ira在使用了比Galore小8倍的秩rank的情況下,訓練表現(xiàn)遠超Galore。
這展現(xiàn)了Fira在大規(guī)模大模型上的有效性,以及相較Galore更高的內存減少能力。

△使用C4數(shù)據(jù)集預訓練LLaMA 7B的驗證集困惑度(↓)對比
在八個常識推理數(shù)據(jù)集微調LLaMA 7B的任務中,相較其他基線方法,F(xiàn)ira在一半的數(shù)據(jù)集下表現(xiàn)最好,平均準確率最高的同時實現(xiàn)了內存高效。

△在八個常識推理數(shù)據(jù)集微調LLaMA 7B準確率對比
另外,消融實驗也顯示了:
- Fira-w.o.-scaling說明了Fira使用基于梯度模長的縮放策略的有效性;
- Fira-matrix說明了從更細粒度的列級別,而不是矩陣級別,考慮scaling factor的有效性;
- Fira-w.o.-limiter說明了Fira中梯度模長限制器的有效性;
- Fira-gradient-clipping說明了梯度裁剪可能無法完全解決損失尖峰問題,導致結果次優(yōu)。

△消融實驗
與GaLore相比,F(xiàn)ira的表現(xiàn)幾乎不受秩rank值減少的影響。
在低秩的情況下(rank=16, rank=4),F(xiàn)ira仍然能與全秩訓練相當,相較Galore更加內存高效。

△不同rank下的預訓練驗證集困惑度(↓)
最后,團隊在不同模型大小,以及低秩和全秩條件下,訓練10,000步,并對得到的矩陣和列級別上Scaling factor做平均。
接著,使用了斯皮爾曼(Spearman)和肯德爾(Kendall)相關系數(shù)分析了Scaling factor在矩陣和列級別上大小順序的相關性。
其中,Coefficient中1代表完全正相關,-1代表完全負相關,而P-value越小越好(通常小于0.05為顯著)。
在所有規(guī)模的LLaMA模型中,Scaling factor在矩陣和列的級別上都表現(xiàn)出很強的正相關關系,并且所有的P-value小于0.05,非常顯著,為Fira中基于梯度模長的縮放策略提供了堅實的實驗基礎。

△矩陣和列級別上的Scaling factor低秩與全秩相似性分析
更多細節(jié)歡迎查閱原論文。
論文鏈接:https://arxiv.org/abs/2410.01623
代碼倉庫:https://github.com/xichen-fy/Fira


































