性能大漲20%!中科大「狀態(tài)序列頻域預(yù)測」方法:表征學(xué)習(xí)樣本效率max
強化學(xué)習(xí)算法(Reinforcement Learning, RL)的訓(xùn)練過程往往需要大量與環(huán)境交互的樣本數(shù)據(jù)作為支撐。然而,現(xiàn)實世界中收集大量的交互樣本通常成本高昂或者難以保證樣本采集過程的安全性,例如無人機空戰(zhàn)訓(xùn)練和自動駕駛訓(xùn)練。
為了提升強化學(xué)習(xí)算法在訓(xùn)練過程中的樣本效率,一些研究者們借助于表征學(xué)習(xí)(representation learning),設(shè)計了預(yù)測未來狀態(tài)信號的輔助任務(wù),使得表征能從原始的環(huán)境狀態(tài)中編碼出與未來決策相關(guān)的特征。
基于這個思路,該工作設(shè)計了一種預(yù)測未來多步的狀態(tài)序列頻域分布的輔助任務(wù),以捕獲更長遠的未來決策特征,進而提升算法的樣本效率。
該工作標(biāo)題為State Sequences Prediction via Fourier Transform for Representation Learning,發(fā)表于NeurIPS 2023,并被接收為Spotlight。
作者列表:葉鳴軒,匡宇飛,王杰*,楊睿,周文罡,李厚強,吳楓
論文鏈接:https://openreview.net/forum?id=MvoMDD6emT
代碼鏈接:https://github.com/MIRALab-USTC/RL-SPF/
研究背景與動機
深度強化學(xué)習(xí)算法在機器人控制[1]、游戲智能[2]、組合優(yōu)化[3]等領(lǐng)域取得了巨大的成功。但是,當(dāng)前的強化學(xué)習(xí)算法仍存在「樣本效率低下」的問題,即機器人需要大量與環(huán)境交互的數(shù)據(jù)才能訓(xùn)得性能優(yōu)異的策略。
為了提升樣本效率,研究者們將目光投向于表征學(xué)習(xí),希望訓(xùn)得的表征能從環(huán)境的原始狀態(tài)中提取出充足且有價值的特征信息,從而提升機器人對狀態(tài)空間的探索效率。
基于表征學(xué)習(xí)的強化學(xué)習(xí)算法框架
在序列決策任務(wù)中,「長期的序列信號」相對于單步信號包含更多有利于長期決策的未來信息。啟發(fā)于這一觀點,一些研究者提出通過預(yù)測未來多步的狀態(tài)序列信號來輔助表征學(xué)習(xí)[4,5]。然而,直接預(yù)測狀態(tài)序列來輔助表征學(xué)習(xí)是非常困難的。
現(xiàn)有的兩類方法中,一類方法通過學(xué)習(xí)單步概率轉(zhuǎn)移模型來逐步地產(chǎn)生單個時刻的未來狀態(tài),以間接預(yù)測多步的狀態(tài)序列[6,7]。但是,這類方法對所訓(xùn)得的概率轉(zhuǎn)移模型的精度要求很高,因為每步的預(yù)測誤差會隨預(yù)測序列長度的增加而積累。
另一類方法通過直接預(yù)測未來多步的狀態(tài)序列來輔助表征學(xué)習(xí)[8],但這類方法需要存儲多步的真實狀態(tài)序列作為預(yù)測任務(wù)的標(biāo)簽,所耗存儲量大。因此,如何有效從環(huán)境的狀態(tài)序列中提取有利于長期決策的未來信息,進而提升連續(xù)控制機器人訓(xùn)練時的樣本效率是需要解決的問題。
為了解決上述問題,我們提出了一種基于狀態(tài)序列頻域預(yù)測的表征學(xué)習(xí)方法(State Sequences Prediction via Fourier Transform, SPF),其思想是利用「狀態(tài)序列的頻域分布」來顯式提取狀態(tài)序列數(shù)據(jù)中的趨勢性和規(guī)律性信息,從而輔助表征高效地提取到長期未來信息。
狀態(tài)序列中的結(jié)構(gòu)性信息分析
我們從理論上證明了狀態(tài)序列存在「兩種結(jié)構(gòu)性信息」,一是與策略性能相關(guān)的趨勢性信息,二是與狀態(tài)周期性相關(guān)的規(guī)律性信息。
馬爾科夫決策過程
在具體分析兩種結(jié)構(gòu)性信息之前,我們先介紹產(chǎn)生狀態(tài)序列的馬爾科夫決策過程(Markov Decision Processes,MDP)的相關(guān)定義。
我們考慮連續(xù)控制問題中的經(jīng)典馬爾可夫決策過程,該過程可用五元組 表示。其中, 為相應(yīng)的狀態(tài)、動作空間, 為獎勵函數(shù), 為環(huán)境的狀態(tài)轉(zhuǎn)移函數(shù), 為狀態(tài)的初始分布, 為折扣因子。此外,我們用 表示策略在狀態(tài) 下的動作分布。
我們將 時刻下智能體所處的狀態(tài)記為 ,所選擇的動作記為 .智能體做出動作后,環(huán)境轉(zhuǎn)移到下一時刻狀態(tài) 并反饋給智能體獎勵 。我們將智能體與環(huán)境交互過程中所得到狀態(tài)、動作對應(yīng)的軌跡記為 ,軌跡服從分布 。
強化學(xué)習(xí)算法的目標(biāo)是最大化未來預(yù)期的累積回報,我們用 表示當(dāng)前策略 和 環(huán)境模型 下的平均累積回報,并簡寫為 ,定義如下:
顯示了當(dāng)前策略 的性能表現(xiàn)。
趨勢性信息
下面我們介紹狀態(tài)序列的「第一種結(jié)構(gòu)性特征」,其涉及狀態(tài)序列和對應(yīng)獎勵序列之間的依賴關(guān)系,能顯示出當(dāng)前策略的性能趨勢。
在強化學(xué)習(xí)任務(wù)中,未來的狀態(tài)序列很大程度上決定了智能體未來采取的動作序列,并進一步?jīng)Q定了相應(yīng)的獎勵序列。因此,未來的狀態(tài)序列不僅包含環(huán)境固有的概率轉(zhuǎn)移函數(shù)的信息,也能輔助表征捕獲反映當(dāng)前策略的走向趨勢。
啟發(fā)于上述結(jié)構(gòu),我們證明了以下定理,進一步論證了這一結(jié)構(gòu)性依賴關(guān)系的存在:
定理一:若獎勵函數(shù)只與狀態(tài)有關(guān),那么對于任意兩個策略 和 ,他們的性能差異可以被這兩個策略所產(chǎn)生的狀態(tài)序列分布差異所控制:
上述公式中, 表示在指定策略和轉(zhuǎn)移概率函數(shù)條件下狀態(tài)序列的概率分布, 表示 范數(shù)。
上述定理表明,兩個策略的性能差異越大,其對應(yīng)的兩個狀態(tài)序列的分布差異也越大。這意味著好策略和壞策略會產(chǎn)生出兩個差異較大的狀態(tài)序列,這進一步說明狀態(tài)序列所包含的長期結(jié)構(gòu)性信息能潛在影響搜索性能優(yōu)異的策略的效率。
另一方面,在一定條件下,狀態(tài)序列的頻域分布差異也能為對應(yīng)的策略性能差異提供上界,具體如以下定理所示:
定理二:若狀態(tài)空間有限維且獎勵函數(shù)是與狀態(tài)有關(guān)的n次多項式,那么對于任意兩個策略 和 ,他們的性能差異可以被這兩個策略所產(chǎn)生的狀態(tài)序列的頻域分布差異所控制:
上述公式中, 表示由策略 所產(chǎn)生的狀態(tài)序列的 次方序列的傅里葉函數(shù), 表示傅里葉函數(shù)的第 個分量。
這一定理表明狀態(tài)序列的頻域分布仍包含與當(dāng)前策略性能相關(guān)的特征。
規(guī)律性信息
下面我們介紹狀態(tài)序列中存在的「第二種結(jié)構(gòu)性特征」,其涉及到狀態(tài)信號之間的時間依賴性,即一段較長時期內(nèi)狀態(tài)序列所表現(xiàn)出的規(guī)律性模式。
在許多的真實場景任務(wù)中,智能體也會表現(xiàn)出周期性行為,因為其環(huán)境的狀態(tài)轉(zhuǎn)移函數(shù)本身就是具有周期性的。以工業(yè)裝配機器人為例,該機器人的訓(xùn)練目標(biāo)是將零件組裝在一起以創(chuàng)造最終產(chǎn)品,當(dāng)策略訓(xùn)練達到穩(wěn)定時,它就會執(zhí)行一個周期性的動作序列,使其能夠有效地將零件組裝在一起。
啟發(fā)于上面的例子,我們提供了一些理論分析,證明了有限狀態(tài)空間中,當(dāng)轉(zhuǎn)移概率矩陣滿足某些假設(shè),對應(yīng)的狀態(tài)序列在智能體達到穩(wěn)定策略時可能表現(xiàn)出「漸近周期性」,具體定理如下:
定理三:對于狀態(tài)轉(zhuǎn)移矩陣為 的有限維狀態(tài)空間 ,假設(shè) 有 個循環(huán)類,對應(yīng)的狀態(tài)轉(zhuǎn)移子矩陣為 。設(shè)這 個矩陣模為1的特征值個數(shù)為 ,則對于任意狀態(tài)的初始分布 ,狀態(tài)分布 呈現(xiàn)出周期為 的漸進周期性。
在MuJoCo任務(wù)中,策略訓(xùn)練達到穩(wěn)定時,智能體也會表現(xiàn)出周期性的運動。下圖中給出了MuJoCo任務(wù)中HalfCheetah智能體在一段時間內(nèi)的狀態(tài)序列示例,可以觀察到明顯的周期性。(更多MuJoCo任務(wù)中帶周期性的狀態(tài)序列示例可參考本論文附錄第E節(jié))
MuJoCo任務(wù)中HalfCheetah智能體在一段時間內(nèi)狀態(tài)所表現(xiàn)出的周期性
時間序列在時域中呈現(xiàn)的信息相對分散,但在頻域中,序列中的規(guī)律性信息以更加集中的形式呈現(xiàn)。通過分析頻域中的頻率分量,我們能顯式地捕獲到狀態(tài)序列中存在的周期性特征。
方法介紹
上一部分中,我們從理論上證明狀態(tài)序列的頻域分布能反映策略性能的好壞,并且通過在頻域上分析頻率分量我們能顯式捕獲到狀態(tài)序列中的周期性特征。
啟發(fā)于上述分析,我們設(shè)計了「預(yù)測無窮步未來狀態(tài)序列傅里葉變換」的輔助任務(wù)來鼓勵表征提取狀態(tài)序列中的結(jié)構(gòu)性信息。
SPF方法損失函數(shù)
下面介紹我們關(guān)于該輔助任務(wù)的建模。給定當(dāng)前狀態(tài) 和動作 ,我們定義未來的狀態(tài)序列期望如下:
我們的輔助任務(wù)訓(xùn)練表征去預(yù)測上述狀態(tài)序列期望的離散時間傅里葉變換(discrete-time Fourier transform, DTFT),即
上述傅里葉變換公式可改寫為如下的遞歸形式:
其中,
其中, 為狀態(tài)空間的維度, 為所預(yù)測的狀態(tài)序列傅里葉函數(shù)的離散化點的個數(shù)。
啟發(fā)于Q-learning中優(yōu)化Q值網(wǎng)絡(luò)的TD-error損失函數(shù)[9],我們設(shè)計了如下的損失函數(shù):
其中, 和 分別為損失函數(shù)要優(yōu)化的表征編碼器(encoder)和傅里葉函數(shù)預(yù)測器(predictor)的神經(jīng)網(wǎng)絡(luò)參數(shù), 為存儲樣本數(shù)據(jù)的經(jīng)驗池。
進一步地,我們可以證明上述的遞歸公式可以表示為一個壓縮映射:
定理四:令 表示函數(shù)族 ,并定義 上的范數(shù)為:
其中 表示矩陣 的第 行向量。我們定義映射 為
則可以證明 為一個壓縮映射。
根據(jù)壓縮映射原理,我們可以迭代地使用算子 ,使得 逼近真實狀態(tài)序列的頻域分布,且在表格型情況(tabular setting)下有收斂性保證。
此外,我們所設(shè)計的損失函數(shù)只依賴于當(dāng)前時刻與下一時刻的狀態(tài),所以無需存儲未來多步的狀態(tài)數(shù)據(jù)作為預(yù)測標(biāo)簽,具有「實施簡單且存儲量低」的優(yōu)點。
SPF方法算法框架
下面我們介紹本論文方法(SPF)的算法框架。
基于狀態(tài)序列頻域預(yù)測的表征學(xué)習(xí)方法(SPF)的算法框架圖
我們將當(dāng)前時刻和下一時刻的狀態(tài)-動作數(shù)據(jù)分別輸入到在線(online)和目標(biāo)(target)表征編碼器(encoder)中,得到狀態(tài)-動作表征數(shù)據(jù),然后將該表征數(shù)據(jù)輸入到傅里葉函數(shù)預(yù)測器(predictor)得到當(dāng)前時刻和下一時刻下的兩組狀態(tài)序列傅里葉函數(shù)預(yù)測值。通過代入這兩組傅里葉函數(shù)預(yù)測值,我們能計算出損失函數(shù)值。
我們通過最小化損失函數(shù)來優(yōu)化更新表征編碼器 和傅里葉函數(shù)預(yù)測器 ,使預(yù)測器的輸出能逼近真實狀態(tài)序列的傅里葉變換,從而鼓勵表征編碼器提取出包含未來長期狀態(tài)序列的結(jié)構(gòu)性信息的特征。
我們將原始狀態(tài)和動作輸入到表征編碼器中,將得到的特征作為強化學(xué)習(xí)算法中actor網(wǎng)絡(luò)和critic網(wǎng)絡(luò)的輸入,并用經(jīng)典強化學(xué)習(xí)算法優(yōu)化actor網(wǎng)絡(luò)和critic網(wǎng)絡(luò)。
實驗結(jié)果
(注:本節(jié)僅選取部分實驗結(jié)果,更詳細的結(jié)果請參考論文原文第6節(jié)及附錄。)
算法性能比較
我們將 SPF 方法在 MuJoCo 仿真機器人控制環(huán)境上測試,對如下 6 種方法進行對比:
- SAC:基于Q值學(xué)習(xí)的soft actor-critic算法[10],一種傳統(tǒng)的RL算法;
- PPO:基于策略優(yōu)化的proximal policy optimization算法[11],一種傳統(tǒng)RL算法;
- SAC-OFE:利用預(yù)測單步未來狀態(tài)的輔助任務(wù)進行表征學(xué)習(xí),以優(yōu)化SAC算法;
- PPO-OFE:利用預(yù)測單步未來狀態(tài)的輔助任務(wù)進行表征學(xué)習(xí),以優(yōu)化PPO算法;
- SAC-SPF:利用預(yù)測無窮步狀態(tài)序列的頻域函數(shù)的輔助任務(wù)進行表征學(xué)習(xí)(我們的方法),以優(yōu)化SAC算法;
- PPO-SPF:利用預(yù)測無窮步狀態(tài)序列的頻域函數(shù)的輔助任務(wù)進行表征學(xué)習(xí)(我們的方法),以優(yōu)化PPO算法;
基于6種MuJoCo任務(wù)的對比實驗結(jié)果
上圖顯示了在 6 種 MuJoCo 任務(wù)中,我們所提出的SPF方法(紅線及橙線)與其他對比方法的性能曲線。結(jié)果顯示,我們所提出的方法相比于其他方法能獲得19.5%的性能提升。
消融實驗
我們對 SPF 方法的各個模塊進行了消融實驗,將本方法與不使用投影器模塊(noproj)、不使用目標(biāo)網(wǎng)絡(luò)模塊(notarg)、改變預(yù)測損失(nofreqloss)、改變特征編碼器網(wǎng)絡(luò)結(jié)構(gòu)(mlp,mlp_cat)時的性能表現(xiàn)做比較。
SPF方法應(yīng)用于SAC算法的消融實驗結(jié)果圖,測試于HalfCheetah任務(wù)
可視化實驗
我們使用 SPF 方法所訓(xùn)練好的預(yù)測器輸出狀態(tài)序列的傅里葉函數(shù),并通過逆傅里葉變換恢復(fù)出的200步狀態(tài)序列,與真實的200步狀態(tài)序列進行對比。
基于傅里葉函數(shù)預(yù)測值恢復(fù)出的狀態(tài)序列示意圖,測試于Walker2d任務(wù)。其中,藍線為真實的狀態(tài)序列示意圖,5條紅線為恢復(fù)出的狀態(tài)序列示意圖,越下方的、顏色越淺的紅線表示利用越久遠的歷史狀態(tài)所恢復(fù)出的狀態(tài)序列。
結(jié)果顯示,即使用更久遠的狀態(tài)作為輸入,恢復(fù)出的狀態(tài)序列也和真實的狀態(tài)序列非常相似,這說明 SPF 方法所學(xué)習(xí)出的表征能有效編碼出狀態(tài)序列中包含的結(jié)構(gòu)性信息。