偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

MHA -> GQA:提升 LLM 推理效率

發(fā)布于 2025-1-13 11:35
瀏覽
0收藏

一、背景

我們?cè)谥暗奈恼轮性敿?xì)分析過(guò) GQA 相比 MHA 的推理優(yōu)勢(shì)(省顯存、計(jì)算強(qiáng)度高),不過(guò) GQA 有可能導(dǎo)致精度的損失,因此早期的一些不太大的 LLM 會(huì)使用 MHA。針對(duì)這個(gè)問(wèn)題有兩種優(yōu)化思路:

  • 將 MHA 轉(zhuǎn)換為 GQA,長(zhǎng)短序列都適用。
  • 在長(zhǎng)序列場(chǎng)景使用 Token 稀疏化方案或者結(jié)合投機(jī)采樣策略。?

本文中我們介紹一個(gè)將 MHA 轉(zhuǎn)換為 GQA 的工作,不過(guò)論文的實(shí)驗(yàn)還偏少,效果也不是非常好;此外,最新的模型基本都在預(yù)訓(xùn)練階段默認(rèn)采用 GQA(LLaMA3 8B、LLaMA3.2 3B 以及 Microsoft 的 Phi 系列模型等),降低了本文工作的應(yīng)用場(chǎng)景。

對(duì)應(yīng)的論文:[2412.20677] Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA [1]

相關(guān)工作也可以參考我們以前的文章:

二、摘要

LLM 在多種自然語(yǔ)言處理任務(wù)中展現(xiàn)出卓越性能。然而,隨著模型規(guī)模與輸入序列長(zhǎng)度的增長(zhǎng),KV Cache 的急劇膨脹顯著拖慢了推理速度。鑒于此,作為 MHA 的替代方案,GQA 已被廣泛引入 LLM。本研究提出了一種低成本方法,可將 MHA 模型按任意 KV Head 壓縮比修剪為 GQA 模型。

該方法基于 L0 掩碼逐步剔除冗余參數(shù)。此外,在不改變模型的前提下,對(duì)注意力頭施加正交變換,以在修剪訓(xùn)練前提升 Attention Head 間的相似度,從而進(jìn)一步優(yōu)化模型性能。本方法兼容RoPE,意味著訓(xùn)練后的模型能完全適配主流標(biāo)準(zhǔn) GQA 框架。實(shí)驗(yàn)表明,僅通過(guò)監(jiān)督微調(diào),提出的策略即可將 LLaMA2-7B 模型的 KV Head 壓縮高達(dá) 87.5%,且性能損失極小。

三、引言

如下 3.1 和 3.2 部分在我們之前的文章中有相吸介紹:???LLM 推理的 Attention 計(jì)算和 KV Cache 優(yōu)化:PagedAttention、vAttention 等??。

3.1 MHA Attention 計(jì)算

如下圖所示為標(biāo)準(zhǔn)的 LLM Decoding 階段的 Multi-Head Attention(MHA)計(jì)算,其中的 D 表示 hidden size,H 表示 Head 個(gè)數(shù),L 表示當(dāng)前是在序列的第 L 個(gè) Token??梢钥闯觯?/p>

  • 當(dāng)Batch Size 為 1時(shí),圖中紅色、綠色、藍(lán)色處的矩陣乘法全部為矩陣乘向量,是明顯的 Memory Bound,算術(shù)強(qiáng)度不到 1。
  • 當(dāng)Batch Size 大于 1時(shí)(比如 Continuous Batching):
  • 紅色藍(lán)色部分:因?yàn)槭?Weight 乘以 Activation,所以不同的 Request 之間可以共享 Weight。這里變成矩陣乘矩陣,并且 Batch Size 越大,算術(shù)強(qiáng)度越大,也就越趨近于 Compute Bound(FFN 層也類(lèi)似)。
  • 綠色部分:這里 Q、K 和 V 的 Attention 計(jì)算,是 Activation 乘以 Activation,所以不同的 Request 之間沒(méi)有任何相關(guān)性。即使 Batching,這里也是Batched 矩陣乘向量,并且因?yàn)樾蛄虚L(zhǎng)度可能不同,這里不同 Request 的矩陣乘向量是不規(guī)則的。也就是說(shuō),這里算術(shù)強(qiáng)度始終不到 1,是明顯的 Memory Bound。

MHA -> GQA:提升 LLM 推理效率-AI.x社區(qū)

