世界模型也擴(kuò)散!訓(xùn)練出的智能體竟然不錯(cuò)
世界模型提供了一種以安全且樣本高效的方式訓(xùn)練強(qiáng)化學(xué)習(xí)智能體的方法。近期,世界模型主要對(duì)離散潛在變量序列進(jìn)行操作來(lái)模擬環(huán)境動(dòng)態(tài)。
然而,這種壓縮為緊湊離散表征的方式可能會(huì)忽略對(duì)強(qiáng)化學(xué)習(xí)很重要的視覺(jué)細(xì)節(jié)。另一方面,擴(kuò)散模型已成為圖像生成的主要方法,對(duì)離散潛在模型提出了挑戰(zhàn)。
受這種范式轉(zhuǎn)變的推動(dòng),來(lái)自日內(nèi)瓦大學(xué)、愛(ài)丁堡大學(xué)、微軟研究院的研究者聯(lián)合提出一種在擴(kuò)散世界模型中訓(xùn)練的強(qiáng)化學(xué)習(xí)智能體 —— DIAMOND(DIffusion As a Model Of eNvironment Dreams)。
- 論文地址:https://arxiv.org/abs/2405.12399
- 項(xiàng)目地址:https://github.com/eloialonso/diamond
- 論文標(biāo)題:Diffusion for World Modeling: Visual Details Matter in Atari
DIAMOND 在 Atari 100k 基準(zhǔn)測(cè)試中獲得了 1.46 的平均人類(lèi)歸一化得分 (HNS),可以媲美完全在世界模型中訓(xùn)練的智能體的 SOTA 水平。該研究提供了定性分析來(lái)說(shuō)明,DIAMOND 的設(shè)計(jì)選擇對(duì)于確保擴(kuò)散世界模型的長(zhǎng)期高效穩(wěn)定是必要的。
此外,在圖像空間中操作的好處是使擴(kuò)散世界模型能夠成為環(huán)境的直接替代品,從而提供對(duì)世界模型和智能體行為更深入的了解。特別地,該研究發(fā)現(xiàn)某些游戲中性能的提高源于對(duì)關(guān)鍵視覺(jué)細(xì)節(jié)的更好建模。
方法介紹
接下來(lái),本文介紹了 DIAMOND, 這是一種在擴(kuò)散世界模型中訓(xùn)練的強(qiáng)化學(xué)習(xí)智能體。具體來(lái)說(shuō),研究者基于 2.2 節(jié)引入的漂移和擴(kuò)散系數(shù) f 和 g,這兩個(gè)系數(shù)對(duì)應(yīng)于一種特定的擴(kuò)散范式選擇。此外,該研究還選擇了基于 Karras 等人提出的 EDM 公式。
首先定義一個(gè)擾動(dòng)核,,其中,
是一個(gè)與擴(kuò)散時(shí)間相關(guān)的實(shí)值函數(shù),稱(chēng)為噪聲時(shí)間表。這對(duì)應(yīng)于將漂移和擴(kuò)散系數(shù)設(shè)為
和
。
接著使用 Karras 等人(2022)引入的網(wǎng)絡(luò)預(yù)處理,同時(shí)參數(shù)化公式(5)中的,作為噪聲觀(guān)測(cè)值和神經(jīng)網(wǎng)絡(luò)
預(yù)測(cè)值的加權(quán)和:
得到公式(6)
其中為了簡(jiǎn)潔定義,包含所有條件變量。
預(yù)處理器的選擇。選擇預(yù)處理器和
,以保持網(wǎng)絡(luò)輸入和輸出在任何噪聲水平
下的單位方差。
是噪聲水平的經(jīng)驗(yàn)轉(zhuǎn)換,
由
和數(shù)據(jù)分布的標(biāo)準(zhǔn)差
給出,公式為
結(jié)合公式 5 和 6,得到訓(xùn)練目標(biāo):
該研究使用標(biāo)準(zhǔn)的 U-Net 2D 來(lái)構(gòu)建向量場(chǎng),并保留一個(gè)包含過(guò)去 L 個(gè)觀(guān)測(cè)和動(dòng)作的緩沖區(qū),以此來(lái)對(duì)模型進(jìn)行條件化。接下來(lái)他們將這些過(guò)去的觀(guān)測(cè)按通道方式與下一個(gè)帶噪觀(guān)測(cè)拼接,并通過(guò)自適應(yīng)組歸一化層將動(dòng)作輸入到 U-Net 的殘差塊中。正如在第 2.3 節(jié)和附錄 A 中討論的,有許多可能的采樣方法可以從訓(xùn)練好的擴(kuò)散模型中生成下一個(gè)觀(guān)測(cè)。雖然該研究發(fā)布的代碼庫(kù)支持多種采樣方案,但該研究發(fā)現(xiàn)歐拉方法在不需要額外的 NFE(函數(shù)評(píng)估次數(shù))以及避免了高階采樣器或隨機(jī)采樣的不必要復(fù)雜性的情況下是有效的。
實(shí)驗(yàn)
為了全面評(píng)估 DIAMOND,該研究使用了公認(rèn)的 Atari 100k 基準(zhǔn)測(cè)試,該基準(zhǔn)測(cè)試包括 26 個(gè)游戲,用于測(cè)試智能體的廣泛能力。對(duì)于每個(gè)游戲,智能體只允許在環(huán)境中進(jìn)行 100k 次操作,這大約相當(dāng)于人類(lèi) 2 小時(shí)的游戲時(shí)間,以在評(píng)估前學(xué)習(xí)玩游戲。作為參考,沒(méi)有限制的 Atari 智能體通常訓(xùn)練 5000 萬(wàn)步,這相當(dāng)于經(jīng)驗(yàn)的 500 倍增加。研究者從頭開(kāi)始在每個(gè)游戲上用 5 個(gè)隨機(jī)種子訓(xùn)練 DIAMOND。每次運(yùn)行大約使用 12GB 的 VRAM,在單個(gè) Nvidia RTX 4090 上大約需要 2.9 天(總計(jì) 1.03 個(gè) GPU 年)。
表 1 比較了在世界模型中訓(xùn)練智能體的不同得分:
圖 2 中提供了平均值和 IQM( Interquartile Mean )置信區(qū)間:
結(jié)果表明,DIAMOND 在基準(zhǔn)測(cè)試中表現(xiàn)強(qiáng)勁,超過(guò)人類(lèi)玩家在 11 個(gè)游戲中的表現(xiàn),并達(dá)到了 1.46 的 HNS 得分,這是完全在世界模型中訓(xùn)練的智能體的新紀(jì)錄。該研究還發(fā)現(xiàn),DIAMOND 在需要捕捉細(xì)節(jié)的環(huán)境中表現(xiàn)特別出色,例如 Asterix、Breakout 和 Road Runner。
為了研究擴(kuò)散變量的穩(wěn)定性,該研究分析了自回歸生成的想象軌跡(imagined trajectory),如下圖 3 所示:
該研究發(fā)現(xiàn)有些情況需要迭代求解器將采樣過(guò)程驅(qū)動(dòng)到特定模式,如圖 4 所示的拳擊游戲:
如圖 5 所示,與 IRIS 想象的軌跡相比,DIAMOND 想象的軌跡通常具有更高的視覺(jué)質(zhì)量,并且更符合真實(shí)環(huán)境。
感興趣的讀者可以閱讀論文原文,了解更多研究?jī)?nèi)容。