突破性的百萬級(jí)視頻和語(yǔ)言世界模型:Large World Model
本文經(jīng)自動(dòng)駕駛之心公眾號(hào)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請(qǐng)聯(lián)系出處。
在探索如何讓AI更好地理解世界方面,最近的一項(xiàng)突破性研究引起了廣泛關(guān)注。來自加州大學(xué)伯克利分校的研究團(tuán)隊(duì)發(fā)布了“Large World Model, LWM”,能夠同時(shí)處理百萬級(jí)長(zhǎng)度的視頻和語(yǔ)言序列,實(shí)現(xiàn)了對(duì)復(fù)雜場(chǎng)景的深入理解。這一研究無疑為未來AI的發(fā)展開啟了新的篇章。
論文地址:World Model on Million-Length Video And Language With RingAttention

博客地址:Large World Models

huggingface: LargeWorldModel (Large World Model)
在傳統(tǒng)方法中,AI模型往往只能處理較短的文本或視頻片段,缺乏對(duì)長(zhǎng)時(shí)間復(fù)雜場(chǎng)景的理解能力。然而,現(xiàn)實(shí)世界中的許多場(chǎng)景,如長(zhǎng)篇書籍、電影或電視劇,都包含了豐富的信息,需要更長(zhǎng)的上下文來進(jìn)行深入理解。為了應(yīng)對(duì)這一挑戰(zhàn),LWM團(tuán)隊(duì)采用了環(huán)形注意力(RingAttention)技術(shù),成功擴(kuò)展了模型的上下文窗口,使其能夠處理長(zhǎng)達(dá)100萬個(gè)令牌(1M tokens)的序列。例如實(shí)現(xiàn)超過 1 小時(shí)的問答視頻:

圖1.長(zhǎng)視頻理解。LWM 可以回答有關(guān)超過 1 小時(shí)的 YouTube 視頻的問題。
超過 1M 上下文的事實(shí)檢索:

圖 2. 針檢索任務(wù)。LWM 在 1M 上下文窗口內(nèi)實(shí)現(xiàn)了高精度,并且性能優(yōu)于 GPT-4V 和 Gemini Pro。

圖 3. 針檢索任務(wù)。LWM 對(duì)于上下文窗口中不同的上下文大小和位置實(shí)現(xiàn)了高精度。
技術(shù)實(shí)現(xiàn)
為了訓(xùn)練和評(píng)估LWM,研究人員首先收集了一個(gè)包含各種視頻和書籍的大型數(shù)據(jù)集。然后,他們逐步增加了訓(xùn)練的上下文長(zhǎng)度,從4K tokens開始,逐步擴(kuò)展到1M tokens。這一過程不僅有效降低了訓(xùn)練成本,還使模型能夠逐步適應(yīng)更長(zhǎng)序列的學(xué)習(xí)。在訓(xùn)練過程中,研究人員還發(fā)現(xiàn),混合不同長(zhǎng)度的圖像、視頻和文本數(shù)據(jù)對(duì)于模型的多模態(tài)理解至關(guān)重要。具體包括:
模型訓(xùn)練分兩個(gè)階段:首先通過訓(xùn)練大型語(yǔ)言模型擴(kuò)展上下文大小。然后進(jìn)行視頻和語(yǔ)言的聯(lián)合訓(xùn)練。
Stage I: Learning Long-Context Language Models
擴(kuò)展上下文:利用RingAttention技術(shù),可以無近似地?cái)U(kuò)展上下文長(zhǎng)度到數(shù)百萬個(gè)token。同時(shí),通過逐步增加訓(xùn)練序列長(zhǎng)度,從32K tokens開始,逐步增加到1M tokens,以減少計(jì)算成本。此外,為了擴(kuò)展位置編碼以適應(yīng)更長(zhǎng)的序列,采用了簡(jiǎn)單的方法,即隨上下文窗口大小增加而增加RoPE中的θ。

上下文擴(kuò)展和視覺語(yǔ)言訓(xùn)練。使用 RingAttention 將書籍上的上下文大小從 4K 擴(kuò)展到 1M,然后對(duì)長(zhǎng)度為 32K 到 1M 的多種形式的視覺內(nèi)容進(jìn)行視覺語(yǔ)言訓(xùn)練。下面板顯示了理解和響應(yīng)有關(guān)復(fù)雜多模式世界的查詢的交互功能。
訓(xùn)練步驟:首先從LLaMA-2 7B模型初始化,然后在5個(gè)階段逐步增加上下文長(zhǎng)度,分別是32K、128K、256K、512K和1M tokens。每個(gè)階段都使用不同過濾版本的Books3數(shù)據(jù)集進(jìn)行訓(xùn)練。隨著上下文長(zhǎng)度的增加,模型能夠處理更多tokens。

任意對(duì)任意長(zhǎng)序列預(yù)測(cè)。RingAttention 能夠使用非常大的上下文窗口進(jìn)行跨視頻-文本、文本-視頻、圖像-文本、文本-圖像、純視頻、純圖像和純文本等多種格式的訓(xùn)練。請(qǐng)參閱LWM 論文了解關(guān)鍵功能,包括屏蔽序列打包和損失加權(quán),它們可以實(shí)現(xiàn)有效的視頻語(yǔ)言訓(xùn)練。
對(duì)話微調(diào):為了學(xué)習(xí)長(zhǎng)上下文的對(duì)話能力,構(gòu)建了一個(gè)簡(jiǎn)單的問答數(shù)據(jù)集,將Books3數(shù)據(jù)集的文檔分割成1000 token的塊,然后利用短上下文語(yǔ)言模型為每個(gè)塊生成一個(gè)問答對(duì),最后將相鄰的塊連接起來構(gòu)造一個(gè)長(zhǎng)上下文的問答示例。在微調(diào)階段,模型在UltraChat和自定義問答數(shù)據(jù)集上進(jìn)行訓(xùn)練,比例為7:3。