從上可以看出,通過(guò) Continuous Batching 可以很好的將 Memory Bound 問(wèn)題轉(zhuǎn)變?yōu)?Compute Bound,但 Q、K 和 V 的 Attention 計(jì)算的算術(shù)強(qiáng)度卻始終小于 1。根據(jù) Amdahl 法則,如果系統(tǒng)中有一部分無(wú)法優(yōu)化,即使把其他部分優(yōu)化到可以忽略,不可優(yōu)化的部分也會(huì)決定整個(gè)系統(tǒng)的性能上限。不幸的是,Sequence Length 越長(zhǎng),這里的計(jì)算量就越不可忽略。

根據(jù)模型配置信息可以估算出模型中 Q、K 和 V 的 Attention 計(jì)算與其他矩陣計(jì)算的比例大約為 (L+D)/(12*D)(PS:準(zhǔn)確值需要根據(jù)具體的模型參數(shù)計(jì)算)。也就是說(shuō),當(dāng)序列長(zhǎng)度 L 等于 12 倍的 hidden size 時(shí),兩部分的計(jì)算量相當(dāng),即使其他矩陣計(jì)算優(yōu)化到 0,加速比也只有 2x。比如 LLaMA 2 7B 的 hidden size 為 4K,當(dāng)序列長(zhǎng)度達(dá)到 44K 時(shí),兩部分的計(jì)算量相當(dāng),要優(yōu)化的重點(diǎn)也會(huì)很不一樣,這也是很多長(zhǎng)序列相關(guān)工作會(huì)在 Attention 部分采用稀疏 Attention 的一個(gè)重要原因。

3.2 GQA Attention 計(jì)算

早期通常只有比較大的模型才會(huì)采用 GQA([2305.13245] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints),比如 LLaMA -2 70B,而 LLaMA-2 7B/13B 都沒(méi)有采用 GQA。然而,LLaMA-3 8B 中也用上了 GQA,甚至其他更小的模型也在將 MHA 替換為 GQA。

  • 使用 GQA 有個(gè)非常大的好處:在推理階段可以顯著降低 KV Cache 的大小,比如,相比 32 個(gè) KV Head 的 MHA,32 個(gè) Query Head,8 個(gè) KV Head 的 GQA 的 KV Cache 大小可以降低到 MHA 的 8/32=1/4,這也為更大的 Batch Size 提供了空間,可以進(jìn)一步提升吞吐。
  • 除此之外,還有一個(gè)比較大的好處:可以明顯提升 Q、K 和 V 的 Attention 計(jì)算的算術(shù)強(qiáng)度。此時(shí)雖然不同的 Request 之間同樣不能共享,但是同一個(gè) Request 中的不同 Head 可以共享,比如 4 個(gè) Query Head 共享 1 個(gè) KV Head,則算術(shù)強(qiáng)度就會(huì)接近于 4,也可以更充分發(fā)揮 Tensor Core 的算力。

MHA -> GQA:提升 LLM 推理效率-AI.x社區(qū)

使用 MHA 時(shí),Q、K 和 V 的 Attention 計(jì)算可以使用 CUDA Core 也可以使用 Tensor Core。由于 Tensor Core 要求矩陣的 Shape 是 8 的整數(shù)倍,如果不滿(mǎn)足就只能 Padding:

  • 對(duì)于MHA而言,其是矩陣乘向量,則有7/8 的計(jì)算是冗余的。
  • 對(duì)于GQA而言,如果 4 個(gè) Query Head 共享 1 個(gè) KV Head,則 Attention 計(jì)算有 4/8 的計(jì)算是冗余的,如果8 個(gè) Query Head 共享 1 個(gè) KV Head,則沒(méi)有計(jì)算的冗余。很多框架已經(jīng)做了相關(guān)優(yōu)化,比如 LMDeploy,TRT-LLM 的 XQA 等。
  • 此外,PagedAttention 的 KV Cache 是非連續(xù)存儲(chǔ)的,導(dǎo)致即使使用 GQA 也無(wú)法利用 Tensor Core。

PS:對(duì)于 GQA 而言,理論上也可以期望 GPU 的 L2 Cache 能夠緩存到共享的 Key 和 Value Cache,從而緩解 IO Bound 問(wèn)題,然而實(shí)際上無(wú)法人為控制,不一定能達(dá)到理想的效果。

3.3 動(dòng)機(jī)

作者從 C4 訓(xùn)練集采樣了 128 個(gè) Sequence,共 128*2048=262144 個(gè) Token,評(píng)估了 LLaMA2-7B 模型中每個(gè) Transformer Block 中 Attention Head 的 KV Cache 的相似性。

