變分掩碼擴散模型:解決并發(fā)標記預(yù)測中的依賴關(guān)系問題
1. 研究背景與問題定義
基于擴散的大型語言模型(DLLMs)作為自回歸模型(ARMs)的重要擴展,正在成為生成式AI領(lǐng)域的重要創(chuàng)新方向。與傳統(tǒng)ARMs按預(yù)定義順序順序生成標記的方式不同,DLLMs提供了并發(fā)標記生成、更高輸出多樣性、增強全局一致性以及更好的生成文本可控性等優(yōu)勢。近期的突破性模型如LLaDA、Mercury和Gemini Diffusion都凸顯了DLLMs的潛力。
然而,當前的掩碼擴散模型(MDM)存在一個關(guān)鍵限制:無法有效捕獲并發(fā)預(yù)測的標記之間的依賴關(guān)系。這導(dǎo)致在標記間依賴性較強的推理任務(wù)中性能下降。例如,在預(yù)測"A poker hand that consists of two English words is: a a"的后續(xù)兩個詞時,適合的預(yù)測應(yīng)為"high card"、"two pair"、"full house"或"straight flush"。這些詞對之間存在強依賴關(guān)系,但MDM在并發(fā)預(yù)測時會獨立采樣,無法考慮這種依賴性,從而可能產(chǎn)生不合理的組合。
2. 變分掩碼擴散(VMD)模型
為解決上述問題,研究者提出了變分掩碼擴散(Variational Masked Diffusion, VMD)框架,通過引入潛在變量來建模并發(fā)預(yù)測期間的聯(lián)合標記分布。VMD的核心思想是:通過潛在變量模型捕獲任意聯(lián)合分布,而不僅僅是可因式分解的分布。
圖片
2.1 基本變分公式
VMD的基本公式為:
pθ(x0i|xt) = ∫pθ(x0i|xt,z)p(z)dz其中z是全局潛在變量,不依賴于標記位置i。這使得模型能夠在標記之間建立聯(lián)合分布。條件于潛在變量z,標記可以獨立采樣,但邊緣化潛在變量后,能夠從正確的聯(lián)合分布中獲得樣本。
訓(xùn)練目標函數(shù)(NELBO)為:
-log pθ(x0) ≤ ∫(0→1) (1/t) ??qt|0(xt|x0)[??q?[∑i:xti=[MASK] -log pθ(x0i|xt,z)] + DKL(q?(·|x0,xt)||p(·))]dt其中p(·)是潛在空間的先驗分布,q?(·|x0,xt)是由可訓(xùn)練參數(shù)φ參數(shù)化的近似后驗分布。
2.2 塊擴散公式
為了解決大規(guī)模標記序列中的依賴建模問題,VMD進一步引入了塊擴散公式。將標記x分為B個連續(xù)塊,每個塊長度為r,總序列長度L=Br。每個塊使用一個潛在變量zb。數(shù)據(jù)對數(shù)似然可以因式分解為:
log pθ(x0|xt) = ∑(b=1→B) log pθ(x0b|xtb,x0<b) = ∑(b=1→B) log ∫pθ(x0b|xtb,x0<b,z≤b)p(z≤b)dz≤b這種方法結(jié)合了自回歸和擴散模型的優(yōu)勢,在塊內(nèi)使用變分擴散,塊間使用自回歸處理。通過調(diào)整塊大小B(或等效的r),可以在自回歸變分模型和變分掩碼擴散模型之間進行插值。
2.3 重新掩碼策略
VMD考慮了兩種重新掩碼策略:
- 基于概率的置信度:cprob.i = pθ(x0i=v1|xt,z),其中v1是詞匯表中最可能的值
- 基于邊際的置信度:cmarg.i = |pθ(x0i=v1|xt,z) - pθ(x0i=v2|xt,z)|,其中v1和v2是兩個最可能的值
這些策略通過潛在變量提供了更全局的上下文,相比傳統(tǒng)MDM的局部置信度指標更為有效。
3. 實驗結(jié)果與分析
3.1 控制合成數(shù)據(jù)實驗
3.1.1 兩標記序列實驗
研究者首先在兩標記序列上進行了實驗,包括確定性依賴、非均勻分布和可變依賴強度三種設(shè)置:
- 確定性依賴:數(shù)據(jù)為{(k,k+1mod10)}k=0→9,第二個標記完全由第一個決定。標準MDM在并發(fā)解碼時退化為隨機猜測(約10%準確率),而VMD能可靠地生成正確對。
- 非均勻分布:分布為P((k,k+1mod10))=(k+1)/55,k∈{0,...,9}。VMD更準確地從目標分布中采樣。
- 可變依賴強度:通過參數(shù)p控制依賴強度,VMD在整個依賴強度范圍內(nèi)都能準確建模真實數(shù)據(jù)分布。
3.1.2 四標記序列實驗
在四標記序列實驗中,研究者構(gòu)建了兩個數(shù)據(jù)集:
- D1:包含10個唯一序列{(k,k+1,k+2,k+3mod10)}k=0→9,具有強標記依賴性
- D2:{(k,k+1,l,l+1mod10)}k,l=0→9,第一個長度為2的塊與第二個塊獨立
結(jié)果顯示,在并行解碼所有四個標記時(B=1, NFE=1),標準MDM表現(xiàn)類似隨機猜測(D1為0.1%,D2為1.1%),而VMD達到了顯著更高的準確率(D1為81.5%,D2為64.4%)。
3.2 數(shù)獨數(shù)據(jù)實驗
數(shù)獨是一個9×9網(wǎng)格的邏輯謎題,需要填充空白單元格,使每行、每列和每個3×3子網(wǎng)格都包含1到9的所有數(shù)字。這種全局和局部依賴性使其成為評估生成模型捕獲標記依賴能力的良好基準。
實驗結(jié)果表明,VMD在各種采樣方法和NFE值下都優(yōu)于基線模型。特別是在較低NFE值時,VMD的優(yōu)勢更為明顯,表明它能更高效地生成有效解決方案。
3.3 文本數(shù)據(jù)實驗
在text8數(shù)據(jù)集上的實驗(包含維基百科的前1億個字符),VMD達到了比標準擴散基線(SEDD, MDLM)更低的困惑度,并在兩種塊大小上都比BD3-LM有輕微但一致的改進。特別是在塊大小為4時,VMD達到了2.858的困惑度,這是基于擴散模型中的最佳結(jié)果。
圖片
4. 技術(shù)細節(jié)與實現(xiàn)
4.1 模型架構(gòu)
VMD的編碼器和解碼器都采用了類似于BD3-LM的架構(gòu),但進行了修改以包含潛在信息。為確保公平比較,解碼器主干與基線保持相同,而編碼器的層數(shù)進行了調(diào)整,使其架構(gòu)鏡像解碼器但參數(shù)數(shù)量減半。
在數(shù)獨實驗中,模型使用了DiT架構(gòu),編碼器和解碼器分別有4層和6層Transformer層。每個塊分配了一個128維的潛在向量,通過共享的自適應(yīng)層歸一化模塊注入到解碼器的每個DiT塊中。
4.2 訓(xùn)練細節(jié)
- 合成數(shù)據(jù):批量大小10,000,訓(xùn)練2,000步,使用Adam優(yōu)化器,固定學習率1e-3
- 數(shù)獨:批量大小1,024,學習率1e-3,余弦學習率調(diào)度器,300個周期
- 文本:批量大小512,使用AdamW優(yōu)化器,學習率3e-4,恒定學習率調(diào)度,2,500步預(yù)熱
4.3 推理過程
推理過程遵循塊擴散方法,應(yīng)用KV緩存以提高采樣效率。每個生成的塊x0b存儲在緩存中,這意味著訓(xùn)練期間使用的前綴上下文x0<b對應(yīng)于累積的鍵和值(K1:b-1,V1:b-1)。解碼器對塊b的預(yù)測因此僅依賴于當前噪聲輸入xtb、潛在變量z≤b和來自早期生成塊的緩存上下文。
5. 與相關(guān)工作的比較
VMD與其他離散擴散模型和非自回歸模型相比具有以下優(yōu)勢:
- 相比標準掩碼擴散模型(如LLaDA、LLaDA-V),VMD能更好地捕獲并發(fā)采樣標記之間的依賴關(guān)系
- 相比BD3-LM,VMD在保持塊自回歸結(jié)構(gòu)的同時,進一步提高了并行采樣的效率和準確性
- 相比其他非自回歸模型(如BERT),VMD通過潛在變量提供了更豐富的文本表示
6. 未來展望
VMD為掩碼擴散模型開辟了新的研究方向,未來可能的發(fā)展包括:
- 多層次潛在變量:引入層次化潛在變量結(jié)構(gòu),以更好地捕獲不同抽象級別的依賴關(guān)系。例如,可以設(shè)計一個包含句子級、段落級和文檔級潛在變量的模型,每個級別負責不同范圍的依賴建模。
- 條件生成增強:將VMD擴展到更復(fù)雜的條件生成任務(wù),如文本到圖像、跨模態(tài)翻譯等。潛在變量可以作為不同模態(tài)之間的橋梁,促進更一致的跨模態(tài)生成。
- 可解釋性研究:探索潛在空間的語義意義,開發(fā)可視化和分析工具,幫助理解模型如何表示和利用標記間依賴關(guān)系。這可以通過潛在空間的降維分析和特定語言現(xiàn)象的干預(yù)實驗來實現(xiàn)。
- 計算效率優(yōu)化:開發(fā)更高效的推理算法,減少VMD的計算開銷。可能的方向包括自適應(yīng)采樣策略、稀疏注意力機制和模型量化技術(shù)。
- 應(yīng)用擴展:將VMD應(yīng)用于更多需要強依賴建模的任務(wù),如程序合成、數(shù)學推理和科學發(fā)現(xiàn)。特別是在需要保持全局一致性的復(fù)雜推理任務(wù)中,VMD的優(yōu)勢可能更為明顯。
這些方向不僅具有理論創(chuàng)新性,也有實際應(yīng)用價值,可以進一步提升生成模型在復(fù)雜依賴建模方面的能力,推動生成AI向更智能、更一致的方向發(fā)展。
7. 結(jié)論
變分掩碼擴散(VMD)模型通過引入潛在變量,成功解決了標準掩碼擴散在并發(fā)標記預(yù)測中無法有效捕獲依賴關(guān)系的問題。在合成數(shù)據(jù)、數(shù)獨謎題和文本數(shù)據(jù)上的實驗都證明了VMD的有效性,特別是在標記間依賴關(guān)系重要的場景中。VMD不僅提高了生成質(zhì)量,也增強了對依賴關(guān)系的感知,凸顯了將變分推理集成到掩碼擴散中的價值。
參考資源
- 論文鏈接:https://arxiv.org/abs/2510.23606
- 代碼實現(xiàn):https://riccizz.github.io/VMD

































