Make U-Nets Great Again!北大&華為提出擴(kuò)散架構(gòu)U-DiT,六分之一算力即可超越DiT
Sora 的發(fā)布讓廣大研究者及開(kāi)發(fā)者深刻認(rèn)識(shí)到基于 Transformer 架構(gòu)擴(kuò)散模型的巨大潛力。作為這一類的代表性工作,DiT 模型拋棄了傳統(tǒng)的 U-Net 擴(kuò)散架構(gòu),轉(zhuǎn)而使用直筒型去噪模型。鑒于直筒型 DiT 在隱空間生成任務(wù)上效果出眾,后續(xù)的一些工作如 PixArt、SD3 等等也都不約而同地使用了直筒型架構(gòu)。
然而令人感到不解的是,U-Net 結(jié)構(gòu)是之前最常用的擴(kuò)散架構(gòu),在圖像空間和隱空間的生成效果均表現(xiàn)不俗;可以說(shuō) U-Net 的 inductive bias 在擴(kuò)散任務(wù)上已被廣泛證實(shí)是有效的。因此,北大和華為的研究者們產(chǎn)生了一個(gè)疑問(wèn):能否重新拾起 U-Net,將 U-Net 架構(gòu)和 Transformer 有機(jī)結(jié)合,使擴(kuò)散模型效果更上一層樓?帶著這個(gè)問(wèn)題,他們提出了基于 U-Net 的 DiT 架構(gòu) U-DiT。
- 論文標(biāo)題:U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers
- 論文地址:https://arxiv.org/pdf/2405.02730
- GitHub 地址:https://github.com/YuchuanTian/U-DiT
從一個(gè)小實(shí)驗(yàn)談開(kāi)去
首先,研究者開(kāi)展了一個(gè)小實(shí)驗(yàn),在實(shí)驗(yàn)中嘗試著將 U-Net 和 DiT 模塊簡(jiǎn)單結(jié)合。然而,如表 1 所示,在相似的算力比較下,U-Net 的 DiT(DiT-UNet)僅僅比原始的 DiT 有略微的提升。
在圖 3 中,作者們展示了從原始的直筒 DiT 模型一步步演化到 U-DiT 模型的過(guò)程。
根據(jù)先前的工作,在擴(kuò)散中 U-Net 的主干結(jié)構(gòu)特征圖主要為低頻信號(hào)。由于全局自注意力運(yùn)算機(jī)制需要消耗大量算力,在 U-Net 的主干自注意力架構(gòu)中可能存在冗余。這時(shí)作者注意到,簡(jiǎn)單的下采樣可以自然地濾除噪聲較多的高頻,強(qiáng)調(diào)信息充沛的低頻。既然如此,是否可以通過(guò)下采樣來(lái)消除對(duì)特征圖自注意力中的冗余?
Token 下采樣后的自注意力
由此,作者提出了下采樣自注意力機(jī)制。在自注意力之前,首先需將特征圖進(jìn)行 2 倍下采樣。為避免重要信息的損失,生成了四個(gè)維度完全相同的下采樣圖,以確保下采樣前后的特征總維度相同。隨后,在四個(gè)特征圖上使用共用的 QKV 映射,并分別獨(dú)立進(jìn)行自注意力運(yùn)算。最后,將四個(gè) 2 倍下采樣的特征圖重新融為一個(gè)完整特征圖。和傳統(tǒng)的全局自注意力相比,下采樣自注意力可以使得自注意力所需算力降低 3/4。
令人驚訝的是,盡管加入下采樣操作之后能夠顯著模型降低所需算力,但是卻反而能獲得比原來(lái)更好的效果(表 1)。
U-DiT:全面超越 DiT
根據(jù)此發(fā)現(xiàn),作者提出了基于下采樣自注意力機(jī)制的 U 型擴(kuò)散模型 U-DiT。對(duì)標(biāo) DiT 系列模型的算力,作者提出了三個(gè) U-DiT 模型版本(S/B/L)。在完全相同的訓(xùn)練超參設(shè)定下,U-DiT 在 ImageNet 生成任務(wù)上取得了令人驚訝的生成效果。其中,U-DiT-L 在 400K 訓(xùn)練迭代下的表現(xiàn)比直筒型 DiT-XL 模型高約 10 FID,U-DiT-S/B 模型比同級(jí)直筒型 DiT 模型高約 30 FID;U-DiT-B 模型只需 DiT-XL/2 六分之一的算力便可達(dá)到更好的效果(表 2、圖 1)。
在有條件生成任務(wù)(表 3)和大圖(512*512)生成任務(wù)(表 5)上,U-DiT 模型相比于 DiT 模型的優(yōu)勢(shì)同樣非常明顯。
研究者們還進(jìn)一步延長(zhǎng)了訓(xùn)練的迭代次數(shù),發(fā)現(xiàn) U-DiT-L 在 600K 迭代時(shí)便能優(yōu)于 DiT 在 7M 迭代時(shí)的無(wú)條件生成效果(表 4、圖 2)。
U-DiT 模型的生成效果非常出眾,在 1M 次迭代下的有條件生成效果已經(jīng)非常真實(shí)。
論文已被 NeurIPS 2024 接收,更多內(nèi)容,請(qǐng)參考原論文。