語(yǔ)言評(píng)估結(jié)果:在單針檢索任務(wù)中,1M上下文的模型可以在整個(gè)上下文中近乎完美地檢索出隨機(jī)分配給隨機(jī)城市的數(shù)字。在多針檢索任務(wù)中,模型在檢索一個(gè)針時(shí)表現(xiàn)良好,在檢索多個(gè)針時(shí)性能略有下降。在短上下文語(yǔ)言任務(wù)評(píng)估中,擴(kuò)大上下文長(zhǎng)度并沒有降低性能。在對(duì)話評(píng)估中,增加對(duì)話交互能力可能會(huì)降低系統(tǒng)檢索具體信息或“針”的精度。




Stage II: Learning Long-Context Vision-Language Models
架構(gòu)修改:在第一階段的基礎(chǔ)上,對(duì)LWM和LWM-Chat進(jìn)行修改,使其能夠接受視覺輸入。具體來說,使用預(yù)訓(xùn)練的VQGAN將256x256的輸入圖像轉(zhuǎn)換為16x16的離散token,對(duì)視頻進(jìn)行逐幀的VQGAN編碼并將編碼連接起來。此外,引入了特殊的標(biāo)記符號(hào)和來區(qū)分文本和視覺token,以及和來標(biāo)記圖像和視頻幀的結(jié)束。


訓(xùn)練步驟:從LWM-Text-1M模型初始化,采用與第一階段類似的逐步增加序列長(zhǎng)度的訓(xùn)練方法,首先在1K tokens上訓(xùn)練,然后是8K tokens,最后是32K、128K和1M tokens。訓(xùn)練數(shù)據(jù)包括文本-圖像對(duì)、文本-視頻對(duì)以及下游任務(wù)的聊天數(shù)據(jù),如文本-圖像生成、圖像理解、文本-視頻生成和視頻理解。在訓(xùn)練過程中,逐步增加下游任務(wù)的混合比例。

視覺-語(yǔ)言評(píng)估結(jié)果:在長(zhǎng)視頻理解方面,模型能夠處理長(zhǎng)達(dá)1小時(shí)的YouTube視頻并準(zhǔn)確回答問題,相較于現(xiàn)有模型具有明顯優(yōu)勢(shì)。在圖像理解和短視頻理解方面,模型表現(xiàn)一般,但通過更嚴(yán)格的訓(xùn)練和更好的分詞器,有潛力改進(jìn)。在圖像和視頻生成方面,模型可以從文本生成圖像和視頻。Ablation研究表明,屏蔽序列填充對(duì)于圖像理解等下游任務(wù)至關(guān)重要。



文本到圖像。LWM 根據(jù)文本提示以自回歸方式生成圖像。

文本到視頻。LWM 根據(jù)文本提示以自回歸方式生成視頻。
第二階段通過逐步增加序列長(zhǎng)度并在大量文本-圖像和文本-視頻數(shù)據(jù)上訓(xùn)練,成功擴(kuò)展了第一階段的語(yǔ)言模型,使其具備視覺理解能力。這一階段的模型可以處理長(zhǎng)達(dá)1M tokens的多模態(tài)序列,并在長(zhǎng)視頻理解、圖像理解和生成等方面展現(xiàn)出強(qiáng)大的能力。
技術(shù)細(xì)節(jié)(Further Details)
訓(xùn)練計(jì)算資源:模型使用TPUv4-1024進(jìn)行訓(xùn)練,相當(dāng)于450個(gè)A100 GPU,使用FSDP進(jìn)行數(shù)據(jù)并行,并通過RingAttention支持大上下文。

訓(xùn)練損失曲線:圖10和圖11展示了第一階段語(yǔ)言模型和第二階段視覺-語(yǔ)言模型的訓(xùn)練損失曲線??梢钥闯?,隨著訓(xùn)練進(jìn)行,損失持續(xù)下降。

訓(xùn)練超參數(shù):附錄F提供了詳細(xì)的訓(xùn)練超參數(shù),包括參數(shù)量、初始化模型、序列長(zhǎng)度、RoPE參數(shù)、每批tokens數(shù)、總tokens數(shù)、訓(xùn)練步驟數(shù)、學(xué)習(xí)率計(jì)劃、學(xué)習(xí)率預(yù)熱步數(shù)、最大學(xué)習(xí)率和最小學(xué)習(xí)率、計(jì)算資源等。


推斷擴(kuò)展:實(shí)現(xiàn)了RingAttention用于解碼,支持對(duì)長(zhǎng)達(dá)數(shù)百萬tokens的序列進(jìn)行推斷,需使用至少v4-128 TPU,并進(jìn)行32路tensor并行和4路序列并行。

量化:文檔指出,模型使用單精度進(jìn)行推斷,通過量化等技術(shù)可以進(jìn)一步提高擴(kuò)展性。
一些例子
基于圖像的對(duì)話。

圖 6. 圖像理解。LWM 可以回答有關(guān)圖像的問題。
超過 1 小時(shí)的 YouTube 視頻視頻聊天。


圖 7. 長(zhǎng)視頻聊天。
即使最先進(jìn)的商業(yè)模型 GPT-4V 和 Gemini Pro 都失敗了,LWM 仍能回答有關(guān) 1 小時(shí)長(zhǎng)的 YouTube 視頻的問題。每個(gè)示例的相關(guān)剪輯位于時(shí)間戳 9:56(頂部)和 6:49(底部)。















 
 
 










 
 
 
 