圖靈獎(jiǎng)得主Yoshua Bengio新作:Were RNNs All We Needed?
自從 Transformer 模型問(wèn)世以來(lái),試圖挑戰(zhàn)其在自然語(yǔ)言處理地位的挑戰(zhàn)者層出不窮。
這次登場(chǎng)的選手,不僅要挑戰(zhàn) Transformer 的地位,還致敬了經(jīng)典論文的名字。
再看這篇論文的作者列表,圖靈獎(jiǎng)得主、深度學(xué)習(xí)三巨頭之一的 Yoshua Bengio 赫然在列。
- 論文標(biāo)題:Were RNNs All We Needed?
- 論文地址:https://arxiv.org/pdf/2410.01201v1
最近,大家重新對(duì)用循環(huán)序列模型來(lái)解決 Transformer 長(zhǎng)上下文的問(wèn)題產(chǎn)生了興趣,出現(xiàn)了一大批有關(guān)成果,其中 Mamba 的成功引爆了 AI 圈,更是點(diǎn)燃了大家的研究熱情。
Bengio 和他的研究團(tuán)隊(duì)發(fā)現(xiàn),這些新的序列模型有很多共同點(diǎn),于是他們重新審視了 LSTM 和 GRU 這兩種經(jīng)典 RNN 模型。
結(jié)果發(fā)現(xiàn),精簡(jiǎn)掉其中的隱藏狀態(tài)依賴之后,不再需要基于時(shí)間反向傳播的 LSTM 和 GRU 的表現(xiàn)就能和 Transformer 打個(gè)平手。
LSTM 和 GRU 僅能順序處理信息,并且在訓(xùn)練時(shí)依賴反向傳播,這使得它們?cè)谔幚泶罅繑?shù)據(jù)時(shí)速度緩慢,最終被淘汰。
基于以上發(fā)現(xiàn),他們進(jìn)一步簡(jiǎn)化了 LSTM 和 GRU,去掉了它們對(duì)輸出范圍的限制,并確保它們的輸出在時(shí)間上是獨(dú)立的,進(jìn)而得到了 minLSTM 和 minGRU。
相比傳統(tǒng) RNN,它們不僅訓(xùn)練時(shí)所需的參數(shù)顯著減少,還可以并行訓(xùn)練,比如上下文長(zhǎng)度為 512 時(shí),速度能提升 175 倍。
這其實(shí)也是 Bengio 長(zhǎng)期關(guān)注 RNN 的系列研究成果。在今年五月,Bengio 及其研究團(tuán)隊(duì)和加拿大皇家銀行 AI 研究所 Borealis AI 合作發(fā)布了一篇名為《Attention as an RNN》的論文。
正如論文名字所示,他們將注意力機(jī)制重新詮釋為一種 RNN,引入了一種基于并行前綴掃描(prefix scan)算法的新的注意力公式,該公式能夠高效地計(jì)算注意力的多對(duì)多(many-to-many)RNN 輸出。基于新公式的模塊 Aaren,不僅可以像 Transformer 一樣并行訓(xùn)練,還可以像 RNN 一樣高效更新。
簡(jiǎn)化 LSTM 和 GRU
在這一部分,研究者通過(guò)簡(jiǎn)化和移除各種門中的若干隱藏狀態(tài)依賴關(guān)系,證明 GRU 和 LSTM 可通過(guò)并行掃描進(jìn)行訓(xùn)練。
在此基礎(chǔ)上,研究者進(jìn)一步簡(jiǎn)化了這些 RNN,消除了它們對(duì)輸出范圍的限制(即 tanh),并確保輸出在規(guī)模上與時(shí)間無(wú)關(guān)。
綜合上述步驟,研究者提出了 GRUs 和 LSTMs 的最小版本(minGRUs 和 minLSTMs),它們可通過(guò)并行掃描進(jìn)行訓(xùn)練,且性能可與 Transformers 和最近提出的序列方法相媲美。
minGRU
研究者結(jié)合了兩個(gè)簡(jiǎn)化步驟,得到了一個(gè)極簡(jiǎn)版的 GRU(minGRU)。
由此產(chǎn)生的模型比原始 GRU 效率大大提高,只需要 個(gè)參數(shù),而不是 GRU 的
個(gè)參數(shù)(其中 d_x 和 d_h 分別對(duì)應(yīng)于 x_t 和 h_t 的大小)。在訓(xùn)練方面,minGRU 可以使用并行掃描算法進(jìn)行并行訓(xùn)練,從而大大加快訓(xùn)練速度。
在實(shí)驗(yàn)部分,研究者展示了在 T4 GPU 上,當(dāng)序列長(zhǎng)度為 512 時(shí),訓(xùn)練步驟的速度提高了 175 倍。參數(shù)效率的提高也非常顯著。通常,在 RNN 中會(huì)進(jìn)行狀態(tài)擴(kuò)展(即 ,其中 α ≥ 1),使模型更容易從輸入中學(xué)習(xí)特征。
minLSTM
研究者結(jié)合了三個(gè)簡(jiǎn)化步驟,得到 LSTM 的最小版本(minLSTM):
與 LSTM 的 相比,最小版本(minLSTM)的效率明顯更高,只需要
個(gè)參數(shù)。此外,minLSTM 可以使用并行掃描算法進(jìn)行并行訓(xùn)練,大大加快了訓(xùn)練速度。例如,在 T4 GPU 上,對(duì)于長(zhǎng)度為 512 的序列,minLSTM 比 LSTM 加快了 235 倍。在參數(shù)效率方面,當(dāng) α = 1、2、3 或 4(其中
)時(shí),與 LSTM 相比,minLSTM 僅使用了 38%、25%、19% 或 15% 的參數(shù)。
Were RNNs All We Needed?
在本節(jié)中,研究者將對(duì)最小版本(minLSTMs 和 minGRUs)與傳統(tǒng)版本(LSTMs 和 GRUs)以及現(xiàn)代序列模型進(jìn)行了比較。
Minimal LSTMs 和 GRU 非常高效
在測(cè)試時(shí),循環(huán)序列模型會(huì)按順序推出,從而使其推理更為高效。相反,傳統(tǒng) RNN 的瓶頸在于其訓(xùn)練,需要線性訓(xùn)練時(shí)間(通過(guò)時(shí)間反向傳播),這導(dǎo)致其最終被淘汰。人們對(duì)循環(huán)序列模型重新產(chǎn)生興趣,是因?yàn)樵S多新的架構(gòu)可以高效地進(jìn)行并行訓(xùn)練。
研究者對(duì)比了訓(xùn)練傳統(tǒng) RNN(LSTM 和 GRU)、它們的最小版本(minLSTM 和 minGRU)以及一種最新的序列模型所需的資源,還特別將重點(diǎn)放在與最近大受歡迎的 Mamba 的比較上。實(shí)驗(yàn)考慮了 64 的批大小,并改變了序列長(zhǎng)度。研究者測(cè)量了通過(guò)模型執(zhí)行前向傳遞、計(jì)算損失和通過(guò)后向傳遞計(jì)算梯度的總運(yùn)行時(shí)間和內(nèi)存復(fù)雜度。
運(yùn)行時(shí)間。在運(yùn)行時(shí)間方面(見圖 1(左)),簡(jiǎn)化版 LSTM 和 GRU(minLSTM 和 minGRU)Mamba 的運(yùn)行時(shí)間相近。對(duì) 100 次運(yùn)行進(jìn)行平均,序列長(zhǎng)度為 512 的 minLSTM、minGRU 和 Mamba 的運(yùn)行時(shí)間分別為 2.97、2.72 和 2.71 毫秒。
對(duì)于長(zhǎng)度為 4096 的序列,運(yùn)行時(shí)間分別為 3.41、3.25 和 3.15 毫秒。相比之下,傳統(tǒng)的 RNN 對(duì)應(yīng)程序(LSTM 和 GRU)所需的運(yùn)行時(shí)間與序列長(zhǎng)度成線性關(guān)系。對(duì)于 512 的序列長(zhǎng)度,在 T4 GPU 上,minGRUs 和 minLSTMs 每個(gè)訓(xùn)練步驟的速度分別比 GRUs 和 LSTMs 快 175 倍和 235 倍(見圖 1(中))。隨著序列長(zhǎng)度的增加,minGRUs 和 minLSTMs 的改進(jìn)更為顯著,在序列長(zhǎng)度為 4096 時(shí),minGRUs 和 minLSTMs 的速度分別提高了 1324 倍和 1361 倍。因此,在 minGRU 需要一天才能完成固定數(shù)量的 epoch 訓(xùn)練的情況下,其傳統(tǒng)對(duì)應(yīng)的 GRU 可能需要 3 年多的時(shí)間。
內(nèi)存。通過(guò)利用并行掃描算法高效地并行計(jì)算輸出,minGRU、minLSTM 和 Mamba 創(chuàng)建了一個(gè)更大的計(jì)算圖,因此與傳統(tǒng)的 RNN 相比需要更多內(nèi)存(見圖 1(右))。與傳統(tǒng)的 RNN 相比,最小變體(minGRU 和 minLSTM)多用了 88% 的內(nèi)存。與 minGRU 相比,Mamba 多用了 56% 的內(nèi)存。但實(shí)際上,運(yùn)行時(shí)間是訓(xùn)練 RNN 的瓶頸。
刪除 的效果。最初的 LSTM 和 GRU 使用輸入 x_t 和之前的隱藏狀態(tài)
計(jì)算各種門電路。這些模型利用其與時(shí)間依賴的門來(lái)學(xué)習(xí)復(fù)雜函數(shù)。然而,minLSTM 和 minGRU 的訓(xùn)練效率是通過(guò)放棄門對(duì)之前隱藏狀態(tài)
的依賴性來(lái)實(shí)現(xiàn)的。因此,minLSTM 和 minGRU 的門僅與輸入 x_t 依賴,從而產(chǎn)生了更簡(jiǎn)單的循環(huán)模塊。因此,由單層 minLSTM 或 minGRU 組成的模型的柵極是與時(shí)間無(wú)關(guān)的,因?yàn)槠錀l件是與時(shí)間無(wú)關(guān)的輸入
。
然而,在深度學(xué)習(xí)中,模型是通過(guò)堆疊模塊構(gòu)建的。雖然第一層的輸入 與時(shí)間無(wú)關(guān),但其輸出
與時(shí)間有關(guān),并被用作第二層的輸入,即
。因此,從第二層開始,minLSTM 和 minGRU 的門也將隨時(shí)間變化,從而建立更復(fù)雜的函數(shù)模型。表 1 比較了不同層數(shù)的模型在 Mamba 論文中的選擇性復(fù)制任務(wù)上的表現(xiàn)??梢粤⒓纯闯鰰r(shí)間依賴性的影響:將層數(shù)增加到 2 層或更多,模型的性能就會(huì)大幅提高。
訓(xùn)練穩(wěn)定性。層數(shù)的另一個(gè)影響是穩(wěn)定性增強(qiáng),隨著層數(shù)的增加,準(zhǔn)確率的差異減?。ㄒ姳?1)。此外,雖然 minLSTM 和 minGRU 都能解決選擇性復(fù)制任務(wù),但可以看到 minGRU 是一種經(jīng)驗(yàn)上比 minLSTM 更穩(wěn)定的方法,它能以更高的一致性和更低的方差解決該任務(wù)。在訓(xùn)練過(guò)程中,這兩組參數(shù)的調(diào)整方向不同,使得比率更難控制和優(yōu)化。相比之下,minGRU 的信息丟棄和添加由單組參數(shù)(更新門)控制,因此更容易優(yōu)化。
Minimal LSTMs 和 GRUs 表現(xiàn)良好
上述內(nèi)容展示了簡(jiǎn)化傳統(tǒng) RNN 所帶來(lái)的顯著效率提升。這部分將探討最小版本的 LSTM 和 GRU 與幾種流行的序列模型相比的經(jīng)驗(yàn)性能。
選擇性復(fù)制。此處考慮 Mamba 論文中的長(zhǎng)序列選擇性復(fù)制任務(wù)。與最初的復(fù)制任務(wù)不同,選擇性復(fù)制任務(wù)的輸入元素相對(duì)于輸出元素是隨機(jī)間隔的,這增加了任務(wù)的難度。為了解決這個(gè)任務(wù),模型需要進(jìn)行內(nèi)容感知推理,記憶依賴的 token 并過(guò)濾掉不依賴的 token。
表 2 將簡(jiǎn)化版的 LSTM 和 GRU(minLSTM 和 minGRU)與可以并行訓(xùn)練的著名循環(huán)序列模型進(jìn)行了比較:S4、H3、Hyena 和 Mamba (S6)。這些基線的結(jié)果引自 Mamba 論文。在所有這些基線中,只有 Mamba 論文中的 S6 能夠解決這一任務(wù)。minGRU 和 minLSTM 也能解決選擇性復(fù)制任務(wù),其性能與 S6 相當(dāng),并優(yōu)于所有其他基線。LSTM 和 GRU 利用內(nèi)容感知門控機(jī)制,使得這些最小版本足以解決許多熱門序列模型無(wú)法解決的這一任務(wù)。
強(qiáng)化學(xué)習(xí)。接下來(lái),研究者討論了 D4RL 基準(zhǔn)中的 MuJoCo 運(yùn)動(dòng)任務(wù)。具體來(lái)說(shuō)考慮了三種環(huán)境:HalfCheetah、Hopper 和 Walker。對(duì)于每種環(huán)境,模型都在三種不同數(shù)據(jù)質(zhì)量的數(shù)據(jù)集上進(jìn)行訓(xùn)練:中等數(shù)據(jù)集(M)、中等游戲數(shù)據(jù)集(M-R)和中等專家數(shù)據(jù)集(M-E)。
表 3 將 minLSTM 和 minGRU 與各種 Decision Transformer 變體進(jìn)行了比較,包括原始 Decision Transformer (DT)、Decision S4 (DS4)、Decision Mamba 和(Decision)Aaren。minLSTM 和 minGRU 的性能優(yōu)于 Decision S4,與 Decision Transformer、Aaren 和 Mamba 相比也不遑多讓。與其他循環(huán)方法不同,Decision S4 是一種循環(huán)轉(zhuǎn)換不感知輸入的模型,這影響了其性能。從 3 × 3 = 9 個(gè)數(shù)據(jù)集的平均得分來(lái)看,minLSTM 和 minGRU 優(yōu)于所有基線方法,只有 Decision Mamba 的差距很小。
語(yǔ)言建模。研究者使用 nanoGPT 框架對(duì)莎士比亞作品進(jìn)行字符級(jí) GPT 訓(xùn)練。圖 2 用交叉熵?fù)p失繪制了學(xué)習(xí)曲線,將所提出的最小 LSTM 和 GRU(minLSTM 和 minGRU)與 Mamba 和 Transformers 進(jìn)行了比較。結(jié)果發(fā)現(xiàn),minGRU、minLSTM、Mamba 和 Transformers 的測(cè)試損失相當(dāng),分別為 1.548、1.555、1.575 和 1.547。Mamba 的表現(xiàn)略遜于其他模型,但訓(xùn)練速度更快,尤其是在早期階段,在 400 步時(shí)達(dá)到最佳表現(xiàn),而 minGRU 和 minLSTM 則分別持續(xù)訓(xùn)練到 575 步和 625 步。相比之下,Transformers 的訓(xùn)練速度明顯較慢,需要比 minGRU 多 2000 步(~ 2.5 倍)的訓(xùn)練步驟才能達(dá)到與 minGRU 相當(dāng)?shù)男阅埽@使得它的訓(xùn)練速度明顯更慢,資源消耗也更大(與 minGRU、minLSTM 和 Mamba 的線性復(fù)雜度相比,Transformers 的復(fù)雜度為二次方)。
更多研究細(xì)節(jié),可參考原論文。