阿里云通義大模型新技術(shù):MoE模型訓(xùn)練專家平衡的關(guān)鍵細(xì)節(jié)
本周,在阿里云通義千問 Qwen 團(tuán)隊(duì)提交的一篇論文中,研究人員發(fā)現(xiàn)了目前最熱門的 MoE(混合專家模型)訓(xùn)練中存在的一個普遍關(guān)鍵問題,并提出一種全新的方法——通過輕量的通信將局部均衡放松為全局均衡,使得 MoE 模型的性能和專家特異性都得到了顯著的提升。

- 論文:《Demons in the Detail: On Implementing Load Balancing Loss for Training Specialized Mixture-of-Expert Models》
- 論文鏈接:https://arxiv.org/abs/2501.11873

MoE 模型訓(xùn)練中的關(guān)鍵問題
混合專家模型(MoEs)通過路由機(jī)制動態(tài)并稀疏地激活模型參數(shù),使得能高效地增大模型參數(shù)規(guī)模?;?TopK 機(jī)制的稀疏激活會在訓(xùn)練中會遇到專家激活不均衡的問題:少數(shù)被頻繁選擇的專家會被優(yōu)化得更多,進(jìn)一步使得這些專家被更頻繁地選擇,最終導(dǎo)致只選擇少數(shù)專家,造成剩余專家的冗余。因此,MoE 在訓(xùn)練中需要引入額外輔助的負(fù)載均衡損失(load balance loss,LBL)來鼓勵專家的選擇趨于均衡。
目前主流 MoE 訓(xùn)練框架中實(shí)現(xiàn)的 LBL 的優(yōu)化目標(biāo)是局部(micro-batch)的負(fù)載均衡,這使得模型需要將一個micro-batch的輸入都均勻分配給不同的專家。然而,一個micro-batch的輸入往往只來自個別領(lǐng)域,局部負(fù)載均衡會讓模型將每個領(lǐng)域的輸入都均勻分配。這種均勻分配會阻礙某些專家更多處理特定領(lǐng)域的數(shù)據(jù),也即阻礙專家出現(xiàn)領(lǐng)域?qū)哟蔚姆只卣?。我們發(fā)現(xiàn),將局部的負(fù)載均衡放松到全局的負(fù)載均衡,能顯著增強(qiáng)專家的特異化并提高模型性能。
背景
混合專家(Mixture-of-Experts,MoE)是一種高效的在訓(xùn)練時擴(kuò)展模型參數(shù)規(guī)模的技術(shù)。通常,一個MoE層由一個路由器(通常是一個線性層)和一組專家組成(對于Transformer的模型,每個專家是一個前饋神經(jīng)網(wǎng)絡(luò))。給定一個輸入,只有部分專家會被激活,然后它們的輸出會根據(jù)路由器分配的權(quán)重進(jìn)行聚合。具體來說:

負(fù)載均衡損失
負(fù)載均衡損失是訓(xùn)練 MoE 網(wǎng)絡(luò)中的一種重要正則化技術(shù),其核心思想是鼓勵所有專家的均衡激活。它可以通過以下公式計(jì)算:

其中, 是專家 的激活頻率, 是分配給專家 的平均路由分?jǐn)?shù)。
然而,大多數(shù)現(xiàn)有的MoE訓(xùn)練框架(例如Megatron-core)實(shí)現(xiàn)的是局部(micro-batch)層次的均衡,這意味著在每個 micro-batch 內(nèi)計(jì)算 LBL ,然后在全局(global-batch)層次上進(jìn)行平均,即:

其中 為 micro-batch 數(shù), 是在第 個 micro-batch 上計(jì)算的負(fù)載均衡損失, 為在第 個 micro-batch 上統(tǒng)計(jì)出的激活頻率和路由分?jǐn)?shù)。
我們關(guān)注的關(guān)鍵點(diǎn)是,如果一個 micro-batch 中的數(shù)據(jù)不夠多樣化,這種實(shí)現(xiàn)方式可能會阻礙專家的特異化。例如,假設(shè)一個 micro-batch 中只包含代碼數(shù)據(jù),上述負(fù)載均衡損失仍然會推動路由器將這些代碼輸入均勻分配給所有專家。而理想狀況下,處理代碼數(shù)據(jù)的專家網(wǎng)絡(luò)應(yīng)該對代碼數(shù)據(jù)有更高的激活頻率。在訓(xùn)練基于 MoE 的大型語言模型時,這種情況更常見:一個較小的 micro-batch (通常為 1)中的數(shù)據(jù)通常來自同一領(lǐng)域。這在一定程度上解釋了為什么當(dāng)前大多數(shù)基于 MoE 的大語言模型中都沒有觀察到明顯的領(lǐng)域?qū)哟蔚膶<姨禺惢?/span>
這一缺點(diǎn)促使我們將當(dāng)前局部均衡的方法想辦法擴(kuò)展到全局(global-batch)均衡。
從局部均衡到全局均衡
得得益于 LBL 計(jì)算的格式,我們可以通過通信不同節(jié)點(diǎn)的 來將局部 轉(zhuǎn)化為全局的 :1)在所有 micro-batch 之間同步專家選擇頻率 ;2)在每個GPU上計(jì)算負(fù)載均衡損失;3)在所有 micro-batch 之間聚合損失。具體來說:

其中 是對全局統(tǒng)計(jì)的激活頻率和門控分?jǐn)?shù),第一個等式為 的計(jì)算方式,第二個等式為全局路由分?jǐn)?shù)可以由局部路由分?jǐn)?shù)平均而來,第三個等式表示用全局激活頻率參與局部計(jì)算后再平均聚合等價于全局均衡損失。因?yàn)?nbsp; 只是一個專家數(shù)大小的向量,即使是在全局通信的情況下也不會帶來明顯的開銷。此外由于 LBL 的計(jì)算與模型其它部分的計(jì)算相對獨(dú)立,還可以用計(jì)算掩蓋等策略進(jìn)一步消除同步 的通信開銷。
此外,對于需要梯度積累的情景,我們還提出了緩存機(jī)制來累積各個積累步統(tǒng)計(jì)的專家激活頻率,使得計(jì)算節(jié)點(diǎn)較少、只進(jìn)行一次通信達(dá)到的均衡范圍有限的情況下,也能逐漸近似全局統(tǒng)計(jì)的激活頻率。
擴(kuò)大均衡的范圍帶來穩(wěn)定的提升
我們在三種參數(shù)規(guī)模(3.4B 激活 0.6B, 15B 激活 2.54B,43B 激活 6.6B)下分別訓(xùn)練了 120B 和 400B tokens,對比了不同的均衡范圍(Balance BSZ)對模型性能的影響。所有模型都使用了細(xì)粒度專家、共享專家及 dropless 策略(專家不會拋棄超過容量的tokens)??梢钥吹?,將均衡范圍從一般框架實(shí)現(xiàn)的 4,8 或者 16 增大到 128 以上后模型在 Benchmark 指標(biāo)和 PPL 都有明顯提升。

我們在 3.4B 激活 0.6B 的模型訓(xùn)練 400B tokens 到設(shè)置上進(jìn)一步對比了模型效果隨著均衡范圍的變化,可以看到 balance BSZ 從 2 到 128 模型的 PPL 在快速降低,在 128 后逐漸飽和。目前主流 MoE 框架中即使是進(jìn)行了機(jī)內(nèi)通信,對于較大的模型 balance BSZ 也一般在 8 到 16 的,這進(jìn)一步體現(xiàn)了我們通信方法的意義。

