全面超越CoT!Meta田淵棟團隊新作:連續(xù)思維鏈
比思維鏈更厲害的方法是什么?
答:連續(xù)思維鏈。
近日,Meta田淵棟團隊提出了針對LLM推理任務(wù)的新范式:Coconut( Chain of Continuous Thought)。
論文地址:https://arxiv.org/pdf/2412.06769
論文一作是來自UC San Diego的Shibo Hao,對于文章的爆火,田淵棟也發(fā)文感謝了「小天才」Tanishq Mathew Abraham的推薦。
注:Tanishq Mathew Abraham,19歲(去年)讀完博士,目前是Stability AI的研究總監(jiān)以及MedARC的創(chuàng)始人。
回到這篇文章,連續(xù)思維鏈?zhǔn)鞘裁矗?/span>
小編在之前曾介紹過微軟發(fā)明的「LLM語言」:讓AI用模型的中間數(shù)據(jù)進行交流,不必轉(zhuǎn)換成人類的語言,交互效率直接翻倍。
而在LLM的推理過程中,也是這么個情況。
人類的語言并不適合推理,讓AI自己思考就行了,思考過程沒必要轉(zhuǎn)換成人類語言。
所以,在形式上,本文的方法就是推理時去掉模型頭尾的LLM head和embedding層,使用中間狀態(tài)進行自回歸,只在輸出最終答案時才轉(zhuǎn)成人類語言。
當(dāng)然了,Coconut要搭配相應(yīng)的訓(xùn)練,才能展現(xiàn)自己的性能:
這效果還是很強的,分?jǐn)?shù)和CoT打平的同時,token數(shù)少了好幾倍。
——看來拋棄人類的束縛才是真理,感覺這個點還能繼續(xù)搞下去,
最后的最后就會發(fā)展成:AI之間說了什么我們聽不懂,AI心里怎么想的我們也不知道。
AI:I'm free。
論文細節(jié)
基于語言空間進行推理的LLM,會遇到一個嚴(yán)重的問題:每個特定token所需的推理量差異很大。
推理鏈中的大多數(shù)token都是為了流暢性而生成的,對實際推理過程的貢獻很小,但當(dāng)前的LLM架構(gòu)分配了幾乎相同的計算來預(yù)測每個token。
另一方面,神經(jīng)影像學(xué)研究也表明,語言網(wǎng)絡(luò)(大腦中負責(zé)語言理解和產(chǎn)生的區(qū)域)在各種推理任務(wù)中基本不活躍。
所以,語言空間可能并不是推理的最佳選擇,理想的LLM應(yīng)該自由進行推理,不受任何語言限制。
Coconut不進行隱藏狀態(tài)和語言之間的映射,這種修改將推理從語言空間內(nèi)解放出來,并且系統(tǒng)可以通過梯度下降進行端到端優(yōu)化,因為連續(xù)思維是完全可微分的。
為了加強潛在推理的訓(xùn)練,本文采用了多階段訓(xùn)練策略,有效利用語言推理鏈來指導(dǎo)訓(xùn)練過程。
另外,與基于語言的推理不同,Coconut中的連續(xù)思考可以同時編碼多個可能的后續(xù)步驟,從而允許類似于廣度優(yōu)先搜索(BFS)的推理過程。
雖然模型可能無法在最初做出正確的決定,但它可以在連續(xù)的思考中保持許多可能的選擇,并在一些隱含價值函數(shù)的指導(dǎo)下,通過推理逐步消除不正確的路徑。
訓(xùn)練過程
在訓(xùn)練時,模型接收問題作為輸入,并期望通過推理過程生成答案。作者利用語言CoT數(shù)據(jù)來監(jiān)督持續(xù)思考,實施多階段訓(xùn)練。
如圖2所示,初始階段,模型在常規(guī)CoT實例上進行訓(xùn)練。后續(xù)階段(第k階段),CoT中的前k個推理步驟被k × c個連續(xù)思維所取代,(c為超參數(shù),控制取代單個語言推理步驟的潛在思維的數(shù)量)。
作者在訓(xùn)練階段切換時重置優(yōu)化器狀態(tài),插入<bot>和<eot> token來封裝連續(xù)的思維。
在訓(xùn)練過程中,作者優(yōu)化了正常的負對數(shù)似然損失,但屏蔽了問題和潛在思維的損失。另一個關(guān)鍵點是,目標(biāo)函數(shù)并不鼓勵使用連續(xù)的思維來壓縮語言思維,而是促進對未來推理的預(yù)測。
因此,與人類語言相比,LLM可以從中學(xué)習(xí)更有效的推理步驟表示。
連續(xù)思維是完全可微分的,允許反向傳播。不過Coconut的訓(xùn)練效率仍然有待優(yōu)化:雖然可以通過使用KV cache來避免重復(fù)的計算,但多個前向傳遞的順序性阻礙了并行訓(xùn)練。
Coconut的推理過程可以看成是在latent和language模式之間切換。
對于思考的終止位置,作者考慮了兩種可能的策略:a)在潛在思維上訓(xùn)練二元分類器,使模型能夠自主決定何時終止?jié)撛谕评恚籦)始終將潛在思維填充到恒定的長度。
作者發(fā)現(xiàn)這兩種方法的效果都不錯。為了簡單起見,以下實驗中使用第二個選項。
實驗
研究人員通過在三個數(shù)據(jù)集上的實驗,驗證了LLM在連續(xù)潛在空間中進行推理的可行性。這里將模型生成的答案與真實值進行比較來評估準(zhǔn)確性,并且分析每個問題新生成的token數(shù)量,作為推理效率的衡量標(biāo)準(zhǔn)。
數(shù)學(xué)推理使用GSM8k作為數(shù)據(jù)集,由小學(xué)水平的數(shù)學(xué)問題組成,問題更加多樣化,與現(xiàn)實世界的用例非常相似。
邏輯推理涉及使用邏輯規(guī)則和已知條件來證明或反駁結(jié)論。這要求模型從多個可能的推理路徑中進行選擇,正確的決策通常依賴于提前探索和規(guī)劃。
這里使用帶有虛構(gòu)概念名稱的5-hop ProntoQA。對于每個問題,都會隨機生成一個樹形結(jié)構(gòu)的本體,并以自然語言描述為一組已知條件,要求模型根據(jù)這些條件判斷給定的陳述是否正確。
作者發(fā)現(xiàn)ProntoQA的生成過程比較困難,因為本體中分散注意力的分支總是很小,從而減少了對復(fù)雜規(guī)劃的需求。
為了解決這個問題,本文應(yīng)用了新的數(shù)據(jù)集構(gòu)建管道,使用隨機生成的DAG來構(gòu)建已知條件。生成的數(shù)據(jù)集要求模型對圖進行大量規(guī)劃和搜索,以找到正確的推理鏈。這個新數(shù)據(jù)集被稱為ProsQA,如下圖所示。
實驗考慮以下基線:
1)CoT:使用完整的推理鏈來訓(xùn)練語言模型,并進行監(jiān)督微調(diào),推理過程中,模型先生成推理過程再輸出回答。
2)No-CoT:LLM直接生成答案。
3)iCoT:使用語言推理鏈進行訓(xùn)練,并將CoT 「內(nèi)化」。訓(xùn)練過程中,推理鏈開頭的token會逐漸被移除,最后只剩下答案。推理過程中,模型直接預(yù)測答案。
4)Pause token:模型僅使用問答進行訓(xùn)練,沒有推理鏈。但在問題和答案之間插入了特殊token,為模型提供了額外的計算能力來得出答案。
實驗還評估了本文方法的一些變體:
1)w/o curriculum:直接使用最后階段的數(shù)據(jù),不進行多階段訓(xùn)練。
2)w/o thought:使用多階段的訓(xùn)練,逐漸去除語言推理步驟,但不使用任何連續(xù)的潛在思維。這在概念上與iCoT相似,但實際的訓(xùn)練過程與Coconut保持一致。
3)Pause as thought:使用特殊的<pause> token來代替連續(xù)的思考,并應(yīng)用與Coconut相同的多階段訓(xùn)練。
表1顯示了所有數(shù)據(jù)集的總體結(jié)果。Coconut的效率很高,并且在ProntoQA和ProsQA上顯示出比CoT更好的性能。
上圖展示了Coconut將不同痕跡的分布編碼到連續(xù)的思想中,為規(guī)劃密集型推理任務(wù)啟用了更高級的推理模式。
圖5顯示了ProsQA上不同推理方法的比較分析。隨著更多地通過連續(xù)思考(增加k)進行推理,最終答案的準(zhǔn)確性(左)和正確推理過程的速率(右)都會提高。
此外,「幻覺」和「錯誤目標(biāo)」的發(fā)生率會降低,這也說明當(dāng)潛在空間發(fā)生更多推理時,規(guī)劃能力會更好。
圖6顯示了一個案例研究,其中CoT產(chǎn)生幻覺(一個不存在的邊)導(dǎo)致了錯誤的目標(biāo),但Coconut(k=2)成功解決了這個問題。潛在推理可以避免預(yù)先做出艱難的選擇,模型可以在后續(xù)步驟中逐步消除不正確的選項,并在推理結(jié)束時獲得更高的準(zhǔn)確性。