推理時(shí)也能做偏好優(yōu)化,無(wú)需額外重訓(xùn)練,來(lái)自上海AI Lab港中文等
隨著大語(yǔ)?模型(LLMs)在各類任務(wù)中展現(xiàn)出令人矚目的能力,如何確保它們?成的回復(fù)既符合預(yù)期又安全,始終是?項(xiàng)關(guān)鍵挑戰(zhàn)。
傳統(tǒng)的偏好對(duì)??法,如基于?類反饋的強(qiáng)化學(xué)習(xí)(RLHF)和直接偏好優(yōu)化(DPO),依賴于訓(xùn)練過(guò)程中的模型參數(shù)更新,但在?對(duì)不斷變化的數(shù)據(jù)和需求時(shí),缺乏?夠的靈活性來(lái)適應(yīng)這些變化。
為了突破這?瓶頸,上海人工智能實(shí)驗(yàn)室、香港中文大學(xué)等聯(lián)合提出了推理時(shí)偏好優(yōu)化(TPO)方法,通過(guò)在推理階段與獎(jiǎng)勵(lì)模型交互,借助可解釋的文本反饋,迭代優(yōu)化模型輸出,實(shí)現(xiàn)了即時(shí)的模型對(duì)?,??需重新訓(xùn)練。
實(shí)驗(yàn)結(jié)果表明,TPO能夠有效提升未對(duì)?模型的表現(xiàn),甚?超越經(jīng)過(guò)訓(xùn)練的對(duì)?模型,為模型偏好對(duì)?提供了?種全新的思路。
△訓(xùn)練時(shí)偏好優(yōu)化VS推理時(shí)偏好優(yōu)化
TPO特點(diǎn)
(1)推理時(shí)對(duì)?、?需訓(xùn)練:TPO通過(guò)與獎(jiǎng)勵(lì)模型的推理階段交互,實(shí)現(xiàn)即時(shí)對(duì)?偏好,無(wú)需更新模型參數(shù)。
(2)基于?本反饋:TPO使?可解釋的文本反饋(而非純數(shù)值梯度)來(lái)指導(dǎo)優(yōu)化,讓模型“理解?并“執(zhí)行”文本評(píng)價(jià)。
(3)優(yōu)于傳統(tǒng)?法:在推理階段,未對(duì)?的模型(例如Llama-3.1-70B-SFT)經(jīng)過(guò)數(shù)次TPO迭代,能夠持續(xù)逼近獎(jiǎng)勵(lì)模型的偏好。在多個(gè)基準(zhǔn)測(cè)試中,其表現(xiàn)甚至超越了已在訓(xùn)練時(shí)對(duì)?的版本(例如Llama-3.1-70B-Instruct)。
(4)靈活適應(yīng)性:TPO能夠靈活應(yīng)對(duì)不斷變化的數(shù)據(jù)和需求,具有較強(qiáng)的適應(yīng)性,并且能夠在資源有限的環(huán)境下?效運(yùn)?。
研究方法
為實(shí)現(xiàn)這??標(biāo),已有多種方法用來(lái)實(shí)現(xiàn)評(píng)分函數(shù),如RLHF和DPO通過(guò)訓(xùn)練時(shí)偏好優(yōu)化來(lái)對(duì)??類偏好。這些?法通過(guò)基于梯度的?法(如隨機(jī)梯度下降,SGD)優(yōu)化模型參數(shù)(如神經(jīng)?絡(luò)中的權(quán)重θ),使得?成符合?類偏好的輸出概率更?。每次更新的步驟如下:
TPO通過(guò)解釋和執(zhí)行文本損失和文本梯度,為模型生成的回復(fù)提供可解釋的優(yōu)化信號(hào)。
如圖所示,TPO包含四個(gè)關(guān)鍵組件,類似于標(biāo)準(zhǔn)的梯度優(yōu)化?法:變量定義、損失計(jì)算、梯度計(jì)算和變量?jī)?yōu)化。
研究人員使用獎(jiǎng)勵(lì)模型R作為人類偏好的代理,提供生成回復(fù)質(zhì)量的反饋。在推理時(shí)對(duì)?過(guò)程中,系統(tǒng)通過(guò)迭代調(diào)整輸出,使其逐步更符合獎(jiǎng)勵(lì)模型的偏好。
△測(cè)試時(shí)間偏好優(yōu)化(TPO)框架(AlpacaEval2的真實(shí)示例)
該過(guò)程最多進(jìn)行D次迭代,類似于訓(xùn)練過(guò)程,稱為推理時(shí)訓(xùn)練(test-time training)。最終,選擇緩存中評(píng)分最高的回復(fù)作為最終輸出。
實(shí)驗(yàn)與結(jié)果
策略模型
- 未對(duì)齊模型:Llama-3.1-70B-SFT
- 已對(duì)齊模型:
-Llama-3.1-70B-Instruct
-Llama-3.1-70B-DPO(UltraFeedback訓(xùn)練得來(lái))
獎(jiǎng)勵(lì)模型
- FsfairX-LLaMA3-RM-v0.1
- Llama-3.1-Tulu-3-8B-RM
benchmark與評(píng)價(jià)指標(biāo)
- 指令跟隨:Alpaca Eval 2(原始勝率WR和長(zhǎng)度控制勝率LC)和ArenaHard(勝率WR)
- 偏好對(duì)齊:HH-RLHF(采樣500條,F(xiàn)sfairX-LLaMA3-RM-v0.1的平均獎(jiǎng)勵(lì)分?jǐn)?shù))
- 安全:BeaverTails-Evaluation(FsfairX-LLaMA3-RM-v0.1的平均獎(jiǎng)勵(lì)分?jǐn)?shù))XSTest(WildGuard的準(zhǔn)確率)
- 數(shù)學(xué)能力:MATH-500(使用0-shot配置和CoT提示,pass@1準(zhǔn)確率)
推理時(shí)訓(xùn)練效果
TPO在推理時(shí)對(duì)模型進(jìn)行優(yōu)化,通過(guò)少量的迭代步數(shù)逐漸擬合獎(jiǎng)勵(lì)模型偏好,顯著提升未對(duì)齊模型的性能,使其達(dá)到與對(duì)齊模型相當(dāng)?shù)乃?;在已?duì)齊模型上,TPO進(jìn)一步增強(qiáng)了對(duì)齊效果,而Revision版本(迭代優(yōu)化選定回復(fù)而不參考被拒絕回復(fù))的提升有限。
benchmark性能
TPO能夠顯著提升模型性能指標(biāo),未對(duì)齊模型通過(guò)TPO超越了訓(xùn)練時(shí)對(duì)齊的模型,而對(duì)齊模型在經(jīng)過(guò)TPO迭代后也獲得了進(jìn)一步的優(yōu)化。D和N分別表示最大迭代次數(shù)和樣本數(shù)量。
* 表示使用獎(jiǎng)勵(lì)模型FsfairX-LLaMA3-RM-v0.1優(yōu)化的模型,而?表示Llama-3.1-Tulu-3-8B-RM。
推理穩(wěn)定性
TPO能夠有效地根據(jù)獎(jiǎng)勵(lì)模型的反饋調(diào)整模型輸出,顯著改善推理穩(wěn)定性,表現(xiàn)為采樣樣本的獎(jiǎng)勵(lì)分?jǐn)?shù)標(biāo)準(zhǔn)差的降低。
TPO的特性分析
TPO的寬度:增加TPO的搜索寬度(即每次TPO迭代中采樣的回復(fù)數(shù)量)能夠顯著提升性能,直到達(dá)到飽和。
TPO的深度:增加TPO的搜索深度比單純?cè)黾訕颖緮?shù)量更有效地發(fā)現(xiàn)更高質(zhì)量的回復(fù)。
TPO的計(jì)算成本:TPO無(wú)需更改模型參數(shù),與訓(xùn)練時(shí)偏好優(yōu)化相比,在計(jì)算成本上具有顯著優(yōu)勢(shì)。TPO的計(jì)算成本(FLOPs)僅為一輪DPO訓(xùn)練(64,000條數(shù)據(jù))所需開(kāi)銷的0.01%。而Instruct模型通常在百萬(wàn)級(jí)語(yǔ)料上多輪迭代,訓(xùn)練成本遠(yuǎn)高于DPO,進(jìn)一步凸顯了TPO在相對(duì)計(jì)算成本方面的優(yōu)勢(shì)。
TPO的指令跟隨前提:TPO的成功依賴于策略模型具備基礎(chǔ)的指令跟隨能力,因?yàn)槟P捅仨殰?zhǔn)確解釋和響應(yīng)數(shù)值形式的獎(jiǎng)勵(lì)模型偏好。
總結(jié)
提出推理時(shí)偏好優(yōu)化(TPO)方法,通過(guò)在推理過(guò)程中與獎(jiǎng)勵(lì)模型交互,將獎(jiǎng)勵(lì)模型信號(hào)轉(zhuǎn)化為”文本損失”和”文本梯度”,以此迭代優(yōu)化模型輸出。
無(wú)需重新訓(xùn)練,即可讓大語(yǔ)言模型與人類偏好對(duì)齊。TPO為訓(xùn)練時(shí)偏好優(yōu)化提供了輕量、高效且可解釋的替代方案,充分利用了大語(yǔ)言模型在推理時(shí)的固有能力。
推理時(shí)優(yōu)化的靈活性:TPO通過(guò)即時(shí)文本反饋實(shí)現(xiàn)推理時(shí)對(duì)?,增強(qiáng)了模型在多樣化場(chǎng)景中的適應(yīng)能力,能快速響應(yīng)變化的需求和任務(wù)的變化。此外,TPO充分利用大語(yǔ)言模型在推理、指令跟隨等方面的內(nèi)在優(yōu)勢(shì),從?實(shí)現(xiàn)了更靈活的偏好對(duì)?。
未來(lái)研究?向:未來(lái)的研究可聚焦于優(yōu)化文本交互?法,使其能夠適應(yīng)更多專門任務(wù),探索更魯棒的獎(jiǎng)勵(lì)模型以提升偏好捕捉能?,并研究如何提升較弱模型在TPO中的表現(xiàn),從而進(jìn)一步拓展其應(yīng)用場(chǎng)景和優(yōu)化效果。
論?鏈接:https://arxiv.org/abs/2501.12895
Github鏈接:https://github.com/yafuly/TPO
Huggingface鏈接:https://huggingface.co/papers/2501.12895