如下圖 Figure 2 所示,分析發(fā)現(xiàn),大多數(shù) Head 之間的 KV Cache 幾乎是正交的,僅有少數(shù) Head 共享較高的相似度。這表明直接對(duì)投影矩陣進(jìn)行均值化會(huì)導(dǎo)致性能顯著下降,說(shuō)明 Attention Head 之間存在重要的獨(dú)特性。

MHA -> GQA:提升 LLM 推理效率-AI.x社區(qū)

根據(jù)之前 [2406.07056] Effectively Compress KV Heads for LLM [2] 的研究,KV Cache 的低秩性為優(yōu)化提供了新思路:

  • 可通過(guò)正交變換對(duì)齊 Key 和 Value 的投影矩陣。
  • 這種方法降低了優(yōu)化的難度,并為 MHA 轉(zhuǎn)換為 GQA 提供了理論支持。

四、方案

4.1 網(wǎng)絡(luò)轉(zhuǎn)換

主要目的是:在剪枝訓(xùn)練之前,對(duì)模型進(jìn)行轉(zhuǎn)換,以增加同一組內(nèi)不同 Attention Head 之間的相似性,從而提高模型優(yōu)化的效率。具體的過(guò)程大概為:

  • 根據(jù)前述的方案,使用部分 C4 的訓(xùn)練集來(lái)收集相應(yīng)的 KV Cache。
  • 基于余弦相似性或者歐氏距離,計(jì)算最優(yōu)的正交矩陣。
  • 將計(jì)算得到的正交矩陣融合到對(duì)應(yīng)的 Q、K、V 投影矩陣中,保證計(jì)算不變性。對(duì)于 Q 和 K 的投影矩陣,要考慮 RoPE 的場(chǎng)景,在子空間應(yīng)用正交變換。

通過(guò)正交變換,可以使得同一組內(nèi)不同 Attention Head 在特征空間中更加接近,從而在后續(xù)的剪枝訓(xùn)練過(guò)程中更容易找到合適的參數(shù)共享方式,提高模型的壓縮效果和性能。

如下圖 Figure 3 所示,作者展示了不同的 Block 中轉(zhuǎn)換前和轉(zhuǎn)換后的 KV Cache 相似性,可以看出,轉(zhuǎn)換后相似性明顯增加:

MHA -> GQA:提升 LLM 推理效率-AI.x社區(qū)

4.2 找到更好的分組方法

在獲取了每對(duì) Attention Head 之間的相似度評(píng)分后,可依據(jù)這些評(píng)分對(duì) Attention Head 進(jìn)行重新分組。將一個(gè)組的相似度評(píng)分定義為該組內(nèi)每對(duì) Attention Head 之間相似度評(píng)分的總和,而每種分組結(jié)果的總相似度評(píng)分則是所有組相似度評(píng)分的累加。

合理的分組方式可以使得同一組內(nèi)的 Attention Head 在特征空間中更加相似,從而在剪枝時(shí)更容易找到合適的參數(shù)共享方式,提高模型的壓縮效果和性能。

4.3 剪枝訓(xùn)練

主要目的是:通過(guò)剪枝訓(xùn)練,逐步將原始的 KV Head 轉(zhuǎn)移到新的 KV Head 上,同時(shí)保持模型性能。如下圖 Figure 1 所示,具體過(guò)程包括:

  • 添加新的投影矩陣:在每組內(nèi)使用 Mean Pooling  初始化新的投影矩陣。
  • 應(yīng)用 L0 掩碼:引入 L0 掩碼來(lái)控制原始 KV Head 和新 KV Head 之間的轉(zhuǎn)換。初始時(shí),掩碼值為 1,表示使用原始 KV Head;在剪枝過(guò)程中,逐步將掩碼值約束為 0,表示使用新的 KV Head。
  • 知識(shí)蒸餾:使用 KL 損失和 BiLD 損失,鼓勵(lì)學(xué)生模型與教師模型的輸出對(duì)齊,從而保持模型性能。

MHA -> GQA:提升 LLM 推理效率-AI.x社區(qū)

五、實(shí)驗(yàn)評(píng)估

如下圖所示,作者在多個(gè)任務(wù)上進(jìn)行評(píng)估,GQA-16(32 個(gè) KV Head 變?yōu)?16 個(gè)) 時(shí)平均精度甚至有所提升。但是 GQA-8(壓縮 4x)和 GQA-4(壓縮 8x)時(shí)損失就比較大:

MHA -> GQA:提升 LLM 推理效率-AI.x社區(qū)

六、參考鏈接

  1. ??https://arxiv.org/abs/2412.20677??
  2. ??https://arxiv.org/abs/2406.07056???

本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談

標(biāo)簽
已于2025-1-13 11:42:18修改
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