RNN回歸!Bengio新作大道至簡與Transformer一較高下
在Transformer統(tǒng)治的AI時代之下,
散落在世界各地的「RNN神教」信徒,一直相信并期待著RNN回歸的那天:
圖片
畢竟,憑借強大的順序和上下文感知能力,RNN曾在各種任務(wù)中表現(xiàn)驚艷。
直到后來遭遇了反向訓(xùn)練的瓶頸,因Scaling Law而跌落神壇。
然而,人們并沒有忘記RNN。
圖片
RWKV、Mamba、xLSTM等RNN衍生模型接連出現(xiàn),欲挑戰(zhàn)Transformer之霸主地位。
就在近日,又有重量級人物下場——
深度學(xué)習(xí)三巨頭之一的Yoshua Bengio,帶領(lǐng)團隊推出了全新的RNN架構(gòu),以大道至簡的思想與Transformer一較高下。
圖片
論文地址:https://arxiv.org/pdf/2410.01201v1
研究人員對傳統(tǒng)的兩種RNN架構(gòu)LSTM和GRU,進行了大刀闊斧的改造,從中誕生了兩個新模型:minLSTM和minGRU。
這倆極簡主義的版本到底怎么樣?咱們先看療效。
首先是RNN最大的問題:訓(xùn)練速度。
圖片
上圖展示了幾種模型在T4 GPU上訓(xùn)練花費的時間,以及新模型帶來的加速比。橫軸為輸入數(shù)據(jù)的序列長度,批量大小為64。
可以看到,相比于原版的LSTM和GRU,minLSTM、minGRU和Mamba的運行時間不會隨序列長度而增加(后3個模型的線在左圖中重疊了)。
當序列長度為4096時,新架構(gòu)相對于傳統(tǒng)版本達到了1300多倍的加速比!
相當于原版GRU需要3年才能做完的事情,minGRU一天就搞定了。
那么對線Transformer的戰(zhàn)績?nèi)绾危?/span>
圖片
在本文測試的語言建模任務(wù)中,minGRU和minLSTM分別在600步左右達到最佳性能點。
相比之下,Transformer需要比minGRU多花大概2000步,訓(xùn)練速度慢了約2.5倍。
對此,YC上的網(wǎng)友表示:「我非常喜歡這個新架構(gòu)的簡單性」。
圖片
畢竟,俗話說的好,「最好的PR是那些刪除代碼的PR」。
模型架構(gòu)
下面來感受一下極簡模型的誕生過程。
首先,這是傳統(tǒng)的RNN架構(gòu):
圖片
LSTM在RNN的每個cell中加入了比較復(fù)雜的門控:
圖片
三個門控(input gate、output gate、forget gate)和輸入的分量,都通過線性投影和非線性激活函數(shù)來得出,并且依賴于上一個時刻的隱藏狀態(tài)ht-1。
圖片
這些值再經(jīng)過線性和非線性計算,得到本時刻的輸出ct和隱藏狀態(tài)ht。
GRU在LSTM的基礎(chǔ)上做了一些簡化:
圖片
少了顯式計算ct,用于門控的項也縮減到2個,相應(yīng)的參數(shù)量和計算量也減少了。
圖片
那么我們就從相對簡單的GRU入手,開始改造。
改造的目的是使RNN能夠應(yīng)用并行掃描(Parallel Scan)算法,解決自身訓(xùn)練困難的問題。
簡單來說,就是將網(wǎng)絡(luò)中的計算改造成vt = at ⊙ vt?1 + bt的形式。
minGRU
第一步,公式中含有對之前隱藏狀態(tài)ht-1的依賴,沒辦法用并行掃描,所以把ht-1直接刪掉。
圖片
ht-1沒了,負責調(diào)控ht-1的rt也沒用了,刪掉。
第二步,雙曲正切函數(shù)(tanh)負責限制隱藏狀態(tài)的范圍,并減輕因sigmoid(σ)而導(dǎo)致的梯度消失。
但是現(xiàn)在ht-1和rt都沒了,tanh也失去了存在的意義,刪掉。
圖片
那么最終,minGRU就是下面這三個公式:
圖片
相比于原版,參數(shù)量和計算量再次減少,最重要的是能夠使用并行掃描來顯著加快訓(xùn)練速度。
minLSTM
經(jīng)過上面的敘述,minLSTM的由來就很好理解了。
首先還是去除隱藏狀態(tài)的依賴:
圖片
接著是拿掉相關(guān)的tanh:
圖片
最后,為了保證LSTM輸出的尺度與時間無關(guān),以及hidden state在縮放上與時間無關(guān),還需要刪掉output gate。
output gate沒了,ct也就沒必要單獨存在了,刪掉;剩下的兩個門控通過歸一化來調(diào)配hidden state進入的比例。
圖片
——emmm......好像變成GRU了,算了不管了。
最終改造好的minLSTM是下面這個樣子:
圖片
Were RNNs All We Needed?
全新的RNN搞出來了,能打Transformer嗎?
別急,先打內(nèi)戰(zhàn)證明價值。
除了傳統(tǒng)的RNN(LSTM和GRU),這里特別關(guān)注與Mamba的比較。
首先是訓(xùn)練上的提升:
圖片
實驗在批次大小64的情況下改變序列長度,測量了模型執(zhí)行前向傳遞、計算損失和向后傳遞計算梯度的總運行時間以及內(nèi)存占用。
在運行時間方面,minLSTM、minGRU與Mamba實現(xiàn)了類似的效率。
序列長度為512時的運行時間(超過100次的平均值),分別為 2.97、2.72和2.71毫秒;序列長度為4096時,運行時間分別為3.41、3.25和3.15。
相比之下,LSTM和GRU的運行時間隨序列長度線性增加。所以序列長度為512時,minGRU和minLSTM的訓(xùn)練加速了175倍和235倍;序列長度為4096時,加速比達到了1324和1361。
內(nèi)存方面,利用并行掃描算法時會創(chuàng)建更大的計算圖,所以minGRU、minLSTM和Mamba ,比傳統(tǒng)RNN需要更多的內(nèi)存(大概多出88%)。
——但這并不重要,因為對于RNN來說,訓(xùn)練時間才是瓶頸。
去除隱藏狀態(tài)的效果
minLSTM和minGRU的訓(xùn)練效率是通過降低它們的門控對先前隱藏狀態(tài)的依賴來實現(xiàn)的。
盡管單層minLSTM或minGRU的門控只與輸入有關(guān),而與時間無關(guān),但是在深度學(xué)習(xí)中,模型是通過堆疊模塊來構(gòu)建的。
從第二層開始,minLSTM和minGRU的門也將與時間相關(guān),從而對更復(fù)雜的函數(shù)進行建模。
下表比較了不同層數(shù)的模型在選擇性復(fù)制任務(wù)上的性能。我們可以看到時間依賴性的影響:將層數(shù)增加會大大提高模型的性能。
圖片
訓(xùn)練穩(wěn)定性
層數(shù)的另一個影響是穩(wěn)定性,隨著層數(shù)的增加,精度的方差減小。
此外,盡管minLSTM和minGRU都解決了選擇性復(fù)制任務(wù),但我們可以看到minGRU在經(jīng)驗上是一種比minLSTM更穩(wěn)定的方法(更高的一致性和更低的方差)。
minLSTM丟棄舊信息并添加新信息,使用兩組參數(shù)(forget gate 和input gate)控制比率。在訓(xùn)練期間,兩組參數(shù)會向不同的方向進行調(diào)整,使得比率更難控制和優(yōu)化。相比之下,minGRU的丟棄和添加信息由一組參數(shù)控制,更容易優(yōu)化。
選擇性復(fù)制
選擇性復(fù)制任務(wù)的輸入元素相對于其輸出是隨機間隔的,為了解決這項任務(wù),模型需要執(zhí)行內(nèi)容感知推理,記住相關(guān)token并過濾掉不相關(guān)的token。
圖片
上表將minLSTM和minGRU與可以并行訓(xùn)練的知名RNN模型進行了比較(S4,H3,Hyena和Mamba(S6)),基線結(jié)果引自Mamba論文。
在所有這些基線中,只有Mamba的S6,以及本文的minGRU和minLSTM能夠解決此任務(wù),體現(xiàn)了LSTM和GRU的內(nèi)容感知門控機制。
強化學(xué)習(xí)
下面開始對戰(zhàn)Transformer。
考慮D4RL基準中的MuJoCo運動任務(wù),包括三個環(huán)境:HalfCheetah、Hopper和Walker。
對于每個環(huán)境,模型在三個數(shù)據(jù)質(zhì)量不同的數(shù)據(jù)集上進行訓(xùn)練:Medium(M)、Medium-Replay(M-R)和Medium-Expert(M-E)。
圖片
上表將minLSTM和minGRU與各種決策模型進行了比較,包括原始的Decision Transformer(DT)、Decision S4 (DS4) 、Decision Mamba和Aaren。
由結(jié)果可知,minLSTM和minGRU的性能優(yōu)于Decision S4,與Decision Transformer、Aaren和Mamba相媲美(Decision S4的遞歸轉(zhuǎn)換不是輸入感知的,這會影響它的性能)。就平均分數(shù)而言,minLSTM和minGRU的表現(xiàn)優(yōu)于除Decision Mamba之外的所有基線。
語言建模
最后考慮語言建模任務(wù),使用nanoGPT框架在莎士比亞的作品上訓(xùn)練字符級GPT。
圖片
上圖繪制了具有交叉熵損失的學(xué)習(xí)曲線,可以發(fā)現(xiàn)minGRU、 minLSTM、 Mamba和Transformers分別實現(xiàn)了1.548、1.555、1.575和1.547的可比測試損耗。
Mamba的表現(xiàn)略差于其他模型,但訓(xùn)練速度更快(400步),minGRU和minLSTM分別花費575步和625步。而Transformer直接比minGRU多了2000 步,慢了大概2.5倍。















 
 
 





 
 
 
 