7B羊駝戰(zhàn)勝540B“谷歌版GPT”,MIT用博弈論調(diào)教大模型,無需訓練就能完成
基于博弈論,MIT提出了一種新的大模型優(yōu)化策略。
在其加持之下,7B參數(shù)的Llama在多個數(shù)據(jù)集上超越了540B的“谷歌版GPT”PaLM。
而且整個過程無需對模型進行額外訓練,消耗的算力資源更低。
圖片
這種基于博弈論制定的優(yōu)化策略被稱為均衡排名(Equilibrium Ranking)。
研究團隊將大模型語言解碼過程轉(zhuǎn)化為正則化不完全信息博弈。
這個詞可以拆解成“正則化”和“不完全信息博弈”兩部分,我們將在原理詳解部分展開介紹。
在博弈過程中,模型不斷對生產(chǎn)的答案進行優(yōu)化,讓生成結(jié)果更加符合事實。
實驗結(jié)果表明,在多個測試數(shù)據(jù)集上,均衡排名優(yōu)化方式的效果顯著優(yōu)于其他方式,甚至其他模型。
那么,均衡排序方法具體是如何將博弈論應用到大模型當中的呢?
讓大模型“自我博弈”
前面提到,研究人員將大模型進行語言解碼的過程直接變成了“正則化不完全信息博弈”過程。
不完全信息博弈是整個方法的核心,正則化則是一種避免出錯的機制,我們先來看這種博弈。
具體而言,他們設計了生成器(G)和判別器(D)兩個模塊,它們掌握著不同的信息,扮演不同角色。
生成器根據(jù)環(huán)境(N)隨機給出的“正確性參數(shù)”生成答案;判別器則只負責判斷生成器的答案是否正確,而不看環(huán)境參數(shù)。
如果判別器的判斷與環(huán)境參數(shù)一致,兩者都得到1分獎勵,否則都不得分。
圖片
在執(zhí)行重復的生成和判別當中,模型的目標是達到納什均衡。
在納什均衡策略組合下單方面改變自己的策略,而其他玩家策略不變,都不會提高自身的收益。
舉個例子,張三和李四一起決定晚餐吃什么,選項有火鍋和燒烤,其他已知條件如下:
- 張三對火鍋的滿意度是2分(很喜歡),對燒烤的滿意度為1分(還可以)
 - 李四對燒烤的滿意度是2分,對火鍋的滿意度為1分
 - 兩個人都不想自己單獨吃飯,因此單獨吃飯時滿意度均為0分
 
此時,兩人的選擇共有四種方式,對應的滿意度得分如下表:
圖片
這一情境下,兩人選擇相同時即為最佳策略,此時只要任何一個人單方面改變策略,兩人的滿意度將同時變?yōu)?。
回到均衡排名優(yōu)化法當中,生成器和判別器會先初始化策略,二者的依據(jù)分別基于問題或答案。
這一環(huán)境下的納什均衡如下表所示:
圖片
初始化完成后,生成器和判別器會進行多輪博弈,逐步更新策略,直到迭代終止。
每一次博弈結(jié)束后,分別計算判別器和生成器的得分和最優(yōu)策略得分的差值,稱為“后悔值”。
然后逐步進行迭代,直到后悔值收斂,逼近納什均衡。
達到納什均衡后,生成器和判別器的策略便確定,會分別對候選答案進行打分,然后進行排序選出最佳答案。
在納什均衡條件下,二者的評分應當是一致的,如果不一致,答案便會被剔除。
不過由于給生成器和判斷器打分的標準是與環(huán)境信息的一致性,而不是客觀事實,因此單純追求達到納什均衡,不一定能保證答案合理。
為了避免二者同時出錯的情況出現(xiàn),開發(fā)者還引入了正則化糾錯機制。
首先是向生成器和判別器基于客觀事實的先驗策略,而不是任由其隨機初始化。
這些先驗策略是生成器和判別器生成策略的“金科玉律”,引導了策略的優(yōu)化方向。
圖片
在此還有一種KL懲罰策略,當新的策略出現(xiàn)時,會計算其與初始策略的KL散度(又叫相對熵)。
KL散度描述了二者之間的相關性,數(shù)值越大,相關性越低。
假設P(x)和Q(x)分別是隨機變量X上的兩個概率分布,則在離散和連續(xù)的情形下,KL散度分別為:
圖片
這一結(jié)果會加入到生成新策略的函數(shù)當中,避免了最終生成的結(jié)果偏離客觀事實。
如下式所示,獎勵函數(shù)U中包含了KL散度項,并設置了懲罰系數(shù)λ(>0)。
圖片
當KL散度越大,也就是和客觀事實偏差越大時,模型獲得的獎勵分數(shù)將會降低。
這樣一來,當生成器和判別器結(jié)果一致卻不符合事實時,相關結(jié)果不會獲得高評分,也就不會成為最終答案。
憑借著這樣的策略,研究團隊用更低的消耗讓7B的Llama取得了優(yōu)異的成績。
部分能力超越“谷歌版GPT”
總的來說,均衡排序優(yōu)化后的Llama在常識推理、閱讀理解、數(shù)學和對話任務中的表現(xiàn)都十分出色。
選擇題方面,同樣是Llama,經(jīng)均衡排名方法優(yōu)化之后,模型在MMLU等多個數(shù)據(jù)集上的成績都排在比較靠前的位置。
圖片
問答題方面,均衡排名策略優(yōu)化后的13B Llama在TruthfulQA數(shù)據(jù)集中取得了最佳成績,7B版本也與第一名相差無幾。
圖片
除了文本相關的理解和推理,模型在數(shù)學方面也達到了較高水平。
7B Llama模型的諸多優(yōu)化方式中,均衡排序取得了GSM8K測試的最好成績。
圖片
均衡排序方法不僅是諸多Llama優(yōu)化方式中的佼佼者,優(yōu)化后的Llama成績也超過了其他模型。
在ARC數(shù)據(jù)集的Challenge分集和RACE數(shù)據(jù)集的High分集上,Llama-7B+均衡排序的準確率分別為58.3%和56.4%,顯著超越了PaLM-540B的53.0%和49.1%。
更多具體細節(jié),可以到原論文中一探究竟。















 
 
 



















 
 
 
 