Differential Transformer:通過(guò)差分注意力機(jī)制提升大語(yǔ)言模型性能
Transformer模型已經(jīng)成為大語(yǔ)言模型(LLMs)的標(biāo)準(zhǔn)架構(gòu),但研究表明這些模型在準(zhǔn)確檢索關(guān)鍵信息方面仍面臨挑戰(zhàn)。今天介紹一篇名叫Differential Transformer的論文,論文的作者觀察到一個(gè)關(guān)鍵問(wèn)題:傳統(tǒng)Transformer模型傾向于過(guò)分關(guān)注不相關(guān)的上下文信息,這種"注意力噪聲"會(huì)影響模型的性能。
在這篇論文中,作者注意到transformer模型傾向于關(guān)注不相關(guān)的上下文。為了放大相關(guān)上下文的注意力分?jǐn)?shù),他們提出了一個(gè)新的注意力模型,稱為差分注意力模型。在這個(gè)模型中,他們將查詢和鍵值向量分成兩組,并計(jì)算兩個(gè)子注意力分?jǐn)?shù)。
差分注意力機(jī)制
差分注意力機(jī)制(Differential Attention)的核心思想是通過(guò)計(jì)算兩個(gè)獨(dú)立的注意力圖譜之差來(lái)消除注意力噪聲。這種設(shè)計(jì)借鑒了電氣工程中差分放大器的原理,通過(guò)對(duì)比兩個(gè)信號(hào)的差異來(lái)消除共模噪聲。
讓我們看看論文中的第一個(gè)方程:
方程(1)
方程(1)顯示,我們首先像標(biāo)準(zhǔn)注意力計(jì)算一樣計(jì)算Q、K和V張量。關(guān)鍵點(diǎn)是我們將Q和K張量分成Q1、Q2和K1、K2子張量。
論文中輸入X、Q1、Q2、K1、K2和V張量的形狀
根據(jù)論文,Q和K張量的形狀應(yīng)該是Nx2d,因?yàn)镼1、Q2、K1和K2將是Nxd。輸入X的形狀是Nxd_model,這是論文中的嵌入維度。這就是為什么W_Q、W_K和W_V的可學(xué)習(xí)參數(shù)的形狀必須是d_modelx2d。
論文中用于lambda計(jì)算的方程(2)
方程(2)展示了如何計(jì)算可學(xué)習(xí)參數(shù)lambda。在這個(gè)方程中有一個(gè)初始lambda參數(shù)。lambda是一個(gè)標(biāo)量參數(shù),但lambda_q1、lambda_k1、lambda_q2和lambda_k2是向量。這一點(diǎn)很關(guān)鍵。向量lambda_q和lambda_k的運(yùn)算是點(diǎn)積。
用于lambda初始化的方程(3)
實(shí)驗(yàn)結(jié)果與性能提升
論文的實(shí)驗(yàn)表明,相比傳統(tǒng)Transformer:
DIFF Transformer只需要約65%的模型參數(shù)量即可達(dá)到相同的性能,在訓(xùn)練token數(shù)量方面也只需要約65%就能達(dá)到相同效果
在Needle-In-A-Haystack測(cè)試中:4K上下文長(zhǎng)度:DIFF Transformer在多目標(biāo)檢索任務(wù)中保持85%準(zhǔn)確率;64K上下文長(zhǎng)度:在深度為25%的位置檢測(cè)時(shí),比傳統(tǒng)Transformer提升了76%的準(zhǔn)確率
Python實(shí)現(xiàn)
下面我們根據(jù)論文的公式來(lái)做一個(gè)簡(jiǎn)單的實(shí)現(xiàn),首先方程(3)展示了我們?nèi)绾斡?jì)算lambda_initial變量?,F(xiàn)在讓我們把方程轉(zhuǎn)換成Python代碼:
def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
然后再寫一個(gè)簡(jiǎn)單的Python函數(shù),使用方程(3)。
class DifferentialAttention(nn.Module):
def __init__(self, dim_model: int, head_nums: int, depth: int):
super().__init__()
self.head_dim = dim_model // head_nums
self.Q = nn.Linear(dim_model, 2 * self.head_dim, bias=False)
self.K = nn.Linear(dim_model, 2 * self.head_dim, bias=False)
self.V = nn.Linear(dim_model, 2 * self.head_dim, bias=False)
self.scale = self.head_dim ** -0.5
self.depth = depth
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.rotary_emb = RotaryEmbedding(self.head_dim * 2)
在DifferentialAttention類中,我們實(shí)現(xiàn)了一個(gè)多頭差分注意力機(jī)制。有dim_model(嵌入維度)、head_nums和depth參數(shù)。為Q1、Q2、K1和K2聲明了四個(gè)lambda可學(xué)習(xí)參數(shù),并使用均值為0、標(biāo)準(zhǔn)差為0.1的隨機(jī)正態(tài)分布初始化它們。
def forward(self, x):
lambda_init = lambda_init_fn(self.depth)
Q = self.Q(x)
K = self.K(x)
seq_len = x.shape[1]
cos, sin = self.rotary_emb(seq_len, device=x.device)
Q, K = apply_rotary_pos_emb(Q, K, cos, sin)
Q1, Q2 = Q.chunk(2, dim=-1)
K1, K2 = K.chunk(2, dim=-1)
V = self.V(x)
A1 = Q1 @ K1.transpose(-2, -1) * self.scale
A2 = Q2 @ K2.transpose(-2, -1) * self.scale
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(Q1)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(Q2)
lambda_ = lambda_1 - lambda_2 + lambda_init
return (F.softmax(A1, dim=-1) - lambda_ * F.softmax(A2, dim=-1)) @ V
forward方法很直觀。我分別實(shí)現(xiàn)了方程(1)和方程(2)。forward方法直接實(shí)現(xiàn)了論文中的偽代碼。
多頭差分注意力架構(gòu)和偽代碼
class MultiHeadDifferentialAttention(nn.Module):
def __init__(self, dim_model: int, head_nums: int, depth: int):
super().__init__()
self.heads = nn.ModuleList([DifferentialAttention(dim_model, head_nums, depth) for _ in range(head_nums)])
self.group_norm = RMSNorm(dim_model)
self.output = nn.Linear(2 * dim_model, dim_model, bias=False)
self.lambda_init = lambda_init_fn(depth)
def forward(self, x):
o = torch.cat([self.group_norm(h(x)) for h in self.heads], dim=-1)
o = o * (1 - self.lambda_init)
return self.output(o)
MultiHeadDifferentialAttention類是根據(jù)論文中的偽代碼編寫的。這里使用了RMSNorm而不是GroupNorm。
論文中使用多頭差分注意力機(jī)制的語(yǔ)言模型的方程
最后使用實(shí)現(xiàn)的MultiHeadDifferentialAttention機(jī)制構(gòu)建一個(gè)transformer解碼器。
class DifferentialTransformer(nn.Module):
def __init__(self, dim: int, depth: int, heads: int = 8, head_dim: int = 64, vocab_size: int = 10000):
super().__init__()
self.vocab_size = vocab_size
self.layers = nn.ModuleList([
MultiHeadDifferentialAttention(dim, heads, depth_idx)
for depth_idx in range(depth)
])
self.ln1 = RMSNorm(dim)
self.ln2 = RMSNorm(dim)
self.ffn = FeedForward(dim, (dim // 3) * 8)
self.output = nn.Linear(dim, self.vocab_size)
def forward(self, x):
for attn in self.layers:
y = attn(self.ln1(x)) + x
x = self.ffn(self.ln2(y)) + y
return self.output(x)
性能優(yōu)化
論文還提供了兩種FlashAttention實(shí)現(xiàn)方式:
1、支持不同維度的實(shí)現(xiàn):
def FlashDiffAttn_1(X, W_q, W_k, W_v, λ):
Q1, Q2 = split(X @ W_q)
K1, K2 = split(X @ W_k)
V = X @ W_v
A1 = flash_attn(Q1, K1, V)
A2 = flash_attn(Q2, K2, V)
return A1 - λ A2
固定維度的實(shí)現(xiàn):
def FlashDiffAttn_2(X, W_q, W_k, W_v, λ):
Q1, Q2 = split(X @ W_q)
K1, K2 = split(X @ W_k)
V1, V2 = split(X @ W_v)
A11 = flash_attn(Q1, K1, V1)
A12 = flash_attn(Q1, K1, V2)
A1 = Concat(A11, A12)
A21 = flash_attn(Q2, K2, V1)
A22 = flash_attn(Q2, K2, V2)
A2 = Concat(A21, A22)
return A1 - λ A2
Differential Transformer論文提出的兩種FlashAttention實(shí)現(xiàn)方案各有特色。第一種實(shí)現(xiàn)(FlashDiffAttn_1)采用直接計(jì)算策略,允許Q、K、V具有不同的維度,這種靈活性使其更適合需要?jiǎng)討B(tài)調(diào)整維度的場(chǎng)景,但可能在某些硬件上的優(yōu)化效果不如第二種方案。第二種實(shí)現(xiàn)(FlashDiffAttn_2)通過(guò)將計(jì)算分解為多個(gè)相同維度的子運(yùn)算,雖然計(jì)算步驟增多,但每個(gè)步驟都能充分利用硬件優(yōu)化,特別是在支持張量核心的現(xiàn)代GPU上表現(xiàn)更好。
這兩種實(shí)現(xiàn)的選擇主要取決于具體應(yīng)用場(chǎng)景:如果模型架構(gòu)需要頻繁調(diào)整維度或者需要更靈活的注意力機(jī)制,建議使用第一種實(shí)現(xiàn);如果追求極致的計(jì)算效率且維度相對(duì)固定,第二種實(shí)現(xiàn)可能是更好的選擇。從工程實(shí)踐角度看,第二種實(shí)現(xiàn)與現(xiàn)有的FlashAttention優(yōu)化庫(kù)的兼容性更好,更容易在現(xiàn)有基礎(chǔ)設(shè)施上部署和優(yōu)化。
局限性和未來(lái)研究方向
Differential Transformer雖然在多個(gè)方面展現(xiàn)出優(yōu)秀的性能,但仍然存在一些值得關(guān)注的局限性。首要的挑戰(zhàn)來(lái)自其計(jì)算效率方面。由于模型需要同時(shí)計(jì)算兩個(gè)獨(dú)立的注意力圖譜,這不可避免地增加了計(jì)算開(kāi)銷。在實(shí)際測(cè)試中,相比傳統(tǒng)Transformer,DIFF Transformer在3B規(guī)模模型上的計(jì)算吞吐量降低了約9%,這種性能損失雖然可以通過(guò)更少的參數(shù)量來(lái)部分抵消,但在大規(guī)模部署場(chǎng)景中仍然需要認(rèn)真考慮。
內(nèi)存使用是另一個(gè)重要的局限性。模型需要存儲(chǔ)兩組獨(dú)立的查詢和鍵值向量,這導(dǎo)致了更高的內(nèi)存占用。盡管這種設(shè)計(jì)對(duì)于提升模型性能是必要的,但在資源受限的環(huán)境下可能會(huì)造成部署困難。特別是在處理超長(zhǎng)序列時(shí),內(nèi)存壓力會(huì)進(jìn)一步加大。
訓(xùn)練穩(wěn)定性也是一個(gè)需要特別關(guān)注的問(wèn)題。模型中λ參數(shù)的初始化策略對(duì)訓(xùn)練過(guò)程的穩(wěn)定性有顯著影響。研究發(fā)現(xiàn),不同的λinit取值會(huì)導(dǎo)致訓(xùn)練收斂速度和最終性能的差異。雖然論文提出了一個(gè)基于層深度的初始化策略,但這種方案并非在所有場(chǎng)景下都能取得最優(yōu)效果,有時(shí)需要根據(jù)具體任務(wù)進(jìn)行調(diào)整。
基于這些局限性,論文提出未來(lái)的研究可以沿著幾個(gè)重要方向展開(kāi)。首先在計(jì)算效率優(yōu)化方面,可以探索更高效的注意力核心實(shí)現(xiàn)。這包括研究如何更好地利用現(xiàn)代硬件特性,例如開(kāi)發(fā)專門的CUDA核心來(lái)加速差分注意力的計(jì)算。同時(shí)考慮到模型產(chǎn)生的稀疏注意力模式,可以設(shè)計(jì)特定的稀疏計(jì)算優(yōu)化策略,這不僅能提升計(jì)算效率,還能減少內(nèi)存占用。
λ參數(shù)的動(dòng)態(tài)調(diào)整機(jī)制是另一個(gè)值得深入研究的方向。當(dāng)前的參數(shù)計(jì)算方案雖然有效,但仍有優(yōu)化空間??梢钥紤]設(shè)計(jì)更靈活的自適應(yīng)機(jī)制,使λ參數(shù)能夠根據(jù)輸入內(nèi)容和任務(wù)特點(diǎn)動(dòng)態(tài)調(diào)整,從而在不同場(chǎng)景下都能獲得最佳性能。這可能需要引入額外的上下文感知機(jī)制,或者設(shè)計(jì)新的參數(shù)更新策略。
在內(nèi)存優(yōu)化方面,量化技術(shù)提供了一個(gè)有前景的研究方向??紤]到DIFF Transformer在處理激活值異常方面的優(yōu)勢(shì),可以探索專門的量化策略。比如,研究如何在保持模型性能的同時(shí),對(duì)注意力權(quán)重和中間狀態(tài)進(jìn)行更激進(jìn)的量化,從而減少內(nèi)存占用。這對(duì)于模型在邊緣設(shè)備上的部署具有重要意義。
長(zhǎng)文本建模能力的進(jìn)一步提升也是一個(gè)重要研究方向。雖然當(dāng)前模型在64K長(zhǎng)度的實(shí)驗(yàn)中表現(xiàn)出色,但隨著應(yīng)用需求的增長(zhǎng),可能需要處理更長(zhǎng)的序列。這要求研究如何在更長(zhǎng)序列上保持模型的效率和性能,可能需要開(kāi)發(fā)新的注意力機(jī)制變體或優(yōu)化策略。
總結(jié)
DIFF Transformer通過(guò)創(chuàng)新的差分注意力機(jī)制成功提升了模型性能,特別是在長(zhǎng)文本理解、關(guān)鍵信息檢索和模型魯棒性等方面。雖然存在一些計(jì)算效率和內(nèi)存使用的權(quán)衡,但考慮到顯著的性能提升和更少的參數(shù)需求,這是一個(gè)非常有價(jià)值的改進(jìn)。這項(xiàng)工作為大語(yǔ)言模型的架構(gòu)設(shè)計(jì)提供了新的思路,也為后續(xù)研究指明了幾個(gè)重要的優(yōu)化方向。