分析實(shí)驗(yàn)
假設(shè)驗(yàn)證
前文提到,這篇工作的出發(fā)點(diǎn)是在一個 micro-batch 中,數(shù)據(jù)的來源較為單一的,進(jìn)而導(dǎo)致 MoE 模型需要將類似來源的數(shù)據(jù)均勻分配到所有expert上,我們改進(jìn)了這一點(diǎn)進(jìn)而得到了提升。
然而,我們也可以假設(shè) global batch 是因?yàn)槭褂昧烁嗟?token 來統(tǒng)計(jì) expert 激活頻率進(jìn)而減少了方差,使得負(fù)載均衡損失更加穩(wěn)定,進(jìn)而提升訓(xùn)練洗哦啊過。位了更加嚴(yán)謹(jǐn)?shù)貙Ρ冗@兩種假設(shè),我們引入了一種對比的實(shí)驗(yàn)設(shè)置:Shffuled batch balance, 即我們從global batch中隨機(jī)抽取一個子集(這個子集的大小等于micro batch的大?。┙y(tǒng)計(jì)專家激活頻率,進(jìn)而計(jì)算負(fù)載均衡損失。Shuffled batch balance 和 micro-batch balance擁有相同的token數(shù)目,和 global-batch balance擁有相同的token分布。

我們發(fā)現(xiàn),shuffled batch balance 和 global batch balance 的表現(xiàn)幾乎一致,都顯著好于 micro batch balance。說明,引入 global-batch 獲得提升的首要原因是在一個更加通用、多樣的 token 集合上計(jì)算損失。進(jìn)而驗(yàn)證了我們的出發(fā)點(diǎn)和假設(shè)。
添加少量局部均衡損失
能提高模型效率
只使用全局均衡會導(dǎo)致局部均衡狀況有所降低,這會一定程度影響 MoE 的計(jì)算效率。我們進(jìn)一步實(shí)驗(yàn)了在主要使用全局均衡的情況下,在訓(xùn)練過程中添加局部均衡(默認(rèn)實(shí)現(xiàn)的 LBL,損失權(quán)重為全局 LBL 的 1%)限制對于模型性能和效率的影響??梢钥吹?,添加局部均衡能提升模型的速度(每個更新步耗時從 1.64秒提升到1.59秒),同時模型的效果也幾乎不受影響。

同期相關(guān)工作以及討論
已有工作 GRIN 也提出了 Global Load Balance Loss Adaptations,然而更多將這一均衡方法作為訓(xùn)練框架只使用張量并行、不使用專家并行的優(yōu)勢。GRIN 中并沒有從 specialization 或是對模型 performance 影響等方面討論使用 Global Load Balance 的動機(jī),也沒有展示單一使用 Global Load Balance 的影響。
Wang et al. 提出在基于MoE的大語言模型訓(xùn)練中,負(fù)載均衡損失和語言模型損失如同杠桿一樣需要權(quán)衡,因?yàn)閮烧叩膬?yōu)化目標(biāo)并不一致。因此,他們提出了一種基于專家選擇頻率更新的偏差項(xiàng)(bais term),在不改變路由分?jǐn)?shù)的情況下平衡專家選擇,從而去掉了用來輔助訓(xùn)練的負(fù)載均衡損失(auxiliary-loss free)?;趯<疫x擇頻率更新的偏置項(xiàng),以在不改變路由評分的情況下平衡專家選擇。但是,他們沒有比較該方法在專家選擇頻率是根據(jù) micro-batch 計(jì)算和根據(jù) global-batch 計(jì)算時的性能差異。
這項(xiàng)工作也被應(yīng)用到 deepseek-v3 的訓(xùn)練中。deepseek-v3 的技術(shù)報(bào)告(同期工作)中強(qiáng)調(diào)了這項(xiàng)技術(shù)的專家選擇頻率是基于 global-batch 進(jìn)行計(jì)算,并在小規(guī)模上討論了基于global batch 使用 LBL 的結(jié)果,也發(fā)現(xiàn)這兩種方法結(jié)果相似。
而我們的工作不僅在大規(guī)模上系統(tǒng)驗(yàn)證了這種方法的有效性,還詳細(xì)析了均衡范圍對性能的影響,并消融證明了 global-batch 是通過納入更多樣化的領(lǐng)域信息從而顯著提性能。
結(jié)論
我們回顧了目前 MoE 訓(xùn)練框架中均衡損失,發(fā)現(xiàn)目前的實(shí)現(xiàn)方式會將所有來自相同領(lǐng)域的局部輸入都均勻分配,限制了專家的分化。通過輕量的通信將局部均衡放松為全局均衡,MoE 模型的性能和專家特異性都得到了顯著的提升。我們認(rèn)為這一進(jìn)展解決了現(xiàn)有MoE訓(xùn)練中的一個關(guān)鍵問題,為MoE模型的優(yōu)化提供了新的視角,并有助于構(gòu)建更加可解釋的模型。盡管我們的實(shí)驗(yàn)主要集中在基于語言的任務(wù)上,我們希望我們的工作能夠?yàn)樵诓煌I(lǐng)域訓(xùn)練更大規(guī)模、更有效的 MoE 模型提供幫助。


































