Transformers學(xué)習(xí)上下文強化學(xué)習(xí)的時間差分方法 原創(chuàng)
上下文學(xué)習(xí)指的是模型在推斷時學(xué)習(xí)能力,而不需要調(diào)整其參數(shù)。模型(例如transformers)的輸入包括上下文(即實例-標(biāo)簽對)和查詢實例(即提示)。然后,模型能夠根據(jù)上下文在推斷期間為查詢實例輸出一個標(biāo)簽。上下文學(xué)習(xí)的一個可能解釋是,(線性)transformers的前向傳播在上下文中實現(xiàn)了對實例-標(biāo)簽對的梯度下降迭代。在本文中,研究人員通過構(gòu)造證明了transformers在前向傳播中也能實現(xiàn)時間差異(TD)學(xué)習(xí),并將這一現(xiàn)象稱為上下文TD。在訓(xùn)練transformers使用多任務(wù)TD算法后展示了上下文TD的出現(xiàn),并進行了理論分析。此外,研究人員證明了transformers具有足夠的表達(dá)能力,可以在前向傳播中實現(xiàn)許多其他策略評估算法,包括殘差梯度、帶有資格跟蹤的TD和平均獎勵TD。
上下文學(xué)習(xí)已經(jīng)成為大型語言模型最顯著的能力之一。在上下文學(xué)習(xí)中,模型的輸入(即提示)包括上下文(即實例-標(biāo)簽對)和一個查詢實例。然后,模型在推斷期間(即前向傳播)為查詢實例輸出一個標(biāo)簽。模型輸入和輸出的一個示例可以是:
其中,“5 → number; a → letter”是包含兩個實例-標(biāo)簽對的上下文,“6”是查詢實例。根據(jù)上下文,模型推斷查詢“6”的標(biāo)簽為“number”。值得注意的是,整個過程在模型的推斷時間內(nèi)完成,而不需要調(diào)整模型的參數(shù)。
在(1)中的示例說明了一個監(jiān)督學(xué)習(xí)問題。在經(jīng)典的機器學(xué)習(xí)框架中,這個監(jiān)督學(xué)習(xí)問題通常通過首先基于上下文中的實例-標(biāo)簽對訓(xùn)練一個分類器來解決,使用諸如梯度下降之類的方法,然后要求分類器預(yù)測查詢實例的標(biāo)簽。值得注意的是,研究表明,transformers能夠在前向傳播中實現(xiàn)這個梯度下降訓(xùn)練過程,而不需要調(diào)整任何參數(shù),為上下文學(xué)習(xí)提供了一個可能的解釋。
超越監(jiān)督學(xué)習(xí),智能涉及到順序決策,其中強化學(xué)習(xí)已經(jīng)成為一個成功的范式。transformers在推斷期間能否執(zhí)行上下文RL,以及如何執(zhí)行?為了解決這些問題,研究人員從馬爾可夫獎勵過程MRP中的一個簡單評估問題開始。在MRP中,代理程序在每個時間步中從一個狀態(tài)轉(zhuǎn)換到另一個狀態(tài)。用(S0,S1,S2,...)表示代理訪問的狀態(tài)序列。在每個狀態(tài)下,代理程序會接收到一個獎勵。用(r(S0),r(S1),r(S2),...)表示代理程序在路途中接收到的獎勵序列。評估問題是估計值函數(shù)v,該函數(shù)計算每個狀態(tài)未來代理程序?qū)⑹盏降钠谕偅ㄕ劭郏┆剟睢K璧妮斎胼敵龅囊粋€示例可以是:
引人注目的是,上述任務(wù)與監(jiān)督學(xué)習(xí)根本不同,因為目標(biāo)是預(yù)測值v(s),而不是即時獎勵r(s)。此外,查詢狀態(tài)s是任意的,不必是S3。時間差分學(xué)習(xí)TD是解決這類評估問題(2)的最常用的RL算法。而且眾所周知,TD不是梯度下降。
在這項工作中,研究人員做出了三個主要貢獻。首先,通過構(gòu)造證明transformers具有足夠的表達(dá)能力來在前向傳播中實現(xiàn)TD,這一現(xiàn)象我們稱為上下文TD。換句話說,transformers能夠通過上下文TD在推斷時間內(nèi)解決問題(2)。超越最直接的TD,transformers還可以實現(xiàn)許多其他策略評估算法,包括殘差梯度(Baird,1995)、帶有資格跟蹤的TD(Sutton,1988)和平均獎勵TD(Tsitsiklis和Roy,1999)。特別地,為了實現(xiàn)平均獎勵TD,transformers需要使用多頭注意力和過度參數(shù)化的提示,例如,
這里,“□”充當(dāng)一個虛擬占位符,在推斷期間transformers將使用它作為“記憶”。第二,通過在多個隨機生成的評估問題上訓(xùn)練transformers與TD,實證地證明了在推斷中出現(xiàn)了上下文TD。換句話說,學(xué)習(xí)的transformer參數(shù)與我們在證明中的構(gòu)造非常相符。將這種訓(xùn)練方案稱為多任務(wù)TD。第三,通過展示對于單層transformer,證明了實現(xiàn)上下文TD所需的transformer參數(shù)在多任務(wù)TD訓(xùn)練算法的不變集合的子集中,來彌合理論和實證結(jié)果之間的差距。
論文:https://arxiv.org/pdf/2405.13861
本文轉(zhuǎn)載自公眾號AIGC最前線
原文鏈接:??https://mp.weixin.qq.com/s/voNZDTww7E5ec1hUwulztw??
