大模型推理速度飆升3.6倍,「美杜莎」論文來(lái)了,賈揚(yáng)清:最優(yōu)雅加速推理方案之一
如你我所知,在大型語(yǔ)言模型(LLM)的運(yùn)行邏輯中,隨著規(guī)模大小的增加,語(yǔ)言生成的質(zhì)量會(huì)隨著提高。不過(guò),這也導(dǎo)致了推理延遲的增加,從而對(duì)實(shí)際應(yīng)用構(gòu)成了重大挑戰(zhàn)。
從系統(tǒng)角度來(lái)看,LLM 推理主要受內(nèi)存限制,主要延遲瓶頸源于加速器的內(nèi)存帶寬而非算術(shù)計(jì)算。這一瓶頸是自回歸解碼的順序性所固有的,其中每次前向傳遞都需要將完整的模型參數(shù)從高帶寬內(nèi)存?zhèn)鬏數(shù)郊铀倨骶彺?。該過(guò)程僅生成了單個(gè)的 token,沒有充分利用現(xiàn)代加速器的算術(shù)計(jì)算潛力,導(dǎo)致了效率低下。
為了解決這一問題,加速 LLM 推理的方法被提出,既可以增加解碼過(guò)程的算術(shù)強(qiáng)度(FLOPs 與總數(shù)據(jù)移動(dòng)的比率),也能減少解碼步驟數(shù)量。這類方法以推測(cè)解碼(speculative decoding)為代表,使用較小的草稿(draft) 模型在每一步生成 token 序列,然后通過(guò)較大的原始模型進(jìn)行細(xì)化以獲得可接受的延續(xù)。不過(guò)獲得合適的草稿模型仍然具有挑戰(zhàn)性,并且將草稿模型集成到分布式系統(tǒng)中更加困難。
在本文中,來(lái)自普林斯頓大學(xué)、Together.AI、伊利諾伊大學(xué)厄巴納 - 香檳分校等機(jī)構(gòu)的研究者沒有使用單獨(dú)的草稿模型來(lái)順序生成候選輸出,而是重新審視并完善了在主干模型之上使用多個(gè)解碼頭加速推理的概念。他們發(fā)現(xiàn),如果該技術(shù)得到有效應(yīng)用,可以克服推測(cè)解碼的挑戰(zhàn),從而無(wú)縫地集成到現(xiàn)有 LLM 系統(tǒng)中。
具體來(lái)講, 研究者提出了 MEDUSA,一種通過(guò)集成額外解碼頭(能夠同時(shí)預(yù)測(cè)多個(gè) tokens)來(lái)增強(qiáng) LLM 推理的方法。這些頭以參數(shù)高效的方式進(jìn)行微調(diào),并可以添加到任何現(xiàn)有模型中。至此,不需要任何新模型,MEDUSA 就可以輕松地集成地當(dāng)前的 LLM 系統(tǒng)中(包括分布式環(huán)境),以確保友好用戶體驗(yàn)。
值得關(guān)注的是,該論文作者之一 Tri Dao 是近來(lái)非常火爆的 Transformer 替代架構(gòu) Mamba 的兩位作者之一。他是 Together.AI 首席科學(xué)家,并即將成為普林斯頓大學(xué)計(jì)算機(jī)科學(xué)助理教授。
- 論文地址:https://arxiv.org/pdf/2401.10774.pdf
- GitHub 地址:https://arxiv.org/pdf/2401.10774.pdf
在具體實(shí)現(xiàn)中,研究者通過(guò)兩個(gè)關(guān)鍵見解進(jìn)一步增強(qiáng)了 MEDUSA。首先,當(dāng)前在每個(gè)解碼步驟生成單個(gè)候選延續(xù)的方法導(dǎo)致了可接受長(zhǎng)度受限和計(jì)算資源的低效使用。為了解決這個(gè)問題,他們建議使用 MEDUSA 頭來(lái)生成多個(gè)候選延續(xù),并通過(guò)對(duì)注意力掩碼的簡(jiǎn)單調(diào)整來(lái)進(jìn)行驗(yàn)證。其次可以使用類似于推測(cè)解碼中的拒絕采樣方案來(lái)生成與原始模型具有相同分布的響應(yīng),但對(duì)于很多 LLM 應(yīng)用來(lái)說(shuō)通常不必要。
因此,研究者考慮或許可以引入一種典型的可接受方案,即從 MEDUSA 輸出中選擇合理的候選者。他們使用溫度作為閾值來(lái)管理原始模型預(yù)測(cè)的偏差,為拒絕采樣提供了一種有效的替代方案。這種方法有效地解決了拒絕采樣的局限性,比如在較高溫度下速度降低。
此外,為了給 LLM 配備預(yù)測(cè)性的 MEDUSA 頭,研究者提出了兩種針對(duì)不同場(chǎng)景量身定制的微調(diào)程序。對(duì)于計(jì)算資源有限或者目標(biāo)是將 MEDUSA 納入現(xiàn)有模型而不影響其性能的情況,他們建議使用 MEDUSA-1。該方法需要的內(nèi)存最少,并且可以使用類似于 QLoRA 中的量化技術(shù)來(lái)進(jìn)一步優(yōu)化,而不會(huì)因固定主干模型影響生成質(zhì)量。
不過(guò),對(duì)于 MEDUSA-1,主干模型的全部潛力無(wú)法得到充分利用。因此可以進(jìn)一步進(jìn)行微調(diào),以提高 MEDUSA 頭的預(yù)測(cè)精度,并直接帶來(lái)更大加速。因此研究者提出了 MEDUSA - 2,它適用于計(jì)算資源充足或從基礎(chǔ)模型進(jìn)行直接監(jiān)督微調(diào)的場(chǎng)景。MEDUSA-2 的關(guān)鍵是一個(gè)訓(xùn)練協(xié)議,它能夠?qū)?MEDUSA 頭和主干模型進(jìn)行聯(lián)合訓(xùn)練,而不會(huì)影響模型下一個(gè) token 的預(yù)測(cè)能力和輸出質(zhì)量。
在實(shí)驗(yàn)部分,研究者主要關(guān)注批大小為 1 的場(chǎng)景,這代表了 LLM 本地托管以供個(gè)人使用的用例。他們?cè)诓煌笮『陀?xùn)練設(shè)置下測(cè)試了 MEDUSA,包括 Vicuna-7B 和 13B(使用公共數(shù)據(jù)集訓(xùn)練)、Vicuna -33B(使用私有數(shù)據(jù)集訓(xùn)練)、Zephyr-7B(使用監(jiān)督微調(diào)和對(duì)齊訓(xùn)練)。
結(jié)果表明,MEDUSA 在不影響生成質(zhì)量的情況下,可以在不同的 promt 類型中實(shí)現(xiàn) 2.3 至 3.6 的推理加速。如下動(dòng)圖為 Vicuna-7b 上有無(wú) Medusa-1 時(shí)推理速度比較。
論文共同一作 Tianle Cai 表示,自 Medusa 項(xiàng)目推出以來(lái),它在 TensorRT、TGI 以及眾多開源項(xiàng)目和公司中得到采用。在新的技術(shù)論文中,我們推出了用于全模型調(diào)優(yōu)的 Medusa-2 方案、用于將 Medusa 集成到任何微調(diào) LLM 的自蒸餾以及其他更多加速技術(shù)。
對(duì)于這項(xiàng)研究,Lepton AI 創(chuàng)始人賈揚(yáng)清表示,Medusa 可能是他們見過(guò)的最優(yōu)雅的加速推理解決方案之一,能夠與 int8/fp8、編譯等互補(bǔ),在實(shí)踐中實(shí)現(xiàn) 2 倍性能增益。
并且,他們已將 Medusa 與很多現(xiàn)有優(yōu)化方法、混合加速方案進(jìn)行集成,結(jié)果在合理的并發(fā)下,加速保持正值,并在 A100 和 H100 等卡中尤其有效。此外,他們還已經(jīng)為 Llama 模型訓(xùn)練了通用 Medusa 頭。
方法概覽
MEDUSA 遵循推測(cè)解碼框架,其中每個(gè)解碼步驟主要由三個(gè)子步驟組成:(1) 生成候選者,(2) 處理候選者, (3) 接受候選者。對(duì)于 MEDUSA,(1) 是通過(guò) MEDUSA 頭(head)實(shí)現(xiàn)的,(2) 是通過(guò)樹注意力(tree attention)實(shí)現(xiàn)的,并且由于 MEDUSA 頭位于原始主干模型之上,因此 (2) 中計(jì)算的 logits 可以用于子步驟 (1) 的下一個(gè)解碼步驟。最后一步 (3) 可以通過(guò)拒絕采樣(rejection sampling)或典型接受(typical acceptance)來(lái)實(shí)現(xiàn)。MEDUSA 的整體流程如下圖 1 所示。
關(guān)鍵組件
MEDUSA 的關(guān)鍵組件主要包括 MEDUSA 頭和樹注意力。
首先,MEDUSA 頭與原始主干模型一起進(jìn)行訓(xùn)練。其中,原始主干模型可以在訓(xùn)練期間保持凍結(jié)狀態(tài) (MEDUSA-1) 或一起訓(xùn)練 (MEDUSA-2)。這種方法甚至可以在單個(gè) GPU 上微調(diào)大模型,利用強(qiáng)大的基礎(chǔ)模型學(xué)得的表征。
此外,MEDUSA 頭的分布確保與原始模型的分布一致,從而緩解了分布偏移問題,并且 MEDUSA 不會(huì)增加服務(wù)系統(tǒng)設(shè)計(jì)的復(fù)雜性,對(duì)分布式設(shè)置很友好。
由于候選者增加會(huì)提高計(jì)算需求,該研究采用樹狀結(jié)構(gòu)的注意力機(jī)制來(lái)同時(shí)處理多個(gè)候選者。這種注意力機(jī)制不同于傳統(tǒng)的因果注意力范式。在其框架內(nèi),只有來(lái)自同一 continuation 的 token 才被視為歷史數(shù)據(jù)。受圖神經(jīng)網(wǎng)絡(luò)領(lǐng)域提出的將圖結(jié)構(gòu)嵌入注意力的啟發(fā),研究團(tuán)隊(duì)還將樹結(jié)構(gòu)合并到注意力掩碼中,如下圖 2 所示。
訓(xùn)練策略
凍結(jié)主干模型來(lái)訓(xùn)練 MEDUSA 頭的方法很簡(jiǎn)單,并且需要的計(jì)算資源很少,但是將主干網(wǎng)絡(luò)與 MEDUSA 頭結(jié)合訓(xùn)練可以顯著提高 MEDUSA 頭的準(zhǔn)確性。因此,根據(jù)計(jì)算資源和用例的具體要求,研究團(tuán)隊(duì)為 MEDUSA 頭提出了兩個(gè)級(jí)別的訓(xùn)練策略,即 MEDUSA-1:凍結(jié)主干網(wǎng)絡(luò),MEDUSA-2:聯(lián)合訓(xùn)練。
最后,該研究提出了 MEDUSA 的兩個(gè)擴(kuò)展,包括自蒸餾(self-distillation)和典型接受(typical acceptance),分別用于處理 MEDUSA 沒有可用訓(xùn)練數(shù)據(jù)的情況和提高解碼過(guò)程的效率。
實(shí)驗(yàn)
為了證明 MEDUSA 在不同設(shè)置下的有效性,該研究進(jìn)行了兩組實(shí)驗(yàn):首先,在 Vicuna-7B/13B 模型上評(píng)估 MEDUSA,以展示 MEDUSA-1 和 MEDUSA-2 的性能;其次,在 Vicuna-33B 和 Zephyr-7B 模型上評(píng)估 MEDUSA,以研究自蒸餾的有效性,因?yàn)?Vicuna-33B 模型的訓(xùn)練數(shù)據(jù)集不公開,而 Zephyr-7B 模型使用 RLHF 進(jìn)行訓(xùn)練。
用例研究 1:在 Vicuna-7B/13B 模型上評(píng)估 MEDUSA
在 Vicuna-7B/13B 模型上評(píng)估 MEDUSA-1、MEDUSA-2 的結(jié)果如下圖 4 所示。
用例研究 2:在 Vicuna-33B 和 Zephyr-7B 使用自蒸餾訓(xùn)練
研究者關(guān)注了需要自蒸餾的情況,使用 Vicuna-33B 和 Zephyr-7B 作為示例。他們首先使用一些種子 prompt 來(lái)生成數(shù)據(jù)集,然后將 ShareGPT 和 UltraChat 作為種子數(shù)據(jù)集,并為以上兩個(gè)示例收集了包含大約 100k 樣本的數(shù)據(jù)集。
下表 1 展示了不同 MEDUSA-2 模型在 MT-Bench 基準(zhǔn)下的加速比、開銷和質(zhì)量。
下圖 5 為使用 MEDUSA-2 時(shí)不同模型的加速情況。
消融實(shí)驗(yàn)
下圖 6a 比較了隨機(jī)采樣密集樹設(shè)置(藍(lán)點(diǎn))和優(yōu)化稀疏樹設(shè)置(紅星)的加速率。6b 比較了密集和稀疏樹設(shè)置的速度。
下圖 7 展示了不同采樣設(shè)置下,模型性能的比較分析。
兩階段微調(diào)的有效性。研究者針對(duì) Vicuna-7B 模型,評(píng)估了兩種微調(diào)策略下的性能差異。