ICLR 2025 Spotlight | 慕尼黑工業(yè)大學(xué)&北京大學(xué):邁向無沖突訓(xùn)練的ConFIG方法
本文由慕尼黑工業(yè)大學(xué)與北京大學(xué)聯(lián)合團隊撰寫。第一作者劉強為慕尼黑工業(yè)大學(xué)博士生。第二作者楚夢渝為北京大學(xué)助理教授,專注于物理增強的深度學(xué)習(xí)算法,以提升數(shù)值模擬的靈活性及模型的準(zhǔn)確性和泛化性。通訊作者 Nils Thuerey 教授(慕尼黑工業(yè)大學(xué))長期研究深度學(xué)習(xí)與物理模擬,尤其是流體動力學(xué)模擬的結(jié)合,并曾因高效流動特效模擬技術(shù)獲奧斯卡技術(shù)獎。目前,其團隊重點關(guān)注可微物理模擬及物理應(yīng)用中的先進生成式模型。
在深度學(xué)習(xí)的多個應(yīng)用場景中,聯(lián)合優(yōu)化多個損失項是一個普遍的問題。典型的例子包括物理信息神經(jīng)網(wǎng)絡(luò)(Physics-Informed Neural Networks, PINNs)、多任務(wù)學(xué)習(xí)(Multi-Task Learning, MTL)和連續(xù)學(xué)習(xí)(Continual Learning, CL)。然而,不同損失項的梯度方向往往相互沖突,導(dǎo)致優(yōu)化過程陷入局部最優(yōu)甚至訓(xùn)練失敗。
目前,主流的方法通常通過調(diào)整損失權(quán)重來緩解沖突。例如在物理信息神經(jīng)網(wǎng)絡(luò)中,許多研究從數(shù)值剛度、損失的收斂速度差異和神經(jīng)網(wǎng)絡(luò)的初始化角度提出了許多權(quán)重方法。然而,盡管這些方法聲稱具有更高的解的精度,但目前對于最優(yōu)的加權(quán)策略尚無共識。
針對這一問題,來自慕尼黑工業(yè)大學(xué)和北京大學(xué)的聯(lián)合研究團隊提出了 ConFIG(Conflict-Free Inverse Gradients,無沖突逆梯度)方法,為多損失項優(yōu)化提供了一種穩(wěn)定、高效的優(yōu)化策略。ConFIG 提供了一種優(yōu)化梯度,能夠防止由于沖突導(dǎo)致優(yōu)化陷入某個特定損失項的局部最小值。ConFIG 方法可以在數(shù)學(xué)上證明其收斂特性并具有以下特點:
- 最終更新梯度
與所有損失項的優(yōu)化梯度均不沖突。
在每個特定損失梯度上的投影長度是均勻的,可以確保所有損失項以相同速率進行優(yōu)化。
長度可以根據(jù)損失項之間的沖突程度自適應(yīng)調(diào)整。
此外,ConFIG 方法還引入了一種基于動量的變種。通過計算并緩存每個損失項梯度的動量,可以避免在每次訓(xùn)練迭代中計算所有損失項的梯度。結(jié)果表明,基于動量的 ConFIG 方法在顯著降低訓(xùn)練成本的同時保證了優(yōu)化的精度。
想深入了解 ConFIG 的技術(shù)細(xì)節(jié)?我們已經(jīng)為你準(zhǔn)備好了完整的論文、項目主頁和代碼倉庫!
- 論文地址:https://arxiv.org/abs/2408.11104
- 項目主頁:https://tum-pbs.github.io/ConFIG/
- GitHub: https://github.com/tum-pbs/ConFIG
ConFIG: 無沖突逆梯度方法
目標(biāo):給定個損失函數(shù)
,其對應(yīng)梯度為
。我們希望找到一個優(yōu)化方向
,使其滿足:
。即所有損失項在該方向上都能減少,從而避免梯度沖突。
無沖突優(yōu)化區(qū)間
假設(shè)存在一個無沖突更新梯度,我們可以引入一個新的矢量。由于
是一個無沖突梯度,
應(yīng)為一個正向分量矢量。同樣地,我們也可以預(yù)先定義一個正向分量矢量
,然后直接通過矩陣的逆運算求得無沖突更新梯度
,即
。通過給定不同的正向分量矢量
,我們得到由一系列不同
組成的無沖突優(yōu)化區(qū)間。
確定唯一優(yōu)化梯度
盡管通過簡單求逆可以獲得一個無沖突更新區(qū)間,我們需要進一步確定唯一的無沖突梯度用于優(yōu)化。在 ConFIG 方法中,我們從方向和幅度兩個方面進一步限定了最終用于優(yōu)化更新的梯度:
- 具體優(yōu)化方向:相比于直接求解梯度矩陣的逆,ConFIG 方法求解了歸一化梯度矩陣的逆,即
,其中
表示第
個梯度向量的單位向量??梢宰C明,變換后
矢量的每個分量代表了每個梯度
與最終更新梯度
之間的余弦相似度。因此,通過設(shè)定
分量的不同值可以直接控制最終更新梯度對于每個損失梯度的優(yōu)化速率。在 ConFIG 中,
被設(shè)定為單位矢量以確保每個損失具有相同的優(yōu)化強度從而避免某些損失項的優(yōu)化被忽略。
- 優(yōu)化梯度大小:此外,ConFIG 方法還根據(jù)梯度沖突程度調(diào)整步長。當(dāng)梯度方向較一致時,加快更新;當(dāng)梯度沖突嚴(yán)重時,減小更新幅度:
, 其中
為每個梯度與最終更新方向之間的余弦相似度。
ConFIG 方法獲得最終無沖突優(yōu)化方向的計算過程可以總結(jié)為:
原論文中給出了上述 ConFIG 更新收斂性的嚴(yán)格證明。同時,我們還可以證明只要參數(shù)空間的維度大于損失項的個數(shù),ConFIG 運算中的逆運算總是可行的。
M-ConFIG: 結(jié)合動量加速訓(xùn)練
ConFIG 方法引入了矩陣的逆運算,這將帶來額外的計算成本。然而與計算每個損失的梯度帶來的計算成本,其并不顯著。在包括 ConFIG 在內(nèi)的基于梯度的方法中,總是需要額外的反向傳播步驟獲得每個梯度相對于訓(xùn)練參數(shù)的梯度。這使得基于梯度的方法的計算成本顯著高于標(biāo)準(zhǔn)優(yōu)化過程和基于權(quán)重的方法。為此,我們引入了 M-ConFIG 方法,使用動量加速優(yōu)化:
- 使用梯度的動量(指數(shù)移動平均)代替梯度進行 ConFIG 運算。
- 在每次優(yōu)化迭代中,僅對一個或部分損失進行反向傳播以更新動量。其它損失項的動量采用之前迭代步的歷史值。
在實際應(yīng)用中,M-ConFIG 的計算成本往往低于標(biāo)準(zhǔn)更新過程或基于權(quán)重的方法。這是由于反向傳播一個子損失往往要比反向傳播總損失
更快。這在物理信息神經(jīng)網(wǎng)絡(luò)中尤為明顯,因為邊界上的采樣點通常遠(yuǎn)少于計算域內(nèi)的采樣點。在我們的實際測試中,M-ConFIG 的平均計算成本為基于權(quán)重方法的 0.56 倍。
結(jié)果:更快的收斂,更優(yōu)的預(yù)測
物理信息神經(jīng)網(wǎng)絡(luò)
在物理信息神經(jīng)網(wǎng)絡(luò)中,用神經(jīng)網(wǎng)絡(luò)的自動微分來近似偏微分方程的時空間導(dǎo)數(shù)。偏微分方程的殘差項與邊界條件和初始條件被視作不同的損失項在訓(xùn)練過程中進行聯(lián)合優(yōu)化。我們在多個經(jīng)典的物理神經(jīng)信息網(wǎng)絡(luò)中測試了 ConFIG 方法的表現(xiàn)。
結(jié)果顯示,在相同訓(xùn)練迭代次數(shù)下,ConFIG 方法是唯一一個相比于標(biāo)準(zhǔn) Adam 方法始終獲得正向提升的方法。對每個損失項變化的單獨分析表明,ConFIG 方法在略微提高 PDE 訓(xùn)練殘差的同時大幅降低了邊界和初始條件損失
,實現(xiàn)了 PDE 訓(xùn)練精度的整體提升。
相同迭代步數(shù)下不同方法在 PINNs 測試中相比于 Adam 優(yōu)化器的相對性能提升
不同損失項隨著訓(xùn)練周期的變化情況
在實際應(yīng)用中,相同訓(xùn)練時間下的模型準(zhǔn)確性可能更為重要。M-ConFIG 方法通過使用動量近似梯度帶來的運算速度提升可以使其充分發(fā)揮潛力。在相同訓(xùn)練時間內(nèi),M-ConFIG 方法的測試結(jié)果優(yōu)于其他所有方法,甚至高于常規(guī)的 ConFIG 方法。
此外,我們還在最具有挑戰(zhàn)性的三維 Beltrami 流動中進一步延長訓(xùn)練時間來更加深入地了解 M-ConFIG 方法的性能。結(jié)果表明,M-ConFIG 方法并非僅在優(yōu)化初始階段帶來顯著的性能改善,而是在整個優(yōu)化過程中都持續(xù)改善優(yōu)化的過程。
相同訓(xùn)練時間下不同方法在 PINNs 測試中相比于 Adam 優(yōu)化器的相對性能提升
三維 Beltrami 流動案例中預(yù)測誤差隨著訓(xùn)練時間的變化
多任務(wù)學(xué)習(xí)
我們還測試了 ConFIG 方法在多任務(wù)學(xué)習(xí)(MTL)方面的表現(xiàn)。我們采用經(jīng)典的 CelebA 數(shù)據(jù)集,其包含 20 萬張人臉圖像并標(biāo)注了 40 種不同的面部二元屬性。對每張人像面部屬性的學(xué)習(xí)是一個非常有挑戰(zhàn)的 40 項損失的多任務(wù)學(xué)習(xí)。
實驗結(jié)果表明,ConFIG 方法或 M-ConFIG 方法在平均 F1 分?jǐn)?shù)、平均排名
中均表現(xiàn)最佳。其中,對于 M-ConFIG 方法,我們在一次迭代中更新 30 個動量而不僅更新一個動量。這是因為當(dāng)任務(wù)數(shù)量增加時,單個動量更新時間的間隔較長,歷史動量信息難以準(zhǔn)確捕捉梯度的變化。動量信息的滯后會逐漸抵消 M-ConFIG 方法更高訓(xùn)練效率帶來的性能提升。
在我們的測試中,當(dāng)任務(wù)數(shù)量等于 10 時,M-ConFIG 方法在相同訓(xùn)練時間下的性能就已經(jīng)弱于 ConFIG 方法。增加單次迭代過程中的動量更新次數(shù)可以顯著緩解這種性能下降。在標(biāo)準(zhǔn)的 40 任務(wù) CelebA 訓(xùn)練中將動量更新次數(shù)提升到 20 時,M-ConFIG 方法的性能已經(jīng)接近 ConFIG 方法,而訓(xùn)練時間僅為 ConFIG 方法的 56%。當(dāng)更新步數(shù)達(dá)到 30 時,其性能甚至可以優(yōu)于 ConFIG 方法。
ConFIG 方法在 CelebA 人臉屬性數(shù)據(jù)集中的表現(xiàn)
結(jié)論
在本研究中,我們提出了 ConFIG 方法來解決不同損失項之間的訓(xùn)練沖突。ConFIG 方法通過確保最終更新梯度與每個子梯度之間的正點積來確保無沖突學(xué)習(xí)。此外,我們還發(fā)展了一種基于動量的方法,用交替更新的動量代替梯度,顯著提升了訓(xùn)練效率。ConFIG 方法有望為眾多包含多個損失項的深度學(xué)習(xí)任務(wù)帶來巨大的性能提升。