多頭潛在注意力:手把手用數(shù)學公式推導(dǎo)
多頭潛在注意力,這可是個大家伙,得好好聊聊!
想象一下,你正在處理一堆復(fù)雜的數(shù)據(jù),就像是面對著一座混亂不堪的寶藏山,每個數(shù)據(jù)點都是一塊閃閃發(fā)光的寶石,但它們卻雜亂無章地堆放在一起。這時,多頭潛在注意力機制就像是一位身懷絕技的探險家,它帶著一堆神奇的分身(也就是那些“頭”),準備深入這座寶藏山,尋找那些隱藏的寶藏。
這些分身,哦不,這些“頭”們,每個都有自己獨特的視角和技能。它們會將這座混亂的寶藏山分割成多個小塊,每個小塊都由一個“頭”負責探索。每個“頭”都會在自己的潛在空間里自由翱翔,尋找那些與當前任務(wù)最相關(guān)的特征信息。這就像是在玩一場尋寶游戲,每個“頭”都在努力尋找自己的線索。
當所有的“頭”都找到了自己的寶藏后,它們就會將這些寶藏(也就是那些加權(quán)后的輸出)收集起來,然后通過一個神奇的整合儀式(也就是那個線性變換層),將這些寶藏融合成一個完整的寶藏圖。這張寶藏圖不僅包含了所有“頭”們的發(fā)現(xiàn),還以一種更加有序和易于理解的方式呈現(xiàn)了出來。
所以,你看,多頭潛在注意力機制就像是那位身懷絕技的探險家和它的分身們一起,通過分工合作和集體智慧,成功地探索了那座混亂的寶藏山,找到了一份珍貴的寶藏圖。而這份寶藏圖,就是我們所需要的、更加準確和豐富的數(shù)據(jù)表示。
怎么樣?通過這樣幽默風趣的語言和深入淺出的講解,你是不是對多頭潛在注意力機制有了更加清晰的認識呢?
減兵而不減勢:多頭潛在注意力
想象一下,你正在看一場盛大的魔術(shù)表演,魔術(shù)師手里拿著一疊撲克牌,準備給你展示一場令人驚嘆的魔術(shù)。這時,多頭潛在注意力機制就像那位魔術(shù)師,而撲克牌就是我們的輸入序列。魔術(shù)師(多頭潛在注意力機制)會將這疊撲克牌(輸入序列)分成好幾摞(小塊),每摞都由他的一個助手(專家網(wǎng)絡(luò))來處理。
這些助手(專家網(wǎng)絡(luò))啊,可都是魔術(shù)師(多頭潛在注意力機制)精心挑選的,每個都有自己的獨門絕技。他們會仔細觀察自己手里的那摞撲克牌(小塊輸入序列),找出其中最關(guān)鍵的幾張牌(特征信息),然后施展魔法(進行處理),將這些牌變成更加炫酷的魔術(shù)道具(加權(quán)后的輸出)。
當所有的助手(專家網(wǎng)絡(luò))都完成自己的任務(wù)后,魔術(shù)師(多頭潛在注意力機制)就會將這些魔術(shù)道具(加權(quán)后的輸出)收集起來,然后通過一個神奇的魔法陣(線性變換層),將這些道具融合成一個超級炫酷的魔術(shù)表演(最終的輸出)。
你看,多頭潛在注意力機制就像那位魔術(shù)師一樣,通過分工合作和集體智慧,成功地將一堆普通的撲克牌變成了一場令人驚嘆的魔術(shù)表演。而這場魔術(shù)表演的背后,其實是多個專家網(wǎng)絡(luò)在共同工作,通過動態(tài)地選擇和加權(quán),將輸入序列中的關(guān)鍵信息提取出來,并融合成一個更加豐富和準確的數(shù)據(jù)表示。
所以,下次當你再看到多頭潛在注意力機制時,不妨想象一下那位魔術(shù)師和他的助手們正在為你上演一場精彩的魔術(shù)表演吧!這樣,你就能更加輕松地理解這個復(fù)雜而有趣的機制了。
多頭注意力
我們在聊多頭潛在注意力的時候,當然不能忘了多頭注意力。
- 一、多頭注意力:讓模型“多管齊下”的絕技
我們先設(shè)想一下你和三位朋友走進一場相親大會,目標是幫你找對象(當然我不是催婚啦)。
在這個看臉的時代,你當然首先專注的是顏值啦。不過聰明如你,當然不希望僅僅被顏值吸引,所以聰明的你請來了三個朋友:
朋友A看學歷和家庭背景;
朋友B看性格和興趣愛好;
朋友C看長遠發(fā)展。
這樣,每個人都有自己負責關(guān)注一個維度,你們就這樣信心百倍的走進了相親現(xiàn)場。這就是 多頭注意力(Multi-Head Attention) 的核心思路:“不同的注意力頭,關(guān)注不同的重點。最后,你們坐下來,把各自的觀察結(jié)果一合計——完美!這比你一個人單打獨斗靠譜多了!

