揭秘大模型的魔法:實現(xiàn)帶可訓練權(quán)重的自注意力機制

大家好,我是寫代碼的中年人。
上一篇我們實現(xiàn)了一個“無可訓練參數(shù)”的注意力機制,讓每個詞都能“看看別人”,計算出自己的上下文理解。
雖然實現(xiàn)起來不難,但它只是個“玩具級”的注意力,離真正的大模型還差了幾個“億”個參數(shù)。今天,我們來實現(xiàn)一個可訓練版本的自注意力機制,這可是 Transformer 的核心!
01、什么叫“可訓練”的注意力?
在大模型里,注意力機制不是寫死的,而是學出來的。
為了讓每個詞都能“智能提問、精準關(guān)注”,我們需要三個可訓練的權(quán)重矩陣:

每個詞自己造問題,然后去問別的詞,看看誰最“對味”,然后決定聽誰的意見。
為什么自注意力機制(Self-Attention)中需要三個可訓練的權(quán)重矩陣,也就是常說的:
Wq:Query 權(quán)重矩陣
Wk:Key 權(quán)重矩陣
Wv:Value 權(quán)重矩陣
這個設(shè)計最早出現(xiàn)在 2017 年 Google 的論文《Attention is All You Need》中,也就是Transformer架構(gòu)的原始論文。這三個矩陣的引入不是隨便“拍腦袋”的設(shè)計,而是有明確動機的:

# ONE
這段論文奠定了 Transformer 的注意力計算基礎(chǔ)。Transformer 后續(xù)所有的 Multi-Head Attention、Encoder-Decoder Attention,都是基于這個 Scaled Dot-Product Attention 構(gòu)建的。
02、我是誰?我在哪?我要關(guān)注誰?
其實自注意力就是一種帶可訓練權(quán)重的加權(quán)平均機制,它做了三件事:
把每個詞向量分別變成三個形態(tài):Query(查詢)、Key(鍵)、Value(值);
計算 Query 和所有 Key 的相似度(注意力權(quán)重);
用這個權(quán)重加權(quán) Value 向量,得出最終的輸出向量。每個詞都在用“自己的 Query”去看“別人的 Key”,然后決定“我到底該關(guān)注誰”。
如果我們想理解這些內(nèi)容,最好以代碼的形式來逐步理解:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# ------- 定義可訓練的自注意力模塊 -------
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.last_attn_weights = None
    def forward(self, x):
        B, T, C = x.size()
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
        self.last_attn_weights = attn_weights.detach()
        out = torch.matmul(attn_weights, V)
        out = self.out_proj(out)
        return out
# ------- Create a Simulated Dataset -------
# Simulate a small vocabulary and word embeddings
vocab = {"寫": 0, "代碼": 1, "的": 2, "中年人": 3, "天天": 4, "<PAD>": 5}
embed_dim = 16
vocab_size = len(vocab)
embedding = nn.Embedding(vocab_size, embed_dim)  # Randomly initialized word embeddings
# Sentence data
sentences = [
    ["寫", "代碼", "的", "中年人"],
    ["天天", "寫", "代碼", "<PAD>"]  # Pad the second sentence to match length
]
batch_size = len(sentences)
seq_len = len(sentences[0])  # Sentences have the same length (4)
# Convert sentences to indices
input_ids = torch.tensor([[vocab[word] for word in sent] for sent in sentences])  # (batch_size, seq_len)
# ------- Parameter Settings -------
epochs = 200
dropout = 0.1
model = SelfAttention(embed_dim, dropout)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# ------- Train the Model -------
for epoch in range(epochs):
    model.train()
    # Compute input inside the loop to create a fresh computation graph
    x = embedding(input_ids)  # (batch_size, seq_len, embed_dim)
    target = x.clone()  # Target is the same as input for this task
    out = model(x)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()  # Compute gradients
    optimizer.step()  # Update model parameters
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}, Loss: {loss.item():.6f}")
# ------- Visualize Attention Weights -------
# Visualize attention matrix for the first sentence
attention = model.last_attn_weights[0].numpy()  # (seq_len, seq_len)
sentence = sentences[0]  # ["寫", "代碼", "的", "中年人"]
plt.figure(figsize=(8, 6))
plt.imshow(attention, cmap='viridis')
plt.title(f"Attention Matrix for Sentence: {' '.join(sentence)}")
plt.xticks(ticks=np.arange(seq_len), labels=sentence)
plt.yticks(ticks=np.arange(seq_len), labels=sentence)
plt.xlabel("Key (Word)")
plt.ylabel("Query (Word)")
plt.colorbar(label="Attention Strength")
for i in range(seq_len):
    for j in range(seq_len):
        plt.text(j, i, f"{attention[i,j]:.2f}", ha="center", va="center", color="white")
