人民大學(xué)&字節(jié)Seed:利用μP實(shí)現(xiàn)Diffusion Transformers高效擴(kuò)展
本文由中國(guó)人民大學(xué)高瓴人工智能學(xué)院李崇軒團(tuán)隊(duì)和字節(jié)跳動(dòng)Seed團(tuán)隊(duì)共同完成。第一作者鄭晨宇是中國(guó)人民大學(xué)高瓴人工智能學(xué)院二年級(jí)博士生,主要研究方向?yàn)榛A(chǔ)模型的優(yōu)化、泛化和可擴(kuò)展性理論,導(dǎo)師為李崇軒副教授,論文為其在字節(jié)跳動(dòng)Seed實(shí)習(xí)期間完成。第二作者張新雨是字節(jié)跳動(dòng)研究員,主要研究方向?yàn)橐曈X(jué)生成模型。李崇軒副教授為唯一通訊作者。
近年來(lái),diffusion Transformers已經(jīng)成為了現(xiàn)代視覺(jué)生成模型的主干網(wǎng)絡(luò)。隨著數(shù)據(jù)量和任務(wù)復(fù)雜度的進(jìn)一步增加,diffusion Transformers的規(guī)模也在快速增長(zhǎng)。然而在模型進(jìn)一步擴(kuò)大的過(guò)程中,如何調(diào)得較好的超參(如學(xué)習(xí)率)已經(jīng)成為了一個(gè)巨大的問(wèn)題,阻礙了大規(guī)模diffusion Transformers釋放其全部的潛能。
為此,人大高瓴李崇軒團(tuán)隊(duì)和字節(jié)跳動(dòng)Seed團(tuán)隊(duì)的研究員引入了大語(yǔ)言模型訓(xùn)練中的μP理論,并將其擴(kuò)展到diffusion Transformers的訓(xùn)練中。μP通過(guò)調(diào)整網(wǎng)絡(luò)不同模塊的初始化和學(xué)習(xí)率,實(shí)現(xiàn)不同大小diffusion Transformers共享最優(yōu)的超參,使得小模型上搜到的超參可以直接遷移到最終大模型上進(jìn)行訓(xùn)練,從而極大地減小了超參搜索的耗費(fèi)。
團(tuán)隊(duì)在DiT,PixArt和MMDiT(Stable Diffusion的基座)上進(jìn)行了系統(tǒng)的大規(guī)模實(shí)驗(yàn)驗(yàn)證。在MMDiT的實(shí)驗(yàn)中,0.18B小模型上搜得的超參成功被用在18B大模型的訓(xùn)練中,并擊敗了人工專家的手調(diào)基線。其中,小模型超參搜索的計(jì)算量(FLOPs)僅是專家手調(diào)的3%左右。
團(tuán)隊(duì)已在近期開(kāi)放在線論文,并開(kāi)源代碼。
- 論文鏈接:https://arxiv.org/abs/2505.15270
- 代碼倉(cāng)庫(kù):https://github.com/ML-GSAI/Scaling-Diffusion-Transformers-muP
μP的背景和問(wèn)題
μP全稱為最大更新參數(shù)化(Maximal Update Parametrization),是Tensor Program無(wú)窮寬網(wǎng)絡(luò)理論系列中的里程碑之作,相關(guān)結(jié)果已被理論證明適用于標(biāo)準(zhǔn)的Transformer架構(gòu)。μP的算法實(shí)現(xiàn)簡(jiǎn)潔,對(duì)于應(yīng)用最為廣泛的AdamW優(yōu)化器而言,μP只需要調(diào)整隱藏層權(quán)重的學(xué)習(xí)率,和輸出層權(quán)重的系數(shù)以及初始化。μP在實(shí)際中被廣泛發(fā)現(xiàn)能夠?qū)崿F(xiàn)不同大小的標(biāo)準(zhǔn)Transformer共享最優(yōu)的超參,使得小模型上搜到的超參可以直接遷移到大模型,極大地減小了超參搜索的耗費(fèi)。由于μP帶來(lái)了穩(wěn)定的超參遷移性質(zhì),它近年來(lái)已經(jīng)被成功使用在大語(yǔ)言模型(標(biāo)準(zhǔn)Transformer)的預(yù)訓(xùn)練中。
然而,diffusion Transformers和標(biāo)準(zhǔn)Transformer存在較大的差異。從架構(gòu)上來(lái)看,diffusion Transformers引入了額外的模塊來(lái)處理并整合文本信息,如DiT中的adaLN block。從任務(wù)目標(biāo)上來(lái)看,diffusion Transformers處理的是視覺(jué)的擴(kuò)散學(xué)習(xí)任務(wù),而標(biāo)準(zhǔn)Transformer主要處理的是語(yǔ)言的自回歸學(xué)習(xí)任務(wù)。這兩點(diǎn)差異意味著已有的μP形式及其超參遷移律在視覺(jué)diffusion Transformers中不一定成立。針對(duì)這一問(wèn)題,團(tuán)隊(duì)從理論和實(shí)踐上進(jìn)行了系統(tǒng)的研究。
Diffusion Transformers的μP形式
團(tuán)隊(duì)首先從理論上研究了主流diffusion Transformers的μP形式,包括DiT,U-ViT,PixArt-α和MMDiT。Tensor Program理論系列中的結(jié)果表明,如果網(wǎng)絡(luò)架構(gòu)能夠被Tensor Program中定義的算子表示,那么現(xiàn)有的μP形式就能成立。基于這個(gè)理論技術(shù),我們證明了:即使主流diffusion Transformers的結(jié)構(gòu)不同于標(biāo)準(zhǔn)Transformer,它們也能夠被Tensor Program表示,因此現(xiàn)有的μP理論和相關(guān)實(shí)踐可以被無(wú)痛遷移到這些主流diffusion Transformers上。我們的證明技術(shù)也可以被遷移到其它的diffusion Transformers做類似的分析。
總之,diffusion Transformers的μP方法論可以由下圖總結(jié)。我們首先基于μP理論,調(diào)節(jié)不同權(quán)重的系數(shù)、初始化和學(xué)習(xí)率。然后,我們?cè)谝幌盗行∧P蜕纤阉鞯玫阶顑?yōu)的超參。最后,我們將最優(yōu)的超參直接遷移到大模型的訓(xùn)練。
基于μP擴(kuò)展Diffusion Transformers:初探
首先,我們使用DiT網(wǎng)絡(luò)在ImageNet數(shù)據(jù)集上系統(tǒng)地驗(yàn)證了:當(dāng)網(wǎng)絡(luò)寬度,數(shù)據(jù)批量大小和訓(xùn)練步數(shù)足夠大時(shí)(如寬度達(dá)到144,批量大小達(dá)到256),超參便可以較為穩(wěn)定地沿著不同的網(wǎng)絡(luò)寬度,數(shù)據(jù)批量大小和訓(xùn)練步數(shù)進(jìn)行遷移。這意味著我們能在網(wǎng)絡(luò)寬度,數(shù)據(jù)批量大小和訓(xùn)練步數(shù)都更小的代理任務(wù)上搜索超參,然后遷移到最終大網(wǎng)絡(luò)大數(shù)據(jù)的訓(xùn)練。
然后,為了驗(yàn)證μP超參遷移的有效性,我們將最優(yōu)的超參(學(xué)習(xí)率2^-10)直接遷移到DiT-XL-2的訓(xùn)練中,我們發(fā)現(xiàn),當(dāng)模型訓(xùn)練到2.4M步時(shí),F(xiàn)ID-50K就已經(jīng)超過(guò)了原論文7M步最終的FID-50K結(jié)果,DiT-XL-2-μP的收斂速度是原論文的2.9倍。這向我們展現(xiàn)了利用μP遷移超參做擴(kuò)展的良好前景。
基于μP擴(kuò)展Diffusion Transformers:大規(guī)模驗(yàn)證
我們進(jìn)一步在大規(guī)模的文生圖任務(wù)上驗(yàn)證了μP擴(kuò)展diffusion Transformers的有效性。我們首先考慮了流行的開(kāi)源文生圖模型PixArt-α,我們?cè)?.04B的代理模型上搜索學(xué)習(xí)率,并遷移到最終0.61B大小PixArt-α的訓(xùn)練。其中,小模型搜索超參的計(jì)算量總和(FLOPs)僅為一次訓(xùn)練的5.5%。利用搜索得到的學(xué)習(xí)率,PixArt-α-μP在訓(xùn)練的過(guò)程中穩(wěn)定地取得了比基線更好的效果。
最后,我們考慮了SD3的基座模型MMDiT,并將驗(yàn)證的規(guī)模提高到了18B的量級(jí)。為了能夠給社區(qū)帶來(lái)更多的可信的實(shí)踐經(jīng)驗(yàn),我們?cè)?4個(gè)超參(學(xué)習(xí)率,梯度裁剪值,REPA loss的權(quán)重以及warmup的步數(shù))上進(jìn)行了多達(dá)80次的隨機(jī)搜索,總搜索計(jì)算量(FLOPs)約是人工手調(diào)的3%。在0.18B模型上的超參搜索結(jié)果表明,我們學(xué)習(xí)率,梯度裁剪值,REPA loss都對(duì)結(jié)果有影響,其中學(xué)習(xí)率的影響仍是最為關(guān)鍵的。而warmup的步數(shù)則對(duì)結(jié)果影響不大。
我們將0.18B模型上搜索的超參應(yīng)用在了18B模型的訓(xùn)練上,不論從訓(xùn)練loss的變化還是從人工評(píng)測(cè)的結(jié)果,MMDiT-μP都穩(wěn)定地超過(guò)了人工專家手調(diào)的基線,而μP的超參搜索FLOPs僅是人工手調(diào)的3%!
經(jīng)過(guò)這一系列系統(tǒng)的實(shí)驗(yàn)探索,我們證明了μP是科學(xué)擴(kuò)展diffusion Transformers的有效手段,我們也相信μP會(huì)是未來(lái)基礎(chǔ)模型擴(kuò)展的必備利器。通過(guò)本工作的大量努力,我們希望讓社區(qū)了解μP理論,擁抱μP實(shí)踐,思考理論上最優(yōu)的智能擴(kuò)展范式(模型大小,數(shù)據(jù)量,推理時(shí)間)。我們也相信,放眼人工智能的長(zhǎng)遠(yuǎn)未來(lái),類似μP的底層理論的發(fā)展仍然是必不可少的,也必將會(huì)在未來(lái)的大規(guī)模實(shí)踐中有著不可或缺的一席之地。