IMG_256
??換成模型的說法:
一個“注意力頭”關(guān)注詞和詞的關(guān)系(比如主謂關(guān)系);
一個頭專門看情感走向(這句話開心還是生氣?);
一個頭盯著長遠關(guān)系(比如“前男友”這個詞和結(jié)尾的“后悔”有關(guān));
最后,把這些不同的視角綜合起來,模型就能給出更聰明的判斷。
多頭潛在注意力
還記得剛才的相親大會嗎?
這次你帶的朋友都太“聰明”了,他們不直接告訴你觀察結(jié)果,而是自己先偷偷開了個小會,先互相交流一下誰的觀察結(jié)果更有用。
朋友A說:“我覺得性格最重要,這一輪就盯性格好了?!?/p>
朋友B說:“算了,家庭背景這次先不管,我們看性格和顏值。”
他們根據(jù)現(xiàn)場的情況動態(tài)調(diào)整每個人的關(guān)注重點,而不是一開始就分工固定。
這就是所謂的 潛在多頭注意力:“我不直接分配任務(wù),而是讓每個注意力頭自己決定最值得關(guān)注的方向?!?/p>

換成模型的說法:每個注意力頭自己“思考”當前任務(wù)下,應(yīng)該重點觀察哪些方面,而不是固定關(guān)注點。
這樣模型變得更加靈活,不會死板地每次都盯著同一類信息。

小結(jié)
讓我們對比一下,如:
概念 | 生活比喻 | 模型里的作用 |
多頭注意力 | 帶朋友去相親,每人盯一個重點 | 每個“頭”關(guān)注不同特征,提高模型看問題的全面性 |
潛在多頭注意力 | 朋友們先開小會,再決定誰關(guān)注什么 | 模型動態(tài)調(diào)整注意力,更靈活、更聰明 |
一句話總結(jié)就是:
多頭注意力讓模型“分頭行動”,潛在多頭注意力讓模型“會看場合調(diào)整戰(zhàn)術(shù)”!

紙上推演:多頭潛在注意力的數(shù)學推演
在探索多頭潛在注意力的數(shù)學推演之路上,我們即將啟程,深入理解這一復(fù)雜而迷人的領(lǐng)域。多頭潛在注意力模型,作為一種先進的深度學習架構(gòu),它在處理序列數(shù)據(jù)、圖像識別、自然語言處理等眾多領(lǐng)域中展現(xiàn)出了卓越的性能。我們通過結(jié)合 Excel 表格進行推演,揭開其背后的原理,理解其如何通過多頭機制捕捉數(shù)據(jù)中的豐富信息,以及如何通過潛在空間的變換來增強模型的表達能力。
下圖來自于多頭潛在注意力論文中,描述的是多頭潛在注意力的機制。

下面讓我們先看看我們的任務(wù)背景:
我們繼續(xù)依據(jù)之前的例子,輸入序列中包括6 個 Token,每個是 5 維向量。潛在向量有4 個 ,可學習潛在向量有4個,每個是 5 維,頭的數(shù)量是2個。
每頭維度 dk=dv=2 (因為我們把 5 維分成 2 個頭)
步驟1: 輸入隱藏狀態(tài)
輸入: h_t ∈ ?^{T × d}

步驟 2: 初始化計算latent的權(quán)重

計算


wpsoffice

同樣,我們計算出


步驟 3: RoPE計算
對latent進行線性變換,得到注意力組件:
RoPE的計算如下:

wpsoffice
為了計算

,我們采用下面的公式:

wpsoffice
我們先計算R1-R6:

/Users/i/Library/Containers/com.kingsoft.wpsoffice.mac/Data/tmp/wpsoffice.KixErtwpsoffice
R1 的上下左右4個腳正是上圖中對應(yīng)的值:

