1行代碼改進大模型訓(xùn)練,Llama訓(xùn)練速度提升至1.47倍,全華人團隊出品
只要改一行代碼,就能讓大模型訓(xùn)練效率提升至1.47倍。
擁有得州大學(xué)奧斯汀分校背景四名華人學(xué)者,提出了大模型訓(xùn)練優(yōu)化器Cautious Optimizers。
在提速的同時,Cautious能夠保證訓(xùn)練效果不出現(xiàn)損失,而且語言和視覺模型都適用。
該優(yōu)化器以哈密頓量和下降動力學(xué)為理論基礎(chǔ),在加速的同時不影響收斂特性。
作者在600M到1B不同參數(shù)規(guī)模的Llama模型上進行了試驗,獲得了最高47%的加速率。
該研究相關(guān)代碼已經(jīng)開源,在GitHub上有使用方法的詳細講解。
一行代碼改進大模型訓(xùn)練
Cautious Optimizers在PyTorch當中增加的一行代碼,核心思路是引入實現(xiàn)一種掩蔽機制,從而避免參數(shù)更新的方向與當前梯度方向相悖。
因為這兩個方向一旦不一致,就有可能導(dǎo)致?lián)p失函數(shù)暫時增加,造成收斂速度的減緩。
不過作者并未在方向不一致的來源問題上過度糾結(jié),而是引入了一種判斷機制,在參數(shù)更新之前增加一步計算,從而過濾掉方向不一致的情形。
這也正是上面代碼的直接作用。
△GD:梯度下降,GDM:帶動量的梯度下降,C-GDM:本項目
具體來說,加入的兩行代會對u和g兩個向量求內(nèi)積,u向量對應(yīng)優(yōu)化器給出的參數(shù)更新方向,而g向量對應(yīng)當前時刻的梯度方向。
作者設(shè)計了一個對齊掩碼函數(shù)?,當u和g的內(nèi)積小于0時(即方向不一致),?的輸出為0向量;當內(nèi)積大于等于0時,?的輸出為全1向量。
而一旦?為零向量時,w_t計算式中含u的項也會變?yōu)榱阆蛄?,?dǎo)致此項更新被跳過。
這樣就可以判斷參數(shù)更新和梯度方向是否一致,如果不一致則不會用于參數(shù)更新,避免了訓(xùn)練過程中損失函數(shù)的回升。
訓(xùn)練效率提升47%
為了評估Cautious Optimizers的具體效果,作者分別在語言模型Llama和視覺模型MAE上進行了試驗。
作者選取了60M、100M、350M和1B四種參數(shù)規(guī)模的Llama模型,在C4語料庫上進行預(yù)訓(xùn)練。
優(yōu)化器選用了AdamW和Lion,以及它們對應(yīng)的Cautious版本:C-AdamW和C-Lion,每個實驗中進行1萬步迭代。
結(jié)果C-AdamW和C-Lion在所有規(guī)模上都表現(xiàn)出明顯的收斂加速效果。
尤其是在1B規(guī)模上,相比原版的AdamW和Lion,它們的樣本效率分別提高了47%和28%,這表明Cautious Optimizer能有效減少訓(xùn)練震蕩,使收斂更平穩(wěn)高效。
并且,Cautious Optimizer在所有情況下都取得了更低的困惑度,印證了其出色的泛化性能。
為了評估模型的實際效果,研究者在語句匹配、文本蘊含、情感分類等6個GLUE下游任務(wù)上測試了AdamW和C-AdamW優(yōu)化后1B模型的表現(xiàn),
結(jié)果表明,C-AdamW的平均得分比AdamW高出2%,在大多數(shù)任務(wù)上都取得了進步,說明Cautious跳過部分參數(shù)更新的方式不會引起模型性能下降。
對于視覺模型,作者以ViT為骨干網(wǎng)絡(luò),在ImageNet-1K數(shù)據(jù)集上預(yù)訓(xùn)練了MAE模型。
由于視覺任務(wù)的特殊性,訓(xùn)練過程采用了隨機遮擋圖像塊并重建的范式,因此優(yōu)化目標是最小化重建誤差,而非通常的分類損失。
作者對比了AdamW和C-AdamW的表現(xiàn),即訓(xùn)練50輪后的最終重建誤差,結(jié)果C-AdamW的誤差為0.5926,低于AdamW的0.6085。
一作曾在一周內(nèi)復(fù)刻o1
本項目是由四名華人學(xué)者共同打造的。
第一作者Kaizhao Liang,是AI推理加速服務(wù)商SambaNova公司的一名高級ML工程師。
在o1模型發(fā)布一周內(nèi),該公司就推出了一個類似o1模型思考過程的開源平替,主要作者正是Liang。
其他三名作者是得州大學(xué)奧斯汀分校CS助理教授Qiang Liu,以及他的兩名博士生,Lizhang Chen和Bo Liu。
此外,Liang的人工智能碩士學(xué)位也是從該校獲得。
論文地址:https://arxiv.org/abs/2411.16085
GitHub:https://github.com/kyleliang919/C-Optim