車輛意圖預(yù)測(cè)中一種基于因果時(shí)間序列的域泛化方法
arXiv上2021年12月3日上傳的論文“Causal-based Time Series Domain Generalization for Vehicle Intention Prediction”,作者來(lái)自伯克利分校。
準(zhǔn)確預(yù)測(cè)交通參與者的行為是自動(dòng)駕駛車輛的基本能力。由于在動(dòng)態(tài)變化的環(huán)境中導(dǎo)航,無(wú)論在何處以及遇到何種駕駛環(huán)境,都需要進(jìn)行準(zhǔn)確的預(yù)測(cè)。當(dāng)自動(dòng)駕駛車輛部署在現(xiàn)實(shí)世界中時(shí),對(duì)未知域的泛化能力對(duì)預(yù)測(cè)模型至關(guān)重要。本文提出一種基于因果的時(shí)間序列域泛化(causal-based time series domain generalization,CTSDG)模型。構(gòu)建一個(gè)車輛意圖預(yù)測(cè)任務(wù)的結(jié)構(gòu)化因果模型,學(xué)習(xí)用于域泛化de輸入駕駛數(shù)據(jù)不變表征。進(jìn)一步將遞歸潛變量 (latent variable) 模型集成到結(jié)構(gòu)化因果模型中,更好地從時(shí)間序列輸入數(shù)據(jù)中捕獲時(shí)間潛在依賴性。
該文章被Neurips 2021的workshop on Distribution Shifts接收。
因果關(guān)系側(cè)重于表示數(shù)據(jù)生成過(guò)程的結(jié)構(gòu)知識(shí),允許干預(yù)和更改,有助于理解和解決當(dāng)前機(jī)器學(xué)習(xí)方法的一些局限性。事實(shí)上,結(jié)合或?qū)W習(xí)環(huán)境結(jié)構(gòu)化知識(shí)的機(jī)器學(xué)習(xí)模型已被證明更有效,泛化效果更好。對(duì)于駕駛員來(lái)說(shuō),在不同域的交互方式也應(yīng)該處理一些不變的結(jié)構(gòu)化關(guān)系,因?yàn)檫@樣可以快速調(diào)整駕駛技能以適應(yīng)新的場(chǎng)景。因此,作者構(gòu)建了一個(gè)用于車輛意圖預(yù)測(cè)任務(wù)的結(jié)構(gòu)化因果模型(SCM),以學(xué)習(xí)輸入駕駛數(shù)據(jù)的不變表征,用于域泛化。
如圖是CTSDG的框架圖:其中觀測(cè)量顯示為陰影節(jié)點(diǎn);潛量節(jié)點(diǎn)是透明的。有向黑邊表示因果關(guān)系;虛線雙向邊表示相關(guān)性;空心虛線箭頭表示包含關(guān)系。藍(lán)線表示推理過(guò)程,綠線表示生成過(guò)程,紅線表示遞推過(guò)程。
域(D)包含地圖屬性的不同組合,如道路拓?fù)?、速度限制和交通?guī)則;事件(E)表示兩個(gè)車輛交互事件相關(guān)的可觀察變量,包括初始交互狀態(tài)和交互長(zhǎng)度等信息;駕駛員(O)表示每個(gè)駕駛員的駕駛偏好或駕駛風(fēng)格相關(guān)的不可觀察變量,例如攻擊性級(jí)別和他/她是否遵守交通規(guī)則;因果特征(XC)表示駕駛員的個(gè)人行為O及其交互狀態(tài)E中衍生出的高級(jí)時(shí)間相關(guān)因果特征,使用這些特征來(lái)標(biāo)記意圖;非因果特征(XNC)表示域相關(guān)特征,不僅包含域特定信息D,還包含事件E本身。例如交互事件信息,兩輛車開始相互注意時(shí)或道路協(xié)商開始時(shí),與道路幾何和交通規(guī)則相關(guān);輸入數(shù)據(jù)(X)是包含序貫多變量數(shù)據(jù)的車輛交互軌跡,由因果特征XC和非因果XNC特征混合構(gòu)成;潛變量(Z)表示從時(shí)間序列車輛交互數(shù)據(jù)中提取出與時(shí)間相關(guān)的潛表示;標(biāo)簽(Y)是車輛意圖標(biāo)簽。
采用一個(gè)表征函數(shù) q去映射輸入空間到一個(gè)潛空間,一個(gè)假設(shè)函數(shù) 映射潛空間到XC(因果)。合在一起,得到期望的意圖預(yù)測(cè)器f,而對(duì)應(yīng)的預(yù)測(cè)損失寫成
XC(因果)需要滿足一個(gè)不變性條件,考慮到駕駛員意圖的不可觀測(cè)性,得到一個(gè)預(yù)測(cè)損失函數(shù)如下:
上面匹配函數(shù)需要優(yōu)化,假設(shè)XC不同域的相同類輸入之間距離是受限的,可定義一個(gè)對(duì)比表征學(xué)習(xí)損失函數(shù),最小化不同域的同類輸入距離,即
作者把Variational Recurrent Neural Networks (VRNN)集成入CTSDG 模型,這樣目標(biāo)函數(shù)變成
最后的總目標(biāo)函數(shù)為
如下是CTSDG的偽代碼算法:
實(shí)驗(yàn)比較結(jié)果如下:
其中基準(zhǔn)方法包括
傳統(tǒng)域泛化任務(wù)的方法7個(gè)
- ERM:Oracle Inequalities in Empirical Risk Minimization and Sparse Recover Problems,2011
- IRM:Invariant risk minimization, 2019
- CCSA:Unified deep supervised domain adaptation and generalization,2017
- Mixup:mixup: Beyond empirical risk minimization,2017
- DANN:Domain-adversarial training of neural networks. 2016
- C-DANN:Deep domain generalization via conditional invariant adversarial networks,2018
- VRADA:Variational recurrent adversarial deep domain adaptation,2017
行為預(yù)測(cè)任務(wù)的方法3個(gè)
- CVAE:Desire: Distant future prediction in dynamic scenes with interacting agents,2017
- GAN:Social gan Socially acceptable trajectories with generative adversarial networks,2018
- GNN:Neural relational inference for interacting systems,2018
如下表是其中實(shí)驗(yàn)定義的域信息:
具體圖示如下:
本文雖然展示了提出的用于車輛意圖預(yù)測(cè)任務(wù)的域泛化方法性能,但CTSDG方法也可以直接或間接地用于軌跡預(yù)測(cè)任務(wù)。最直接的方法是將軌跡預(yù)測(cè)問(wèn)題作為一組不同軌跡的分類,能夠a)確保狀態(tài)空間的預(yù)期覆蓋級(jí)別,b)消除動(dòng)態(tài)不可行的軌跡,以及c)避免模式崩潰問(wèn)題。此外,意圖分類的準(zhǔn)確性在回歸任務(wù)中起著重要作用。與無(wú)條件的預(yù)測(cè)相比,目標(biāo)/意圖條件軌跡預(yù)測(cè)可以改進(jìn)智體聯(lián)合和單個(gè)智體的預(yù)測(cè)。條件預(yù)測(cè)也可用于規(guī)劃任務(wù)。因此,擁有一個(gè)準(zhǔn)確域泛化的意圖/目標(biāo)預(yù)測(cè)器是開發(fā)最優(yōu)軌跡預(yù)測(cè)器和運(yùn)動(dòng)規(guī)劃器的前提。