θ我們這分別取10,20,30,40.50,60。
計算Query 的潛在向量,公式如下:

wpsoffice
下面是我們在 Excel中計算


位置向量(rotary):

wpsoffice


RoPE(Key)的計算方式類似,我們就不一一計算了,在此一并給出計算結(jié)果:

步驟4:: 單頭注意力計算

我們先計算頭1的值。首先,我們計算頭1的

,Excel中的公式為:
=MMULT(Q47:S50,V16#)

同理,我們計算出


,最終結(jié)果如圖:

步驟5:

和

聯(lián)合起來,得到以下矩陣:

步驟6: 計算注意力

/Users/i/Library/Containers/com.kingsoft.wpsoffice.mac/Data/tmp/wpsoffice.JiORRtwpsoffice
這里的dk指的是:單個注意力頭(head)中,Query 和 Key 向量的特征維度大小。
也就是說:我們整體模型維度d是6,注意力有3個頭h(heads),

/Users/i/Library/Containers/com.kingsoft.wpsoffice.mac/Data/tmp/wpsoffice.rWNXPIwpsoffice
在我們這個例子中,dk = 2

用同樣的方法,我們計算出其他頭:


步驟7: 連接所有頭

這是連接所有頭后,計算輸出的公式:

wpsoffice
這是我們 Excel 中的計算:

如此就完成了我們多頭潛在注意里的計算。怎么樣,是不是感覺收獲滿滿?
以少勝多:多頭潛在注意力的代碼實現(xiàn)
接下來用 PyTorch 實現(xiàn)多頭潛在注意力。
多頭潛在注意力的代碼實現(xiàn)邏輯其實并不復(fù)雜,主要步驟包括初始化參數(shù)、計算潛在向量、應(yīng)用RoPE變換、計算單頭注意力、將多個頭的輸出進行拼接等。下面,我們就來一步步揭開它的神秘面紗。
首先,我們需要定義一個類來實現(xiàn)多頭潛在注意力機制,比如叫??MultiHeadLatentAttention??。在這個類中,我們需要初始化一些必要的參數(shù),比如頭的數(shù)量、輸入和輸出的維度、可學習的潛在向量等。
然后,我們來實現(xiàn)前向傳播函數(shù)。在這個函數(shù)中,我們首先需要對輸入進行線性變換,得到Query、Key和Value。接著,我們計算每個頭的潛在向量,并應(yīng)用RoPE變換。然后,我們按照標準的多頭注意力機制來計算每個頭的注意力得分,并將這些得分進行softmax歸一化。
接下來用歸一化后的得分對Value進行加權(quán)求和,得到每個頭的輸出。最后,我們將所有頭的輸出進行拼接,并通過一個線性變換層得到最終的輸出。
在代碼中,我們還需要注意一些細節(jié),比如如何保持維度的一致性、如何高效地計算等。不過,只要理解了多頭潛在注意力機制的基本原理,這些細節(jié)問題就迎刃而解了。
現(xiàn)在,你是不是已經(jīng)迫不及待地想要看看具體的代碼實現(xiàn)了呢?別急,下面我們就來給出完整的代碼實現(xiàn),并逐行進行解釋。相信通過這段代碼,你一定能更加深入地理解多頭潛在注意力機制的實現(xiàn)原理。
代碼實現(xiàn)
我們先來修改 MultiHeadLatentAttention ,用它來輸出注意力權(quán)重:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadLatentAttention(nn.Module):
def __init__(self, input_dim, latent_dim, num_latents, num_heads=1, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = latent_dim // num_heads
assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
self.latent = nn.Parameter(torch.randn(num_latents, latent_dim)) # 改為英文命名
self.q_proj = nn.Linear(latent_dim, latent_dim)
self.k_proj = nn.Linear(input_dim, latent_dim)
self.v_proj = nn.Linear(input_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, latent_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
batch_size = x.size(0)
latent = self.latent.unsqueeze(0).expand(batch_size, -1, -1) # 修改為英文變量名
Q = self.q_proj(latent)
K = self.k_proj(x)
V = self.v_proj(x)
def reshape(t):
return t.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
Q, K, V = map(reshape, (Q, K, V))
# 計算注意力分數(shù)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
attended = torch.matmul(self.dropout(attn_weights), V)
# 合并頭
attended = attended.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
output = self.out_proj(attended)
# 返回輸出和注意力權(quán)重(去除 batch 維度)
return output, attn_weights.mean(dim=0) # 或者返回 attn_weights.detach().cpu() 用于可視化MultiHeadLatentAttention 逐行解釋
? 類定義和初始化
??class MultiHeadLatentAttention(nn.Module):??
定義一個繼承自 torch.nn.Module 的神經(jīng)網(wǎng)絡(luò)模塊:
def __init__(self, input_dim, latent_dim, num_latents, num_heads=1, dropout=0.1):
super().__init__()參數(shù)含義:
- input_dim: 輸入特征維度(序列的特征大小)。
- latent_dim: 潛在變量特征維度(通常等于隱藏層維度)。
- num_latents: 潛在向量的數(shù)量,用于注意力摘要。
- num_heads: 注意力頭的數(shù)量。
- dropout: Dropout 比例,防止過擬合。
self.num_heads = num_heads
self.head_dim = latent_dim // num_heads
assert latent_dim % num_heads == 0計算每個注意力頭的維度 head_dim。
保證 latent_dim 能被 num_heads 整除,以便均勻拆分維度:
??self.latent = nn.Parameter(torch.randn(num_latents, latent_dim))??
初始化潛在向量,形狀 [num_latents, latent_dim]。
關(guān)鍵點:這些是可學習參數(shù),用于從輸入中提取摘要信息。
self.q_proj = nn.Linear(latent_dim, latent_dim)
self.k_proj = nn.Linear(input_dim, latent_dim)
self.v_proj = nn.Linear(input_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, latent_dim)
self.dropout = nn.Dropout(dropout)定義線性投影層:
- q_proj: 將潛在向量映射為 Query。
- k_proj: 將輸入映射為 Key。
- v_proj: 將輸入映射為 Value。
- out_proj: 將多頭注意力結(jié)果映射回潛在空間。
Dropout 防止過擬合。
? 前向傳播 Forward
def forward(self, x):
batch_size = x.size(0)獲取輸入的 batch 大小。??latent = self.latent.unsqueeze(0).expand(batch_size, -1, -1)??
將潛在向量擴展到當前 batch 大小,形狀變?yōu)?[batch_size, num_latents, latent_dim]。
Q = self.q_proj(latent)
K = self.k_proj(x)
V = self.v_proj(x)計算 Query、Key、Value:
Q: 從潛在向量中得到。
K 和 V: 從輸入序列中得到。
def reshape(t):
return t.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)將張量 reshape 為多頭注意力所需的形狀:
從 [batch_size, seq_len, latent_dim] → [batch_size, num_heads, seq_len, head_dim]
Q, K, V = map(reshape, (Q, K, V))應(yīng)用 reshape 到 Q、K、V。
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)計算注意力分數(shù):
公式:Q @ K^T / sqrt(head_dim)
縮放因子防止梯度爆炸。
??attn_weights = F.softmax(attn_scores, dim=-1)??
通過 softmax 獲取注意力權(quán)重。
??attended = torch.matmul(self.dropout(attn_weights), V)??
根據(jù)注意力權(quán)重計算上下文向量(信息聚合)。
??attended = attended.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)??
將多頭結(jié)果重新拼接回 [batch_size, num_latents, latent_dim]。
??output = self.out_proj(attended)??
最終線性映射到 latent_dim 空間。
??return output, attn_weights.squeeze(0)??
返回:
- output: 潛在向量的更新結(jié)果。
- attn_weights: 注意力權(quán)重 [num_heads, num_latents, seq_len],可用于可視化注意力分布。
核心原理總結(jié)
用固定數(shù)量的潛在向量代替直接對長序列計算注意力,減少計算量。
每個潛在向量通過注意力機制從輸入序列中摘要關(guān)鍵信息。
類似于 Perceiver、Set Transformer 中的 Inducing Points 思路,高效且適用于大規(guī)模輸入。
代碼應(yīng)用示例
為 Latent指定語義標簽
latent_labels = ["Subject", "Verb", "Object", "Time", "Emotion", "Action"]
可視化 注意力權(quán)重(以熱圖顯示)
導(dǎo)入繪圖庫 matplotlib 和 seaborn,用于繪制熱力圖
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attn_weights, Token_ids, latent_labels, id_to_Token):可視化多頭注意力權(quán)重的熱力圖。
參數(shù)說明:
- attn_weights: 注意力權(quán)重張量,形狀 [num_heads, num_latents, seq_len]
- Token_ids: 輸入序列的 Token ID 列表,長度為 seq_len
- latent_labels: 潛在向量的標簽列表,長度為 num_latents
- id_to_Token: 字典,用于將 Token_id 轉(zhuǎn)換成實際的 Token 文本
"""
# 獲取注意力頭的數(shù)量
num_heads = attn_weights.shape[0]
# 將 Token_ids 轉(zhuǎn)換為對應(yīng)的可讀 Token 文本
Tokens = [id_to_Token[idx] for idx in Token_ids]
# 遍歷每一個注意力頭,分別繪制熱力圖
for h in range(num_heads):
# 新建一個圖像窗口,設(shè)置大小
plt.figure(figsize=(10, 6))
# 繪制當前注意力頭的熱力圖
sns.heatmap(
attn_weights[h].detach().numpy(), # 將當前注意力頭的張量轉(zhuǎn)成 NumPy 數(shù)組
xticklabels=Tokens, # X 軸標簽為輸入的 Token 文本
yticklabels=latent_labels, # Y 軸標簽為潛在向量標簽
cmap="viridis", # 配色風格為 "viridis"
annot=True, # 在熱力圖上顯示具體數(shù)值
fmt=".2f"# 數(shù)值保留兩位小數(shù)
)
# 設(shè)置圖表標題,標明當前是第幾個注意力頭
plt.title(f"Attention Heatmap (Head {h})")
# 設(shè)置 X 軸和 Y 軸的標簽
plt.xlabel("Input Tokens")
plt.ylabel("潛在Vectors")
# 顯示圖表
plt.show()將這一切連起來:
# 定義輸入句子
sentence = ["i", "drink", "and", "know", "things"]
# 將單詞轉(zhuǎn)為對應(yīng)的 Token ID(假設(shè) Token_to_id 是一個字典)
Token_ids = [Token_to_id[t] for t in sentence]
# 將 Token ID 列表轉(zhuǎn)換成張量,并增加 batch 維度(形狀變?yōu)?[1, seq_len])
Token_tensor = torch.tensor(Token_ids).unsqueeze(0)
# 將 Token_tensor 中的 Token IDs 映射到對應(yīng)的嵌入向量(假設(shè) embedding_tensor 已經(jīng)預(yù)定義)
# 結(jié)果 embedded 的形狀為 [1, seq_len, embedding_dim]
embedded = embedding_tensor[Token_tensor]
# 創(chuàng)建 MultiHeadLatentAttention 模型實例
# input_dim = embedding_dim(單詞嵌入的維度)
# latent_dim = 2(潛在空間維度,通常不這么低,這里是演示用)
# num_latents = 6(使用 6 個潛在向量)
# num_heads = 1(單頭注意力)
mhla = MultiHeadLatentAttention(input_dim=embedding_dim, latent_dim=2, num_latents=6, num_heads=1)
# 執(zhí)行前向傳播
# embedded 輸入形狀:[batch_size=1, seq_len, embedding_dim]
# output: 更新后的潛在向量表示 [1, num_latents, latent_dim]
# attn_weights: 注意力權(quán)重 [num_heads, num_latents, seq_len]
output, attn_weights = mhla(embedded)
# 使用可視化函數(shù)展示注意力分布
# latent_labels 是潛在向量的名稱或編號(如 ["L1", "L2", ..., "L6"])
# id_to_Token 是 Token ID 到單詞的映射字典
visualize_attention(attn_weights, Token_ids, latent_labels, id_to_Token)小結(jié)
多頭潛在注意力的創(chuàng)新在于它結(jié)合了多頭注意力和潛在向量的思想,實現(xiàn)了對長序列的高效處理。通過固定數(shù)量的潛在向量來代表輸入序列的關(guān)鍵信息,多頭注意力機制能夠捕捉不同方面的依賴關(guān)系。
這種方法不僅減少了計算量,還提高了模型的泛化能力,使其能夠處理更復(fù)雜的任務(wù)。
此外,通過為潛在向量指定語義標簽,可以進一步增強模型的可解釋性,使得我們能夠更好地理解模型是如何從輸入序列中提取關(guān)鍵信息的。
總之,多頭潛在注意力是一種高效且強大的注意力機制,為自然語言處理等領(lǐng)域的研究和應(yīng)用提供了新的思路和方法。
文本轉(zhuǎn)載自 ???AI大模型世界??,作者:roclv

















