OpenAI:訓(xùn)練大型神經(jīng)網(wǎng)絡(luò)的四種基本方法
本文轉(zhuǎn)自雷鋒網(wǎng),如需轉(zhuǎn)載請(qǐng)至雷鋒網(wǎng)官網(wǎng)申請(qǐng)授權(quán)。
大型神經(jīng)網(wǎng)絡(luò)是當(dāng)前人工智能領(lǐng)域的熱門(mén)話題之一,那么,如何訓(xùn)練大模型?
最近,曾推出大規(guī)模預(yù)訓(xùn)練模型 GPT-3 的 OpenAI 發(fā)表了一篇博文,介紹了基于 GPU 的四種節(jié)省內(nèi)存的并行訓(xùn)練方法,分別是:
- 數(shù)據(jù)并行——在不同的 GPU 上運(yùn)行同一批次的不同子集;
- 流水線并行——在不同的 GPU 上運(yùn)行模型的不同層;
- 張量并行——分解單個(gè)運(yùn)算的數(shù)學(xué)運(yùn)算,例如將矩陣乘法拆分到 GPU 上;
- 專(zhuān)家混合(MOE)——僅通過(guò)每層的一小部分處理每個(gè)示例。
圖注:三層模型上各種并行策略,每種顏色代表一層,虛線分隔不同的 GPU。
1 數(shù)據(jù)并行
「數(shù)據(jù)并行訓(xùn)練」意味著將相同的參數(shù)復(fù)制到多個(gè) GPU(通常稱(chēng)為“workers”),并為每個(gè) GPU 分配不同的示例以同時(shí)處理。
單單的數(shù)據(jù)并行要求模型匹配單個(gè) GPU 內(nèi)存,但當(dāng)你利用多個(gè) GPU 計(jì)算時(shí),代價(jià)是存儲(chǔ)參數(shù)的多個(gè)副本。不過(guò),話雖如此,有一些策略可以增加 GPU 可用的有效 RAM,例如,在兩次使用之間,可將參數(shù)暫時(shí)卸載到 CPU 內(nèi)存。
隨著每次數(shù)據(jù)并行 worker 更新其參數(shù)副本,它們需要相互協(xié)調(diào),以確保每個(gè) worker 都繼續(xù)具有相似的參數(shù)。最簡(jiǎn)單的方法是在 worker 之間引入「阻塞通信」:
步驟 1:獨(dú)立計(jì)算每個(gè)worker上的梯度;
步驟 2:將不同 worker 的梯度平均;
步驟 3:在每個(gè) worker 上獨(dú)立計(jì)算相同的新參數(shù)。
步驟 2 是一個(gè)阻塞平均值,它需要傳輸大量數(shù)據(jù)(與 worker 數(shù)量乘以參數(shù)大小成正比),這可能會(huì)損害訓(xùn)練的吞吐量。有各種異步同步方案可以消除這種損耗,但會(huì)損害學(xué)習(xí)效率;因此在實(shí)踐中,人們普遍堅(jiān)持同步方法。
2 流水線并行
在流水線并行訓(xùn)練中,研究者會(huì)將模型的順序塊劃分到 GPU 上,每個(gè) GPU 只保存一小部分參數(shù),因此,相同模型的每個(gè) GPU 消耗的內(nèi)存按比例減少。
將大型模型拆分為連續(xù)層的塊很簡(jiǎn)單,但由于層的輸入和輸出之間存在順序依賴(lài)關(guān)系,因此,在 worker 等待前一臺(tái)機(jī)器的輸出用作其輸入時(shí),一個(gè)幼稚的執(zhí)行可能會(huì)導(dǎo)致出現(xiàn)大量空閑時(shí)間。這些等待時(shí)間塊被稱(chēng)為「泡沫」(bubbles),即浪費(fèi)了本可以由空閑機(jī)器來(lái)完成的計(jì)算。
圖注:一個(gè)簡(jiǎn)單的流水線并行設(shè)置插圖,其中,模型被垂直分成 4 個(gè)分區(qū)。worker 1 主持第一層的模型參數(shù)(最接近輸入),而 worker 4 主持第 4 層(最接近輸出)?!癋”、“B”和“U”分別代表前向、后向和更新操作。下標(biāo)會(huì)指示在哪個(gè) worker 上運(yùn)行操作。由于順序依賴(lài)性,數(shù)據(jù)一次由一個(gè) worker 處理,導(dǎo)致產(chǎn)生了大量的空閑時(shí)間“泡沫”。
我們可以重用數(shù)據(jù)并行的想法,通過(guò)讓每個(gè) worker 一次只處理數(shù)據(jù)元素的一個(gè)子集,來(lái)降低產(chǎn)生時(shí)間泡沫的成本,從而使我們能巧妙地將新計(jì)算與等待時(shí)間重疊。核心思想是,將一個(gè)批次拆分為多個(gè)微批次,每個(gè)微批次的處理速度都應(yīng)該成比例地加快,并且每個(gè) worker 在下一個(gè)微批次可用時(shí)立即開(kāi)始工作,從而加快管道執(zhí)行。有了足夠的微批次, worker 可以在大部分時(shí)間被利用,并且在步驟開(kāi)始和結(jié)束時(shí)「泡沫」最小。梯度在微批次之間進(jìn)行平均,并且只有在所有微批次完成后才會(huì)更新參數(shù)。
模型拆分的 worker 數(shù)量通常稱(chēng)為「管道深度」(pipeline depth)。
在前向傳遞期間,worker 只需將其層塊的輸出(稱(chēng)為「激活」)發(fā)送給下一個(gè) worker;在反向傳遞期間,它僅將這些激活的梯度發(fā)送給前一個(gè)工作人員。如何安排這些通道以及如何跨微批次聚合梯度有很大的設(shè)計(jì)空間。例如,方法 GPipe 是讓每個(gè)工作進(jìn)程連續(xù)向前和向后傳遞,然后在最后同步聚合來(lái)自多個(gè)微批次的梯度;而 PipeDream 會(huì)安排每個(gè) worker 交替處理的前向和后向通道。
圖注:GPipe 和 PipeDream 流水線方案的比較,每批使用 4 個(gè)微批次。微批次 1-8 對(duì)應(yīng)于兩個(gè)連續(xù)的數(shù)據(jù)批次。圖中“number”表示在哪個(gè)微批次上操作,下標(biāo)標(biāo)記 worker ID。注意,PipeDream 通過(guò)使用陳舊參數(shù)執(zhí)行一些計(jì)算來(lái)獲得更高的效率。
3 張量并行
管道并行性將模型逐層“垂直”拆分,也可以在一個(gè)層內(nèi)“水平”拆分某些操作,這通常稱(chēng)為張量訓(xùn)練。
對(duì)于許多現(xiàn)代模型(例如Transformer),計(jì)算瓶頸是將激活批處理矩陣與大權(quán)重矩陣相乘。矩陣乘法可以認(rèn)為是成對(duì)的行和列之間的點(diǎn)積;可以在不同的 GPU 上計(jì)算獨(dú)立的點(diǎn)積,或者在不同的 GPU 上計(jì)算每個(gè)點(diǎn)積的部分并總結(jié)結(jié)果。無(wú)論采用哪種策略,我們都可以將權(quán)重矩陣分割成大小均勻的“碎片”,將每個(gè)碎片托管在不同的 GPU 上,并使用該碎片計(jì)算整個(gè)矩陣乘積的相關(guān)部分,然后再進(jìn)行通信以組合結(jié)果。
一個(gè)例子是Megatron-LM,它在 Transformer 的自注意力和 MLP 層內(nèi)并行化矩陣乘法。PTD-P使用張量、數(shù)據(jù)和流水線并行,其流水線調(diào)度為每個(gè)設(shè)備分配了多個(gè)不連續(xù)的層,以增加網(wǎng)絡(luò)通信為代價(jià)來(lái)減少泡沫損耗。
有時(shí),網(wǎng)絡(luò)輸入可以跨維度并行化,相對(duì)于交叉通信具有高度的并行計(jì)算。序列并行就是這樣一種想法,其中輸入序列在時(shí)間上被分成多個(gè)子示例,通過(guò)允許計(jì)算繼續(xù)進(jìn)行更細(xì)粒度的示例,來(lái)按比例減少峰值內(nèi)存消耗。
4 專(zhuān)家混合 (MoE)
使用專(zhuān)家混合(MoE)方法,只有小部分網(wǎng)絡(luò)用于計(jì)算任何一個(gè)輸入的輸出。
一個(gè)示例方法是擁有多組權(quán)重,并且網(wǎng)絡(luò)可在推理時(shí)通過(guò)門(mén)控機(jī)制選擇要使用的權(quán)重組,這能在不增加計(jì)算成本的情況下啟用更多參數(shù)。每組權(quán)重都被稱(chēng)為“專(zhuān)家”,且希望網(wǎng)絡(luò)能學(xué)會(huì)為每個(gè)專(zhuān)家分配專(zhuān)門(mén)的計(jì)算和技能。不同的專(zhuān)家可以主持不同的 GPU ,從而提供了一種明確的方式來(lái)擴(kuò)大用于模型的 GPU 數(shù)量。
圖注:門(mén)控網(wǎng)絡(luò)只選擇了n個(gè)專(zhuān)家中的2個(gè)。
GShard 將 MoE Transformer 的參數(shù)擴(kuò)展到 6000 億個(gè)參數(shù),其中僅將 MoE 層拆分到多個(gè) TPU 設(shè)備上,其他層則完全復(fù)制。Switch Transformer 通過(guò)將一個(gè)輸入路由到單個(gè)專(zhuān)家,將模型大小擴(kuò)展到數(shù)萬(wàn)億個(gè)參數(shù),具有更高的稀疏性。
5 其他節(jié)省內(nèi)存的設(shè)計(jì)
還有許多其他的計(jì)算策略,可以使訓(xùn)練越來(lái)越大的神經(jīng)網(wǎng)絡(luò)更容易處理。例如:
要計(jì)算梯度,需要保存原始激活,這會(huì)消耗大量設(shè)備 RAM。檢查點(diǎn)(也稱(chēng)為激活重新計(jì)算)存儲(chǔ)激活的任何子集,并在反向傳遞期間,及時(shí)重新計(jì)算中間的激活,以最多一個(gè)額外完整前向傳遞的計(jì)算成本,節(jié)省了大量?jī)?nèi)存。人們還可以通過(guò)選擇性激活重新計(jì)算,來(lái)不斷權(quán)衡計(jì)算和內(nèi)存成本,這是對(duì)激活的子集進(jìn)行檢查,其存儲(chǔ)成本相對(duì)較高,但計(jì)算成本較低。
混合精度訓(xùn)練是使用較低精度的數(shù)字(最常見(jiàn)的是FP16)來(lái)訓(xùn)練模型?,F(xiàn)代加速器可以使用較低精度的數(shù)字達(dá)到更高的 FLOP 計(jì)數(shù),并且還能節(jié)省設(shè)備 RAM。在適當(dāng)?shù)恼疹櫹拢a(chǎn)生的模型幾乎可以不損失任何精度。
卸載是將未使用的數(shù)據(jù)臨時(shí)卸載到 CPU 或不同設(shè)備之間,在需要時(shí)將其讀回。幼稚的執(zhí)行會(huì)大大減慢訓(xùn)練速度,但復(fù)雜的實(shí)現(xiàn)方式會(huì)預(yù)先獲取數(shù)據(jù),使設(shè)備永遠(yuǎn)不需要等待。這個(gè)想法的一個(gè)實(shí)現(xiàn)是ZeRO,它可將參數(shù)、梯度和優(yōu)化器狀態(tài)分割到所有可用的硬件上,并根據(jù)需要將它們具體化。
Memory Efficient Optimizers已經(jīng)提出了內(nèi)存效率優(yōu)化器,以減少優(yōu)化器所維護(hù)的運(yùn)行狀態(tài)的內(nèi)存占用,例如Adafactor。
壓縮也可用于存儲(chǔ)網(wǎng)絡(luò)中的中間結(jié)果。例如,Gist壓縮為后向傳遞而保存的激活;DALL-E在同步梯度之前壓縮梯度。