plt.tight_layout()
plt.savefig("attention_matrix_sentence1.png")
plt.show()
# Visualize attention matrix for the second sentence
attention = model.last_attn_weights[1].numpy()
sentence = sentences[1]  # ["天天", "寫", "代碼", "<PAD>"]
plt.figure(figsize=(8, 6))
plt.imshow(attention, cmap='viridis')
plt.title(f"Attention Matrix for Sentence: {' '.join(sentence)}")
plt.xticks(ticks=np.arange(seq_len), labels=sentence)
plt.yticks(ticks=np.arange(seq_len), labels=sentence)
plt.xlabel("Key (Word)")
plt.ylabel("Query (Word)")
plt.colorbar(label="Attention Strength")
for i in range(seq_len):
    for j in range(seq_len):
        plt.text(j, i, f"{attention[i,j]:.2f}", ha="center", va="center", color="white")
plt.tight_layout()
plt.savefig("attention_matrix_sentence2.png")
plt.show()上面的代碼執(zhí)行后輸出:


代碼詳解:
這段代碼實現(xiàn)了一個簡單的自注意力(Self-Attention)模型,并通過一個模擬的中文數(shù)據(jù)集進行訓練,展示自注意力機制如何捕捉句子中詞語之間的關(guān)系。以下是代碼的詳細解釋,以及對自注意力機制的深入分析。
這段代碼的核心目標是:實現(xiàn)自注意力模塊:通過定義一個SelfAttention類,實現(xiàn)自注意力機制,模擬Transformer模型中的核心組件。訓練模型:使用一個簡單的中文詞匯數(shù)據(jù)集,訓練自注意力模型,使其學習詞語之間的注意力分布??梢暬⒁饬?quán)重:通過繪制注意力矩陣,直觀展示模型如何關(guān)注句子中不同詞語之間的關(guān)系。
代碼主要分為以下幾個部分:數(shù)據(jù)集構(gòu)建:構(gòu)造一個小型中文詞匯表和兩個短句,模擬自然語言處理任務。模型定義:實現(xiàn)自注意力模塊,包含查詢(Query)、鍵(Key)、值(Value)的線性變換和注意力計算。訓練過程:通過優(yōu)化模型,使其輸出盡可能接近輸入(一種簡單的自監(jiān)督學習任務)??梢暬豪L制注意力矩陣,展示模型對不同詞語的關(guān)注程度。
03、自注意力機制詳解
自注意力機制的核心思想
自注意力是Transformer模型的核心組件,用于捕捉序列中元素(詞、字符等)之間的關(guān)系。
其核心思想是:
每個輸入元素(如詞)同時扮演查詢(Query)、鍵(Key)和值(Value)三個角色。通過計算查詢與鍵的相似度,生成注意力權(quán)重,決定每個元素對其他元素的關(guān)注程度。使用注意力權(quán)重對值進行加權(quán)求和,生成上下文感知的表示。
數(shù)學公式:

# ONE
訓練權(quán)重的作用:
在訓練過程中,自注意力機制的權(quán)重(W_q, W_k, W_v, W_out)通過優(yōu)化器更新,目標是使模型輸出盡可能接近輸入(MSE損失)。
具體作用:
學習語義關(guān)系:通過調(diào)整W_q和W_k,模型學習詞之間的語義關(guān)聯(lián)。例如,“寫”和“代碼”可能有較高的注意力權(quán)重,因為它們在語義上相關(guān)。
增強表示:通過W_v和W_out,模型生成更豐富的上下文表示,捕捉句子中詞語的相互影響。
動態(tài)關(guān)注:注意力權(quán)重是動態(tài)計算的,允許模型根據(jù)輸入內(nèi)容靈活調(diào)整關(guān)注重點。
通過深入剖析自注意力機制及其可訓練權(quán)重的核心作用,我們揭開了大模型處理復雜任務時那份“魔力”的關(guān)鍵一角。自注意力以其獨特的方式,讓模型能靈活聚焦于輸入序列中的重要信息,大幅提升了上下文理解的能力。但這只是開端。在下一章,我們將進一步探討多頭注意力機制,看它如何通過并行處理多組注意力,為模型帶來更強的表達力和靈活性。















 
 
 
















 
 
 
 