偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響 原創(chuàng)

發(fā)布于 2024-11-5 13:15
瀏覽
0收藏

在LLM的訓(xùn)練時(shí),由于顯存不足以支撐起大batch訓(xùn)練,通常大家都會(huì)采用一種策略:梯度累計(jì)(gradient accumulate)。這種方法允許模型在多個(gè)batch的梯度回傳累計(jì)并求均值之后,再更新一次權(quán)重。這樣做相當(dāng)于模擬了一個(gè)更大的批量大小,而實(shí)際上并沒有一次性處理那么多數(shù)據(jù)。這樣做的好處是,它可以減少內(nèi)存的使用,因?yàn)椴恍枰淮涡约虞d所有數(shù)據(jù)到GPU上,同時(shí)也可以享受等價(jià)大batch帶來的訓(xùn)練的穩(wěn)定性和模型的泛化能力。

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響-AI.x社區(qū)

但是近期大家發(fā)現(xiàn)了一個(gè)bug:對(duì)于幾乎所有使用了梯度累積策略的庫(kù),包括Huggingface的一系列庫(kù),都暗藏了一個(gè)bug,這個(gè)bug尤其在LLM的后訓(xùn)練階段影響顯著:使用梯度累計(jì)并不一定等價(jià)于大batch訓(xùn)練,會(huì)有非常明顯的精度損失!

???https://github.com/huggingface/trl/issues/2175???

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響-AI.x社區(qū)

如同上述issue描述的情況,圖中bs表示batch size即梯度大小, gas表示 gradient accumulate step即多少次梯度回傳累計(jì)后再更新一次模型權(quán)重。

對(duì)于LLM訓(xùn)練而言,不像圖像任務(wù)有batch norm的影響,理論上,梯度累計(jì)在應(yīng)等同于全批量訓(xùn)練,但實(shí)際發(fā)現(xiàn)loss并不匹配。研究者通過公式和實(shí)驗(yàn)證明,罪魁禍?zhǔn)资情_源庫(kù)中使用基于平均交叉熵loss求和后進(jìn)行梯度累計(jì)的實(shí)現(xiàn)導(dǎo)致了bug,這在輸出等長(zhǎng)的訓(xùn)練任務(wù)中并不影響(這也是為什么在CV任務(wù)和LLM預(yù)訓(xùn)練階段,梯度累計(jì)沒有發(fā)生明顯性能損失,因?yàn)檩敵鐾ǔJ堑乳L(zhǎng)的)。 梯度累積后,過度重視短輸出序列的loss,而忽略長(zhǎng)輸出序列的loss。

這個(gè)bug的數(shù)學(xué)推導(dǎo)也非常簡(jiǎn)單:

我們首先注意到交叉熵?fù)p失的計(jì)算方法如下:

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響-AI.x社區(qū)

請(qǐng)注意,分母計(jì)算了未填充或未忽略(賦值為-100)的token的數(shù)量。首先,我們把它們?cè)O(shè)置為整個(gè)文檔的平均長(zhǎng)度,以簡(jiǎn)化我們的計(jì)算。

假設(shè)兩個(gè)batch的平均序列長(zhǎng)度不等長(zhǎng),一個(gè)是m1,1個(gè)是m2,對(duì)于full batch情況:

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響-AI.x社區(qū)

對(duì)于梯度累計(jì)情況:

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響-AI.x社區(qū)

明顯看出在m1和m2不相等時(shí),兩者是明顯不等價(jià)的。尤其是在其中一個(gè)序列長(zhǎng)度明顯更長(zhǎng),另一個(gè)序列長(zhǎng)度很短時(shí),問題更加嚴(yán)重:比如m1=10,m2=1000時(shí),會(huì)發(fā)現(xiàn)l2的loss大小會(huì)被壓縮,而l1的loss大小相對(duì)于full batch情況下會(huì)被嚴(yán)重放大。

這是因?yàn)椴煌琤atch的文本長(zhǎng)度不同,導(dǎo)致的問題。在梯度累積中,我們需要將每個(gè)小批量梯度累積器按梯度累積步驟的數(shù)量進(jìn)行縮放,以便我們得到期望的結(jié)果。

修復(fù)分母問題后重新實(shí)驗(yàn):

大模型SFT暗藏大陷阱?梯度累計(jì)bug造成大范圍影響-AI.x社區(qū)

現(xiàn)在確實(shí)等價(jià)了,所有的訓(xùn)練損失曲線都匹配上了!分母就是罪魁禍?zhǔn)?!這意味著簡(jiǎn)單地對(duì)每個(gè)梯度累積步驟進(jìn)行平均是錯(cuò)誤的,相反,我們必須事先推導(dǎo)出分母。

目前,這個(gè)bug已經(jīng)引起了廣泛關(guān)注,不少開源庫(kù)包括huggingface系列正在針對(duì)這個(gè)問題進(jìn)行修復(fù)。如果近期遇到SFT效果不佳的問題,可以關(guān)注是否踩到了這個(gè)坑,短期不要使用梯度累計(jì),或在修復(fù)后及時(shí)更新,使用新版梯度累計(jì)算法。


本文轉(zhuǎn)載自公眾號(hào)思源數(shù)據(jù)科學(xué) 作者:思源Source

原文鏈接:??https://mp.weixin.qq.com/s/Za62RV9BDrbuoMERzodCUA??


?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請(qǐng)注明出處,否則將追究法律責(zé)任
標(biāo)簽
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