蘋果讓大模型學(xué)會偷懶:更快吐出第一個token,準(zhǔn)確度還保住了
Llama 3.1 剛剛發(fā)布,你是否已經(jīng)嘗試了呢?就算你的個人計算機(jī)是最近的頂尖配置,運(yùn)行其中最小的 8B 版本可能也依然會有明顯延遲。為了提升模型的推理效率,研究者想出了多種多樣的方法,但其中很多都會讓模型犧牲一些準(zhǔn)確度。
近日,蘋果和 Meta AI 的一個研究團(tuán)隊提出了一種新方法,可在保證準(zhǔn)確度不明顯下降的同時,將 Llama 2 預(yù)填充階段的推理速度提升到原來的 2 倍以上,這或許能為 Llama 3.1 的加速提供一些啟發(fā)。他們把這種方法稱為 LazyLLM,即懶惰大型語言模型。
- 論文標(biāo)題:LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference
- 論文地址:https://arxiv.org/abs/2407.14057
那么他們是怎么讓 LLM 偷懶的呢?要理解他們的方法,我們首先需要知道標(biāo)準(zhǔn)的基于 prompt 的 LLM 推理過程是怎樣的。簡單來說,該過程分為兩個階段:預(yù)填充和解碼,如圖 1 所示。
在預(yù)填充階段,模型計算和保存 prompt 中每個 token 的 KV 緩存,并預(yù)測首個 token。我們將預(yù)填充階段所耗費的時間稱為「首個 token 時間(TTFT)」。
預(yù)填充階段之后是解碼階段。在這個階段,模型再次使用緩存的 KV 來迭代式地解碼下一個 token,直到滿足停止標(biāo)準(zhǔn)。
在預(yù)填充階段,所有 Transformer 層都會使用 prompt 中的所有 token。當(dāng) prompt 較長時,TTFT 可能很慢,因為當(dāng)前最佳的基于 Transformer 的 LLM 既深又寬,并且計算注意力的成本會隨 prompt 中 token 數(shù)量而呈二次增長。舉個例子,Llama 2(7B 版本)堆疊了 32 層 Transformer,模型維度為 4096。在這種情況下,TTFT 需要的 walltime 是每個后續(xù)解碼步驟的 21 倍,在 LongBench 基準(zhǔn)上這些時間大約占用了總生成時間的 23%。
因此,要讓 LLM 推理高效進(jìn)行,優(yōu)化 TTFT 是非常關(guān)鍵的步驟。
盡管 LLM 推理優(yōu)化方面是一個活躍的研究領(lǐng)域,但很多方法關(guān)注的重心都是提升解碼階段的推理速度。研究者很少關(guān)注 TTFT 的改進(jìn)。一些基于壓縮的研究成果可通過減少 LLM 的大小隱式地提升 TTFT。
另一個研究方向是在靜態(tài)的 Transformer 架構(gòu)下實現(xiàn)對 TTFT 的改進(jìn)。對于這個研究方向,很自然會引出一個問題:在生成首個 token 時,所有 prompt token 都必不可少嗎?
圖 2 給出了在 LongBench 基準(zhǔn)上的 LLM 分析結(jié)果。
可以看到,對于首個生成的 token,輸入 token 的注意力分?jǐn)?shù)非常稀疏,這說明輸入 prompt 中的許多 token 是多余的,就算移除也不會影響到下一 token 預(yù)測。這一觀察正是該團(tuán)隊提出 LazyLLM 的基礎(chǔ)。
LazyLLM 的優(yōu)勢包括適用范圍廣、無需訓(xùn)練、效果好。圖 3 對比了標(biāo)準(zhǔn) LLM 與 LazyLLM。
LazyLLM
圖 4 展示了 LazyLLM 的整體框架。
從完整上下文開始,LazyLLM 會逐漸對 token 進(jìn)行剪枝,從而逐漸減少得到最終模型所使用的計算數(shù)量。請注意,LazyLLM 允許模型在不同的生成步驟選取不同的 token 子集,即便它們中的一些可能在之前的步驟中被剪枝了。相比于靜態(tài)剪枝(一次性對所有 token 進(jìn)行剪枝),動態(tài)剪枝會在每個生成步驟對下一 token 預(yù)測進(jìn)行優(yōu)化,這有助于維持模型的性能表現(xiàn)。
漸進(jìn)式 token 剪枝
之前也有一些研究成功使用過 token 剪枝來優(yōu)化 LLM 推理。但是,這些方法需要積累預(yù)測前幾個 token 的完整注意力圖,以便在剪枝開始之前分析 prompt token 的重要性。也因此,它們不適合用于降低 TTFT,因為它們在預(yù)填充階段仍需要計算所有 KV 緩存。
相較之下,LazyLLM 「很懶」,會從推理的第一輪迭代(預(yù)填充步驟)開始,只計算對預(yù)測下一 token 重要的 token。
在第一輪迭代中,一大關(guān)鍵難題是確定各個 token 的重要性。受之前已有研究(其中表明 token 隱藏狀態(tài)會在穿過 Transformer 層時發(fā)生演進(jìn))的啟發(fā),該團(tuán)隊的解決方案是在每個生成步驟使用逐層 token 剪枝。具體來說,他們是使用各層的注意力圖來確定輸入 token 對將要預(yù)測的 token 的重要性。
在計算了 token 的置信度分?jǐn)?shù)之后,另一個難題是確定剪枝 token 的閾值。
具體來說,對于不同的層和不同的任務(wù),該閾值可能會隨注意力分?jǐn)?shù)的變化而改變。該團(tuán)隊的解決思路是使用 top-k 百分位數(shù)選取策略。具體來說,如果一個 token 的置信度分?jǐn)?shù)小于輸入 token 中的第 k 個百分位數(shù),便將其剪枝掉。一旦 token 被剪枝去掉了,它就不再參與所有后續(xù)層的計算。
也就是說,后續(xù)層使用的 token 是之前層所使用 token 的子集。
后面的實驗表明,剪枝層的位置和剪枝的 token 數(shù)量不同時,也會導(dǎo)致性能發(fā)生變化。具體來說,對于同一 Transformer 層,隨著被剪枝去掉的 token 越來越多,模型的性能也會逐漸下降。
他們還發(fā)現(xiàn),相比于早期層的剪枝,在后期層執(zhí)行剪枝時會得到更好的性能,這說明后期層對 token 剪枝的敏感度更低。為了更好地平衡速度與準(zhǔn)確度,該團(tuán)隊使用了如圖 4 所示的漸進(jìn)式剪枝法,從而在早期層保留更多 token,然后在 token 流向后期層的過程中逐漸減少 token 的數(shù)量。
Aux Cache(輔助緩存)
預(yù)填充階段沒有 KV 緩存,每個 token 都表示成隱藏狀態(tài)。因此,可通過移除已被剪枝 token 的隱藏狀態(tài)來實現(xiàn)漸進(jìn)式 token 剪枝。但是,要將漸進(jìn)式 token 剪枝擴(kuò)展到后續(xù)的解碼步驟,卻并不簡單。原因是每個解碼步驟都會使用預(yù)填充階段計算的 KV 緩存來計算注意力。由于 LazyLLM 是在預(yù)填充階段執(zhí)行漸進(jìn)式 token 剪枝,因此在某一層被剪枝的 token 的 KV 不會出現(xiàn)在下一層的 KV 緩存中。
這里提醒一下,LazyLLM 框架允許在每一步讓每個生成步驟從完整的輸入 token 序列中挑選一個不同的 token 子集,無論它們是否已在之前的步驟中被剪枝。舉個例子,在接下來的解碼步驟中,那些在 KV 緩存中不存在的已被剪枝的 token 可能會被重新選取出來用于計算注意力。在這種情況下,模型無法檢索到這些 token 的 KV 緩存。
對此,一個基于直覺的解決方案是再讓這些 token 通過該 Transformer 的起點。但是,這會導(dǎo)致對同一 token 的重復(fù)計算,并最終減慢整體的生成速度。
為解決這個難題,該團(tuán)隊在原有的 KV 緩存之外引入了另一種緩存:Aux Cache(輔助緩存)。
如果已被剪枝 token(如圖 4 中 T4 和 T7)的 KV 并未出現(xiàn)在后續(xù)層的 KV 緩存中,則會由 Aux Cache 保存它們的隱藏狀態(tài)以供后續(xù)迭代檢索。
如圖 4 所示,在每個解碼步驟,每個 Transformer 層首先會檢索過去 token 的 KV 緩存(如果存在的話)。對于那些不在 KV 緩存中的 token,則直接從其前一層的 Aux Cache 中檢索它們的隱藏狀態(tài),而不必再次經(jīng)過之前的層。Aux Cache 可確保每個 token 在每個 Transformer 層中最多被計算一次,還能確保 LazyLLM 最慢時也比標(biāo)準(zhǔn) LLM 快。
實驗
該團(tuán)隊在兩個大型語言模型上檢驗了這種「懶惰」新方法:Llama 2 7B 和 XGen 7B。作為對比的標(biāo)準(zhǔn) LLM 是同樣的公開發(fā)布的預(yù)訓(xùn)練檢查點模型,同時不進(jìn)行任何附加訓(xùn)練。
實驗基準(zhǔn)是 LongBench,這是一個針對長內(nèi)容理解的多任務(wù)基準(zhǔn)。LongBench 基準(zhǔn)包含 16 個數(shù)據(jù)集,涉及 6 個任務(wù),包括單文檔問答、多文檔問答、總結(jié)、少樣本學(xué)習(xí)、合成任務(wù)和代碼補(bǔ)全。
評估指標(biāo)是每種方法在 TTFT 加速與準(zhǔn)確度權(quán)衡方面的效果和效率。
結(jié)果
表 1 給出了 LazyLLM、標(biāo)準(zhǔn) LLM 和其它基線方法的 TTFT 加速和準(zhǔn)確度結(jié)果。
在此表中,baseline 是指標(biāo)準(zhǔn) LLM 推理。random token drop 是指對 token 執(zhí)行隨機(jī)剪枝。static token pruning 是指在預(yù)填充階段基于前面幾個 Transformer 層的注意力方法來對輸入 token 執(zhí)行一次性剪枝。Prompt Compression 就是 prompt 壓縮方法,也就是使用 LLM 去除輸入上下文中的冗余。
從表 1 可以看到,LazyLLM 在 TTFT 加速方面全面優(yōu)勝,同時準(zhǔn)確度方面的下降基本可以忽略不計。需要指出,使用 LLM 來壓縮 prompt 需要大量計算。因此,即使 Prompt Compression 能讓推理速度更快,但其實際的 TTFT 卻比標(biāo)準(zhǔn) LLM 還長。
對總體生成速度的影響
為了評估新方法對總體生成速度的影響,該團(tuán)隊分析了計算使用的 prompt token 百分比和生成加速情況,見表 2。
可以看到,LazyLLM 計算使用的 token 的占比總是低于 100%,這說明 LazyLLM 在生成結(jié)束時也沒有用完 prompt 中的所有 token,但理論上講該模型可以使用所有 token。這能為不同任務(wù)的整體生成過程提供額外的加速。
不同層的丟棄率
該團(tuán)隊也分析了剪枝層的位置和被剪枝 token 的數(shù)量的影響。結(jié)果見圖 6。
可以看到,當(dāng)在同一 Transformer 層進(jìn)行剪枝時,留下的 token 越少,模型的性能越差。這也符合我們的直觀認(rèn)知。此外,相比于在更前期 Transformer 層執(zhí)行剪枝,在后期層進(jìn)行剪枝會得到更好的性能,這說明后期層對 token 剪枝的敏感度更低。
基于這些觀察,可以說漸進(jìn)式 token 剪枝的效果得到了證明。
漸進(jìn)式 KV 增長
最后,該團(tuán)隊也嘗試了理解使用 token 剪枝邏輯的模型的內(nèi)部情況。具體來說,他們想要了解 prompt token 中的累積使用比例以及相應(yīng)的不被使用的比例。這種「累積 token 使用量」可以等價地定義成每一步的 KV 緩存 大小。圖 7 給出了 LazyLLM 的每個階段這些累積的 prompt token 使用量。
該結(jié)果支持這一假設(shè):許多 token 永遠(yuǎn)不會被模型選擇(即便理論上講模型可以使用 prompt 中的所有 token。
考慮到模型依然能維持執(zhí)行任務(wù)的準(zhǔn)確度,因此可以得出結(jié)論:模型可以有效地丟棄不影響輸出質(zhì)量的 token。