謝賽寧新作:表征學(xué)習(xí)有多重要?一個操作刷新SOTA,DiT訓(xùn)練速度暴漲18倍
擴散模型如何突破瓶頸?成本高又難訓(xùn)練的DiT/SiT模型如何提升效率?
對于這個問題,紐約大學(xué)謝賽寧團隊最近發(fā)表的一篇論文找到了一個全新的切入點:提升表征(representation)的質(zhì)量。
論文的核心或許就可以用一句話概括:「表征很重要!」
用謝賽寧的話來說,即使只是想讓生成模型重建出好看的圖像,仍然需要先學(xué)習(xí)強大的表征,然后再去渲染高頻的、使圖像看起來更美觀的細(xì)節(jié)。
這個觀點,Yann LeCun之前也多次強調(diào)過。
有網(wǎng)友還在線幫謝賽寧想標(biāo)題:你這篇論文不如就叫「Representation is all you need」(手動狗頭)
由于觀點一致,這篇研究也獲得了同在紐約大學(xué)的Yann LeCun的轉(zhuǎn)發(fā)。
當(dāng)使用自監(jiān)督學(xué)習(xí)訓(xùn)練視覺編碼器時,我們知道一個事實,使用具有重建損失(reconstruction loss)的解碼器的效果遠(yuǎn)遠(yuǎn)不如具有特征預(yù)測損失(feature prediction loss)和崩潰預(yù)防機制的聯(lián)合嵌入架構(gòu)。
這篇來自紐約大學(xué)謝賽寧團隊的論文表明,即使只對生成像素感興趣(例如,使用擴散Transformer生成漂亮的圖片),包含特征預(yù)測損失也是值得的,以便解碼器的內(nèi)部表示可以基于預(yù)訓(xùn)練的視覺編碼器(例如 DINOv2)進行特征預(yù)測。
REPA的核心思想非常簡單,就是讓擴散模型中的表征與外部更強大的視覺表征進行對齊,但提升效果非常顯著,頗有「他山之石,可以攻玉」的意味。
僅僅是在損失函數(shù)添加一項相似度最大化,就能將SiT/DiT的訓(xùn)練速度提升將近18倍,還刷新了模型的SOTA性能,在ImageNet 256x256上實現(xiàn)了最先進的FID=1.42。
謝賽寧表示,剛看到實驗結(jié)果時,他自己也被震驚到了,因為感覺并沒有發(fā)明什么全新的東西,而只是意識到了,我們幾乎完全不理解擴散模型和SSL方法學(xué)習(xí)到的表示。
論文簡介
論文地址:https://arxiv.org/abs/2410.06940
項目地址:https://sihyun.me/REPA/
在生成高維的視覺數(shù)據(jù)方面,基于去噪方法(如擴散模型)或基于流的生成模型,已經(jīng)成為了一種可擴展的途徑,并在有挑戰(zhàn)性的的零樣本文生圖/文生視頻任務(wù)上取得了非常成功的結(jié)果。
最近的研究表明,生成擴散模型中的去噪過程可以在模型內(nèi)部的隱藏狀態(tài)中引入有意義的表示,但這些表示的質(zhì)量目前仍落后于自監(jiān)督學(xué)習(xí)方法,例如DINOv2。
作者認(rèn)為,訓(xùn)練大規(guī)模擴散模型的一個主要瓶頸,就在于無法有效學(xué)習(xí)到高質(zhì)量的內(nèi)部表示。
如果能夠結(jié)合高質(zhì)量的外部視覺表示,而不是僅僅依靠擴散模型來獨立學(xué)習(xí),就可以使訓(xùn)練過程變得更容易。
為了實現(xiàn)這一點,論文基于經(jīng)典的擴散Transformer架構(gòu),引入了一種簡單的正則化方法REPA(REPresentation Alignment)。
簡單來說,就是將去噪網(wǎng)絡(luò)中從噪聲輸入 得到的隱藏狀態(tài)??的投影,與外部自監(jiān)督預(yù)訓(xùn)練的視覺編碼器從干凈圖像??獲得的視覺表示??*進行對齊。
這樣一個非常直給的策略,卻獲得了驚人的結(jié)果:應(yīng)用于流行的SiT或DiT時,模型的訓(xùn)練效率和生成質(zhì)量都得到了顯著提高。
具體來說,REPA可以將SiT的訓(xùn)練速度加快17.5×以上,以不到40萬步的訓(xùn)練量匹配有700萬步訓(xùn)練的SiT-XL模型的性能,同時實現(xiàn)了FID=1.42的SOTA結(jié)果。
REPA:使用表征對齊的正則化
統(tǒng)一視角的擴散模型+流模型
由于論文希望同時優(yōu)化基于流的模型SiT和基于去噪的擴散模型DiT,因此首先從統(tǒng)一的隨機插值視角,對這兩種模型進行簡要的回顧。
考慮在t∈[0,T]的連續(xù)時間步中,對數(shù)據(jù)??*~p(??)使用高斯分布ε~??(0,??)添加隨機噪音:
其中,αt和σt分別表示t的遞減和遞增函數(shù)。在公式(1)給定的過程中,存在一個帶有速度場(velocity field)的概率流常微分方程:
其中t步時的分布就等于邊際概率pt(??)。
速度??(??,t)可以表示為如下兩個條件期望之和:
這個值可以通過最小化如下訓(xùn)練目標(biāo)得到近似值??θ(??,t):
同時,還存在一個反向的隨機微分方程(SDE),帶有擴散系數(shù)wt,其中的邊際概率pt(??)與公式(2)相符:
其中,??(??t,t)是一個條件期望值,定義為:
對任意t>0,都可以通過速度??(??,t)計算出??(??,t)的值:
這表明,數(shù)據(jù)??t也可以通過求解公式(5)的SDE來以另一種方式生成。
以上定義對類似的擴散模型變體,例如DDPM,同樣適用,只是需要將連續(xù)的時間步離散化。
方法概述
令p(??)為數(shù)據(jù)??∈??的未知目標(biāo)分布,我們的訓(xùn)練目標(biāo)就是通過模型對數(shù)據(jù)的學(xué)習(xí)得到p(??)的近似。
為了降低計算成本,最近流行的「潛在擴散」方法(latent diffusion)提出學(xué)習(xí)潛在變量??=E(??)的分布p(??),其中E表示來自預(yù)訓(xùn)練自編碼器(例如KL-VAE)中的編碼部分。
要學(xué)習(xí)到分布p(??),就需要訓(xùn)練擴散模型??θ(??t,t),訓(xùn)練目標(biāo)是進行速度預(yù)測,具體方法如上一節(jié)所述。
放在自監(jiān)督表示學(xué)習(xí)的背景中,可以將擴散模型看成編碼器fθ:?????和解碼器gθ:?????的組合,其中編碼器負(fù)責(zé)隱式地學(xué)習(xí)到表示??t以重建目標(biāo)??t。
然而,作者提出,用于生成的大型擴散模型并不擅長表征學(xué)習(xí),因此REPA引入了外部的語義豐富的表示,從而顯著提升生成性能。
REPA方法概述
模型觀察
擴散模型是否真的不擅長表征學(xué)習(xí)?這需要更進一步地觀察模型才能確定,為此,研究人員測量并比對了diffusion transformer和當(dāng)前的SOTA自監(jiān)督模型DINOv2之間的表征差距,包括語義差距和特征對齊兩種角度。
語義差距
從圖2a可知,預(yù)訓(xùn)練SiT的隱藏層表示在第20層達到最佳狀態(tài),這與之前的研究結(jié)果相符,但仍遠(yuǎn)遠(yuǎn)落后于DINOv2。
特征對齊
如圖2b和2c所示,使用CKNNA值測量SiT和DINOv2之間的表征對齊程度后發(fā)現(xiàn),SiT的對齊效果會隨著模型增大和訓(xùn)練迭代步數(shù)增加而逐漸改善,但即使增加到7M次迭代,和DINOv2之間的對齊程度仍然不足。
事實上,這種差距不僅在SiT中存在,根據(jù)附錄C.2的實驗結(jié)果,DiT等其他基于去噪的生成式Transformer模型也存在類似的問題。
縮小表征差距
那么,REPA方法究竟如何縮小這種表征差距,讓diffusion transformer在噪聲輸入中也能學(xué)到有用的語義特征?
定義N,D分別表示patch數(shù)量預(yù)訓(xùn)練編碼器f的嵌入維度,編碼器輸入為無噪聲的圖像??*,輸出為??*=f(??*)∈?N×D。
Diffusion transformer將編碼器輸出??t=fθ(??t)通過一個可訓(xùn)練的投影頭hφ(MLP)投影為hφ(??t)∈?N×D。
之后,REPA負(fù)責(zé)將hφ(??t)與??*進行對齊,通過最大化兩者間的patch間相似度:
在實際實現(xiàn)中,將這一項添加到公式(4)定義的基于擴散的訓(xùn)練目標(biāo)中,就得到總體的訓(xùn)練目標(biāo):
其中超參數(shù)λ>0用于控制模型在去噪目標(biāo)和表征對齊間的權(quán)衡。
從圖3結(jié)果可知,REPA減少了表示中的語義差距。
有趣的是,使用REPA后,僅對齊前幾個Transformer塊就能實現(xiàn)足夠程度的表示對齊,從而讓diffusion transformer的靠后層專注于捕獲高頻細(xì)節(jié),從而進一步提高生成性能。
實驗結(jié)果
為了驗證REPA方法的有效性,實驗在兩種流行的擴散模型訓(xùn)練目標(biāo)(即??velocity)上進行了實驗,包括DiT中改進后的DDPM和SiT中的線性隨機插值,但實際中也同樣可以考慮其他的訓(xùn)練目標(biāo)。
所用模型默認(rèn)嚴(yán)格遵循SiT和DiT的原始結(jié)構(gòu)(除非有特別說明),包括B/2、L/2、XL/2三種參數(shù)設(shè)置,如表1所示。
以下實驗旨在回答3個問題:
- REPA能否顯著提升diffusion transformer的訓(xùn)練?
- REPA在模型規(guī)模和表征質(zhì)量方面是否具有可擴展性?
- 擴散模型的表征能否和多種視覺表征進行對齊?
REPA提升視覺縮放
首先比較兩個SiT-XL/2模型在前400K次迭代期間生成的圖像,它們共享相同的噪聲、采樣器和采樣步數(shù),但其中使用REPA訓(xùn)練的模型顯示出更好的進展。
REPA在各個方面都展現(xiàn)出了強大的可擴展性
研究人員還改變了預(yù)訓(xùn)練編碼器和Diffusion Transformer的模型大小來檢驗REPA的可擴展性。
圖5a結(jié)果表明,與更好的視覺表示相結(jié)合可以改善生成效果和線性探測的結(jié)果。
此外,如圖5b和c所示,增加模型大小可以在生成和線性評估方面帶來更快的收益,也就是說,模型規(guī)模越大,REPA的加速效果越明顯,表現(xiàn)出了強大的可擴展性。
REPA顯著提高訓(xùn)練效率和生成質(zhì)量
最后,論文比較了普通DiT或SiT模型在訓(xùn)練中使用REPA前后的FID值。
在沒有指導(dǎo)的情況下,REPA在400K次迭代時實現(xiàn)了FID=7.9,優(yōu)于普通模型在7M次迭代后的性能。
此外,使用無分類器引導(dǎo)時,帶有REPA的SiT-XL/2的性能優(yōu)于SOTA性能(FID=1.42),同時迭代次數(shù)減少了7倍。
作者介紹
Sihyun Yu
本文一作Sihyun Yu是KAIST(韓國科學(xué)技術(shù)院)人工智能專業(yè)最后一年的博士生,此前他同樣在KAIST獲得了數(shù)學(xué)和計算機科學(xué)的雙專業(yè)學(xué)士學(xué)位。
他的研究主要集中在減少大型生成模型訓(xùn)練(和采樣)的內(nèi)存和計算負(fù)擔(dān),其中,對大規(guī)模且高效的視頻生成特別感興趣;博士期間,他還曾在英偉達和谷歌研究院擔(dān)任實習(xí)生。