大語(yǔ)言模型(LLM)在生成文本時(shí),通常是一個(gè) token 一個(gè) token 地進(jìn)行。每當(dāng)模型生成一個(gè)新的 token,它就會(huì)把這個(gè) token 加入輸入序列,作為下一步預(yù)測(cè)下一個(gè) token 的依據(jù)。這一過(guò)程不斷重復(fù),直到完成整個(gè)輸出。
然而,這種逐詞生成的方式帶來(lái)了一個(gè)問(wèn)題:每一步的輸入幾乎與前一步相同,只是多了一個(gè)新 token。如果不對(duì)計(jì)算過(guò)程進(jìn)行優(yōu)化,模型就不得不反復(fù)重復(fù)大量相同的計(jì)算,造成資源浪費(fèi)和效率低下。
為了解決這個(gè)問(wèn)題,KV-Cache(Key Value-Cache)應(yīng)運(yùn)而生,它是提升 LLM 推理性能的關(guān)鍵技術(shù)之一。
KV-Cache 的核心思想是緩存中間計(jì)算結(jié)果。在 Transformer 架構(gòu)中,每個(gè) token 在生成下一個(gè) token 時(shí)都需要通過(guò)自注意力機(jī)制計(jì)算一系列 Key 和 Value 向量。這些向量描述了當(dāng)前 token 與前面所有 token 的關(guān)系。
1. 關(guān)于自注意力機(jī)制
要理解KV緩存,首先需要掌握注意力機(jī)制的基本原理。在著名的論文《Attention Is All You Need》中,提出了用于Transformer模型的注意力機(jī)制公式,該公式幫助模型在生成每個(gè)新token時(shí)確定應(yīng)關(guān)注哪些先前的token。
注意力機(jī)制的核心在于通過(guò)一系列計(jì)算來(lái)量化不同token之間的關(guān)聯(lián)強(qiáng)度。具體來(lái)說(shuō),當(dāng)模型生成一個(gè)新的token時(shí),它會(huì)根據(jù)輸入序列中的所有token計(jì)算出一組Query(查詢)、Key(鍵)和Value(值)向量。這些向量通過(guò)特定的公式相互作用,以決定當(dāng)前上下文中每個(gè)token的重要性。
讓我們一步一步來(lái)看看這個(gè)方程是怎么在工程上實(shí)現(xiàn)的。
2. 提示詞階段(預(yù)填充階段)
我們從一個(gè)示例輸入開(kāi)始,稱為“提示詞”(prompt)。這個(gè)輸入首先會(huì)被分詞器(tokenizer)切分為一個(gè)個(gè)的 token,也即模型可處理的基本單位。例如,短語(yǔ) "The quick brown fox" 在使用 OpenAI 的 o200kbase 分詞器時(shí)會(huì)被拆分為 4 個(gè) token。隨后,每個(gè) token 被轉(zhuǎn)換為一個(gè)嵌入向量,記作 x? 到 x?,這些向量的維度由模型決定,通常表示為 d_model
。在原始 Transformer 論文中,d_model 設(shè)定為 512。
為了高效計(jì)算,我們可以將所有 n 個(gè)嵌入向量堆疊成一個(gè)矩陣 X,其形狀為 [n × d_model
]。接下來(lái),為了執(zhí)行注意力機(jī)制,我們需要通過(guò)三個(gè)可學(xué)習(xí)的投影矩陣 Wq、Wk 和 Wv 來(lái)分別生成查詢(Query)、鍵(Key)和值(Value)向量。
具體來(lái)說(shuō):
- Wq 的形狀為 [
d_model × d_k
],用于將嵌入向量映射到查詢空間,得到 q 矩陣,其形狀為 [n × d_k
]; - Wk 的形狀也為 [
d_model × d_k
],生成 k 矩陣,形狀為 [n × d_k
]; - Wv 的形狀為 [
d_model × d_v]
,生成 v 矩陣,形狀為 [n × d_v
]。
這三個(gè)矩陣在訓(xùn)練過(guò)程中不斷優(yōu)化,以捕捉不同 token 之間的依賴關(guān)系。其中,dk 和 dv 是設(shè)計(jì)模型結(jié)構(gòu)時(shí)設(shè)定的超參數(shù),在最初的 Transformer 模型中被設(shè)為 64。
這種線性變換不僅將高維嵌入壓縮到更易操作的空間,還保留了關(guān)鍵語(yǔ)義信息,為后續(xù)的注意力計(jì)算奠定了基礎(chǔ)。
圖片
接下來(lái),我們通過(guò)計(jì)算查詢矩陣 q 與其對(duì)應(yīng)鍵矩陣 k 的轉(zhuǎn)置的乘積,得到一個(gè)大小為 [n × n
] 的自注意力分?jǐn)?shù)矩陣。這個(gè)矩陣中的每個(gè)元素代表了某個(gè)查詢向量與所有鍵向量之間的相似性得分,反映了在生成當(dāng)前 token 時(shí),模型應(yīng)將多少注意力分配給前面每一個(gè) token。
對(duì)于解碼器類型的大型語(yǔ)言模型(LLM),為了避免模型在生成過(guò)程中“偷看”未來(lái)的信息,我們采用一種稱為“掩碼自注意力”的機(jī)制——即將該矩陣的上三角部分設(shè)為負(fù)無(wú)窮(-inf)。這樣一來(lái),在后續(xù)的 softmax 操作中,這些位置的值會(huì)趨近于零,從而確保每個(gè) token 在預(yù)測(cè)時(shí)只能關(guān)注到它之前的歷史 token,而不會(huì)看到未來(lái)的輸入。
處理后的自注意力分?jǐn)?shù)矩陣因此成為一個(gè)下三角結(jié)構(gòu),體現(xiàn)了嚴(yán)格的因果關(guān)系。由于這一操作,模型在生成序列的過(guò)程中能夠保持邏輯連貫性和時(shí)間順序的正確性。
為了得到最終的注意力輸出,我們首先對(duì)這個(gè)帶有掩碼的注意力分?jǐn)?shù)矩陣進(jìn)行縮放,除以 sqrt(d_k)以防止點(diǎn)積結(jié)果過(guò)大導(dǎo)致梯度飽和;然后應(yīng)用 softmax 函數(shù)將其轉(zhuǎn)化為概率分布;最后將該分布與值矩陣 v 相乘,得到加權(quán)聚合后的上下文信息。這一過(guò)程即完成了對(duì)輸入序列 "The quick brown fox" 的注意力輸出計(jì)算,為下一步的 token 預(yù)測(cè)提供了基礎(chǔ)。
圖片
3. 自回歸生成階段(解碼階段)
現(xiàn)在,當(dāng)我們生成下一個(gè) token(例如 "jumps")并計(jì)算其對(duì)應(yīng)的注意力輸出時(shí),在自回歸生成階段(即解碼過(guò)程中),模型是如何工作的呢?
在生成“jumps”這一新 token 時(shí),模型會(huì)為它生成新的查詢向量 q?、鍵向量 k? 和值向量 v?。但值得注意的是,并不需要重新計(jì)算之前所有 token 的 k 和 v,因?yàn)檫@些中間結(jié)果已經(jīng)保存在 KV-Cache 中。
圖片
如圖所示,我們只需關(guān)注注意力矩陣中新增的最后一行,其余部分可以從緩存中直接復(fù)用。圖中以灰色突出顯示的 k? 到 k? 和 v? 到 v? 都是之前步驟中已計(jì)算并緩存的結(jié)果。這意味著,只有與當(dāng)前 token 相關(guān)的 q?、k? 和 v? 是新生成的。
最終,第 n 個(gè) token 的注意力輸出只需以下計(jì)算(忽略 softmax 和縮放因子以便理解):
圖片
這正是 KV-Cache 的價(jià)值所在:通過(guò)緩存先前 token 的 Key 和 Value 向量,模型無(wú)需重復(fù)計(jì)算整個(gè)歷史序列,從而顯著減少冗余運(yùn)算,提升推理效率。
隨著生成序列變長(zhǎng),KV-Cache 的作用愈發(fā)重要——它不僅減少了計(jì)算負(fù)擔(dān),也降低了內(nèi)存帶寬的壓力,使得 LLM 在實(shí)際應(yīng)用中能夠?qū)崿F(xiàn)高效、流暢的文本生成。
4. 速度與內(nèi)存的權(quán)衡
KV-Cache 的引入顯著提升了大型語(yǔ)言模型(LLM)的推理速度。其核心思想在于緩存每一步計(jì)算生成的 Key 和 Value 向量,使得在生成新 token 時(shí),模型無(wú)需重復(fù)計(jì)算歷史上下文中的 K 和 V 值,從而大幅減少冗余計(jì)算,加快響應(yīng)生成。
然而,這種性能提升也帶來(lái)了內(nèi)存上的代價(jià)。由于 KV-Cache 需要為每個(gè)已生成的 token 保存對(duì)應(yīng)的 K 和 V 向量,它會(huì)持續(xù)占用 GPU 顯存。對(duì)于本身就需要大量資源的 LLM 來(lái)說(shuō),這進(jìn)一步加劇了顯存壓力,尤其是在處理長(zhǎng)序列時(shí)更為明顯。
因此,在實(shí)際應(yīng)用中存在一個(gè)權(quán)衡:一方面,使用 KV-Cache 可以加速生成過(guò)程,但另一方面,它也會(huì)增加內(nèi)存消耗。當(dāng)顯存緊張時(shí),開(kāi)發(fā)者可能需要選擇犧牲部分生成速度來(lái)節(jié)省內(nèi)存,或接受更高的硬件資源開(kāi)銷以換取更快的推理表現(xiàn)。這種性能與資源之間的平衡,是部署 LLM 時(shí)必須仔細(xì)考量的關(guān)鍵因素之一。
5. KV-Cache 實(shí)踐
KV-Cache 是現(xiàn)代大型語(yǔ)言模型(LLM)推理引擎中至關(guān)重要的一項(xiàng)優(yōu)化技術(shù)。在 Hugging Face 的 Transformers 庫(kù)中,當(dāng)你調(diào)用 model.generate()
函數(shù)生成文本時(shí),默認(rèn)會(huì)啟用 use_cache=True
參數(shù),也就是自動(dòng)使用 KV-Cache 來(lái)提升生成效率。
類似的技術(shù)細(xì)節(jié)在 Hugging Face 官方博客的文章《KV Caching Explained: Optimizing Transformer Inference Efficiency》中也有深入解析。該文通過(guò)實(shí)驗(yàn)展示了啟用 KV-Cache 帶來(lái)的顯著加速效果:在 T4 GPU 上對(duì)模型 HuggingFaceTB/SmoLLM2-1.7B 進(jìn)行測(cè)試,使用 KV-Cache 相比不使用時(shí),推理速度提升了 5.21 倍。
這一性能提升充分說(shuō)明了 KV-Cache 在實(shí)際應(yīng)用中的重要性,測(cè)試的參考代碼如下:
from transformers import AutoTokenizer, AutoModelForCausaLLM
import torch
import time
# Choose model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausaLLM.from_pretrained("gpt2")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Prepare input
input_sentence = "The red cat was"
inputs = tokenizer(input_sentence, return_tensors="pt").to(device)
# Function to measure generation time
def generate_and_time(use_cache: bool):
torch.cuda.empty_cache()
start_time = time.time()
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=300,
use_cache=use_cache
)
end_time = time.time()
duration = end_time - start_time
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return duration, output_text
# Measure with KV-cache enabled
time_with_cache, text_with_cache = generate_and_time(use_cache=True)
print(f"\n[use_cache=True] Time taken: {time_with_cache:.4f} seconds")
print(f"Generated Text:\n{text_with_cache}\n")
# Measure with KV-cache disabled
time_without_cache, text_without_cache = generate_and_time(use_cache=False)
print(f"[use_cache=False] Time taken: {time_without_cache:.4f} seconds")
print(f"Generated Text:\n{text_without_cache}\n")
# Speedup factor
speedup = time_without_cache / time_with_cache
print(f"Speedup from using KV-cache: {speedup:.2f}×")
KV-Cache的運(yùn)行速度實(shí)際上受到多種因素的綜合影響,其中包括模型的規(guī)模(具體體現(xiàn)在注意力層數(shù)的多少)、輸入文本的長(zhǎng)度n、所使用的硬件設(shè)備以及具體的實(shí)現(xiàn)細(xì)節(jié)等。
6. 小結(jié)
KV-cache作為一種極為強(qiáng)大的性能優(yōu)化手段,能夠顯著提升語(yǔ)言模型(LLM)生成文本的速度。其核心機(jī)制在于,在生成文本的過(guò)程中,通過(guò)重用前面步驟中的注意力計(jì)算結(jié)果,避免重復(fù)計(jì)算,從而實(shí)現(xiàn)更高效的文本生成。具體而言,當(dāng)計(jì)算下一個(gè)標(biāo)記的注意力輸出時(shí),系統(tǒng)會(huì)緩存并重用之前步驟中所產(chǎn)生的鍵和值。
不過(guò),需要指出的是,這種性能上的提升并非沒(méi)有代價(jià)。由于需要存儲(chǔ)這些向量,會(huì)占用大量的GPU內(nèi)存資源,而這些被占用的內(nèi)存就無(wú)法再用于其他任務(wù)了。
【參考閱讀與關(guān)聯(lián)閱讀】