LLM可解釋性的未來(lái)希望?稀疏自編碼器是如何工作的,這里有一份直觀說(shuō)明
在解釋機(jī)器學(xué)習(xí)模型方面,稀疏自編碼器(SAE)是一種越來(lái)越常用的工具(雖然 SAE 在 1997 年左右就已經(jīng)問(wèn)世了)。
機(jī)器學(xué)習(xí)模型和 LLM 正變得越來(lái)越強(qiáng)大、越來(lái)越有用,但它們?nèi)耘f是黑箱,我們并不理解它們完成任務(wù)的方式。理解它們的工作方式應(yīng)當(dāng)大有助益。
SAE 可幫助我們將模型的計(jì)算分解成可以理解的組件。近日,LLM 可解釋性研究者 Adam Karvonen 發(fā)布了一篇博客文章,直觀地解釋了 SAE 的工作方式。
可解釋性的難題
神經(jīng)網(wǎng)絡(luò)最自然的組件是各個(gè)神經(jīng)元。不幸的是,單個(gè)神經(jīng)元并不能便捷地與單個(gè)概念相對(duì)應(yīng),比如學(xué)術(shù)引用、英語(yǔ)對(duì)話、HTTP 請(qǐng)求和韓語(yǔ)文本。在神經(jīng)網(wǎng)絡(luò)中,概念是通過(guò)神經(jīng)元的組合表示的,這被稱(chēng)為疊加(superposition)。
之所以會(huì)這樣,是因?yàn)槭澜缟虾芏嘧兞刻烊痪褪窍∈璧摹?/span>
舉個(gè)例子,某位名人的出生地可能出現(xiàn)在不到十億分之一的訓(xùn)練 token 中,但現(xiàn)代 LLM 依然能學(xué)到這一事實(shí)以及有關(guān)這個(gè)世界的大量其它知識(shí)。訓(xùn)練數(shù)據(jù)中單個(gè)事實(shí)和概念的數(shù)量多于模型中神經(jīng)元的數(shù)量,這可能就是疊加出現(xiàn)的原因。
近段時(shí)間,稀疏自編碼器(SAE)技術(shù)越來(lái)越常被用于將神經(jīng)網(wǎng)絡(luò)分解成可理解的組件。SAE 的設(shè)計(jì)靈感來(lái)自神經(jīng)科學(xué)領(lǐng)域的稀疏編碼假設(shè)?,F(xiàn)在,SAE 已成為解讀人工神經(jīng)網(wǎng)絡(luò)方面最有潛力的工具之一。SAE 與標(biāo)準(zhǔn)自編碼器類(lèi)似。
常規(guī)自編碼器是一種用于壓縮并重建輸入數(shù)據(jù)的神經(jīng)網(wǎng)絡(luò)。
舉個(gè)例子,如果輸入是一個(gè) 100 維的向量(包含 100 個(gè)數(shù)值的列表);自編碼器首先會(huì)讓該輸入通過(guò)一個(gè)編碼器層,讓其被壓縮成一個(gè) 50 維的向量,然后將這個(gè)壓縮后的編碼表示饋送給解碼器,得到 100 維的輸出向量。其重建過(guò)程通常并不完美,因?yàn)閴嚎s過(guò)程會(huì)讓重建任務(wù)變得非常困難。
一個(gè)標(biāo)準(zhǔn)自編碼器的示意圖,其有 1x4 的輸入向量、1x2 的中間狀態(tài)向量和 1x4 的輸出向量。單元格的顏色表示激活值。輸出是輸入的不完美重建結(jié)果。
解釋稀疏自編碼器
稀疏自編碼器的工作方式
稀疏自編碼器會(huì)將輸入向量轉(zhuǎn)換成中間向量,該中間向量的維度可能高于、等于或低于輸入的維度。在用于 LLM 時(shí),中間向量的維度通常高于輸入。在這種情況下,如果不加額外的約束條件,那么該任務(wù)就很簡(jiǎn)單,SAE 可以使用單位矩陣來(lái)完美地重建出輸入,不會(huì)出現(xiàn)任何意料之外的東西。但我們會(huì)添加約束條件,其中之一是為訓(xùn)練損失添加稀疏度懲罰,這會(huì)促使 SAE 創(chuàng)建稀疏的中間向量。
舉個(gè)例子,我們可以將 100 維的輸入擴(kuò)展成 200 維的已編碼表征向量,并且我們可以訓(xùn)練 SAE 使其在已編碼表征中僅有大約 20 個(gè)非零元素。
稀疏自編碼器示意圖。請(qǐng)注意,中間激活是稀疏的,僅有 2 個(gè)非零值。
我們將 SAE 用于神經(jīng)網(wǎng)絡(luò)內(nèi)的中間激活,而神經(jīng)網(wǎng)絡(luò)可能包含許多層。在前向通過(guò)過(guò)程中,每一層中和每一層之間都有中間激活。
舉個(gè)例子,GPT-3 有 96 層。在前向通過(guò)過(guò)程中,輸入中的每個(gè) token 都有一個(gè) 12,288 維向量(一個(gè)包含 12,288 個(gè)數(shù)值的列表)。此向量會(huì)累積模型在每一層處理時(shí)用于預(yù)測(cè)下一 token 的所有信息,但它并不透明,讓人難以理解其中究竟包含什么信息。
我們可以使用 SAE 來(lái)理解這種中間激活。SAE 基本上就是「矩陣 → ReLU 激活 → 矩陣」。
舉個(gè)例子,如果 GPT-3 SAE 的擴(kuò)展因子為 4,其輸入激活有 12,288 維,則其 SAE 編碼的表征有 49,512 維(12,288 x 4)。第一個(gè)矩陣是形狀為 (12,288, 49,512) 的編碼器矩陣,第二個(gè)矩陣是形狀為 (49,512, 12,288) 的解碼器矩陣。通過(guò)讓 GPT 的激活與編碼器相乘并使用 ReLU,可以得到 49,512 維的 SAE 編碼的稀疏表征,因?yàn)?SAE 的損失函數(shù)會(huì)促使實(shí)現(xiàn)稀疏性。
通常來(lái)說(shuō),我們的目標(biāo)讓 SAE 的表征中非零值的數(shù)量少于 100 個(gè)。通過(guò)將 SAE 的表征與解碼器相乘,可得到一個(gè) 12,288 維的重建的模型激活。這個(gè)重建結(jié)果并不能與原始的 GPT 激活完美匹配,因?yàn)橄∈栊约s束條件會(huì)讓完美匹配難以實(shí)現(xiàn)。
一般來(lái)說(shuō),一個(gè) SAE 僅用于模型中的一個(gè)位置舉個(gè)例子,我們可以在 26 和 27 層之間的中間激活上訓(xùn)練一個(gè) SAE。為了分析 GPT-3 的全部 96 層的輸出中包含的信息,可以訓(xùn)練 96 個(gè)分立的 SAE—— 每層的輸出都有一個(gè)。如果我們也想分析每一層內(nèi)各種不同的中間激活,那就需要數(shù)百個(gè) SAE。為了獲取這些 SAE 的訓(xùn)練數(shù)據(jù),需要向這個(gè) GPT 模型輸入大量不同的文本,然后收集每個(gè)選定位置的中間激活。
下面提供了一個(gè) SAE 的 PyTorch 參考實(shí)現(xiàn)。其中的變量帶有形狀注釋?zhuān)@個(gè)點(diǎn)子來(lái)自 Noam Shazeer,參見(jiàn):https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd 。請(qǐng)注意,為了盡可能地提升性能,不同的 SAE 實(shí)現(xiàn)往往會(huì)有不同的偏置項(xiàng)、歸一化方案或初始化方案。最常見(jiàn)的一種附加項(xiàng)是某種對(duì)解碼器向量范數(shù)的約束。更多細(xì)節(jié)請(qǐng)?jiān)L問(wèn)以下實(shí)現(xiàn):
- OpenAI:https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py#L16
- SAELens:https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L97
- dictionary_learning:https://github.com/saprmarks/dictionary_learning/blob/main/dictionary.py#L30
import torch
import torch.nn as nn
# D = d_model, F = dictionary_size
# e.g. if d_model = 12288 and dictionary_size = 49152
# then model_activations_D.shape = (12288,) and encoder_DF.weight.shape = (12288, 49152)
class SparseAutoEncoder (nn.Module):
"""
A one-layer autoencoder.
"""
def __init__(self, activation_dim: int, dict_size: int):
super ().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size
self.encoder_DF = nn.Linear (activation_dim, dict_size, bias=True)
self.decoder_FD = nn.Linear (dict_size, activation_dim, bias=True)
def encode (self, model_activations_D: torch.Tensor) -> torch.Tensor:
return nn.ReLU ()(self.encoder_DF (model_activations_D))
def decode (self, encoded_representation_F: torch.Tensor) -> torch.Tensor:
return self.decoder_FD (encoded_representation_F)
def forward_pass (self, model_activations_D: torch.Tensor) -> tuple [torch.Tensor, torch.Tensor]:
encoded_representation_F = self.encode (model_activations_D)
reconstructed_model_activations_D = self.decode (encoded_representation_F)
return reconstructed_model_activations_D, encoded_representation_F
標(biāo)準(zhǔn)自編碼器的損失函數(shù)基于輸入重建結(jié)果的準(zhǔn)確度。為了引入稀疏性,最直接的方法是向 SAE 的損失函數(shù)添加一個(gè)稀疏度懲罰項(xiàng)。對(duì)于這個(gè)懲罰項(xiàng),最常見(jiàn)的計(jì)算方式是取這個(gè) SAE 的已編碼表征(而非 SAE 權(quán)重)的 L1 損失并將其乘以一個(gè) L1 系數(shù)。這個(gè) L1 系數(shù)是 SAE 訓(xùn)練中的一個(gè)關(guān)鍵超參數(shù),因?yàn)樗纱_定實(shí)現(xiàn)稀疏度與維持重建準(zhǔn)確度之間的權(quán)衡。
請(qǐng)注意,這里并沒(méi)有針對(duì)可解釋性進(jìn)行優(yōu)化。相反,可解釋的 SAE 特征是優(yōu)化稀疏度和重建的一個(gè)附帶效果。下面是一個(gè)參考損失函數(shù)。
# B = batch size, D = d_model, F = dictionary_size
def calculate_loss (autoencoder: SparseAutoEncoder, model_activations_BD: torch.Tensor, l1_coeffient: float) -> torch.Tensor:
reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass (model_activations_BD)
reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow (2)
reconstruction_error_B = einops.reduce (reconstruction_error_BD, 'B D -> B', 'sum')
l2_loss = reconstruction_error_B.mean ()
l1_loss = l1_coefficient * encoded_representation_BF.sum ()
loss = l2_loss + l1_loss
return loss
稀疏自編碼器的前向通過(guò)示意圖。
這是稀疏自編碼器的單次前向通過(guò)過(guò)程。首先是 1x4 大小的模型向量。然后將其乘以一個(gè) 4x8 的編碼器矩陣,得到一個(gè) 1x8 的已編碼向量,然后應(yīng)用 ReLU 將負(fù)值變成零。這個(gè)編碼后的向量就是稀疏的。之后,再讓其乘以一個(gè) 8x4 的解碼器矩陣,得到一個(gè) 1x4 的不完美重建的模型激活。
假想的 SAE 特征演示
理想情況下,SAE 表征中的每個(gè)有效數(shù)值都對(duì)應(yīng)于某個(gè)可理解的組件。
這里假設(shè)一個(gè)案例進(jìn)行說(shuō)明。假設(shè)一個(gè) 12,288 維向量 [1.5, 0.2, -1.2, ...] 在 GPT-3 看來(lái)是表示「Golden Retriever」(金毛犬)。SAE 是一個(gè)形狀為 (49,512, 12,288) 的矩陣,但我們也可以將其看作是 49,512 個(gè)向量的集合,其中每個(gè)向量的形狀都是 (1, 12,288)。如果該 SAE 解碼器的 317 向量學(xué)習(xí)到了與 GPT-3 那一樣的「Golden Retriever」概念,那么該解碼器向量大致也等于 [1.5, 0.2, -1.2, ...]。
無(wú)論何時(shí) SAE 的激活的 317 元素是非零的,那么對(duì)應(yīng)于「Golden Retriever」的向量(并根據(jù) 317 元素的幅度)會(huì)被添加到重建激活中。用機(jī)械可解釋性的術(shù)語(yǔ)來(lái)說(shuō),這可以簡(jiǎn)潔地描述為「解碼器向量對(duì)應(yīng)于殘差流空間中特征的線性表征」。
也可以說(shuō)有 49,512 維的已編碼表征的 SAE 有 49,512 個(gè)特征。特征由對(duì)應(yīng)的編碼器和解碼器向量構(gòu)成。編碼器向量的作用是檢測(cè)模型的內(nèi)部概念,同時(shí)最小化其它概念的干擾,盡管解碼器向量的作用是表示「真實(shí)的」特征方向。研究者的實(shí)驗(yàn)發(fā)現(xiàn),每個(gè)特征的編碼器和解碼器特征是不一樣的,并且余弦相似度的中位數(shù)為 0.5。在下圖中,三個(gè)紅框?qū)?yīng)于單個(gè)特征。
稀疏自編碼器示意圖,其中三個(gè)紅框?qū)?yīng)于 SAE 特征 1,綠框?qū)?yīng)于特征 4。每個(gè)特征都有一個(gè) 1x4 的編碼器向量、1x1 的特征激活和 1x4 的解碼器向量。重建的激活的構(gòu)建僅使用了來(lái)自 SAE 特征 1 和 4 的解碼器向量。如果紅框表示「紅顏色」,綠框表示「球」,那么該模型可能表示「紅球」。
那么我們?cè)撊绾蔚弥僭O(shè)的特征 317 表示什么呢?目前而言,人們的實(shí)踐方法是尋找能最大程度激活特征并對(duì)它們的可解釋性給出直覺(jué)反應(yīng)的輸入。能讓每個(gè)特征激活的輸入通常是可解釋的。
舉個(gè)例子,Anthropic 在 Claude Sonnet 上訓(xùn)練了 SAE,結(jié)果發(fā)現(xiàn):與金門(mén)大橋、神經(jīng)科學(xué)和熱門(mén)旅游景點(diǎn)相關(guān)的文本和圖像會(huì)激活不同的 SAE 特征。其它一些特征會(huì)被并不顯而易見(jiàn)的概念激活,比如在 Pythia 上訓(xùn)練的一個(gè) SAE 的一個(gè)特征會(huì)被這樣的概念激活,即「用于修飾句子主語(yǔ)的關(guān)系從句或介詞短語(yǔ)的最終 token」。
由于 SAE 解碼器向量的形狀與 LLM 的中間激活一樣,因此可簡(jiǎn)單地通過(guò)將解碼器向量加入到模型激活來(lái)執(zhí)行因果干預(yù)。通過(guò)讓該解碼器向量乘以一個(gè)擴(kuò)展因子,可以調(diào)整這種干預(yù)的強(qiáng)度。當(dāng) Anthropic 研究者將「金門(mén)大橋」SAE 解碼器向量添加到 Claude 的激活時(shí),Claude 會(huì)被迫在每個(gè)響應(yīng)中都提及「金門(mén)大橋」。
下面是使用假設(shè)的特征 317 得到的因果干預(yù)的參考實(shí)現(xiàn)。類(lèi)似于「金門(mén)大橋」Claude,這種非常簡(jiǎn)單的干預(yù)會(huì)迫使 GPT-3 模型在每個(gè)響應(yīng)中都提及「金毛犬」。
def perform_intervention (model_activations_D: torch.Tensor, decoder_FD: torch.Tensor, scale: float) -> torch.Tensor:
intervention_vector_D = decoder_FD [317, :]
scaled_intervention_vector_D = intervention_vector_D * scale
modified_model_activations_D = model_activations_D + scaled_intervention_vector_D
return modified_model_activations_D
稀疏自編碼器的評(píng)估難題
使用 SAE 的一大主要難題是評(píng)估。我們可以訓(xùn)練稀疏自編碼器來(lái)解釋語(yǔ)言模型,但我們沒(méi)有自然語(yǔ)言表示的可度量的底層 ground truth。目前而言,評(píng)估都很主觀,基本也就是「我們研究一系列特征的激活輸入,然后憑直覺(jué)闡述這些特征的可解釋性。」這是可解釋性領(lǐng)域的主要限制。
研究者已經(jīng)發(fā)現(xiàn)了一些似乎與特征可解釋性相對(duì)應(yīng)的常見(jiàn)代理指標(biāo)。最常用的是 L0 和 Loss Recovered。L0 是 SAE 的已編碼中間表征中非零元素的平均數(shù)量。Loss Recovered 是使用重建的激活替換 GPT 的原始激活,并測(cè)量不完美重建結(jié)果的額外損失。這兩個(gè)指標(biāo)通常需要權(quán)衡考慮,因?yàn)?SAE 可能會(huì)為了提升稀疏性而選擇一個(gè)會(huì)導(dǎo)致重建準(zhǔn)確度下降的解。
在比較 SAE 時(shí),一種常用方法是繪制這兩個(gè)變量的圖表,然后檢查它們之間的權(quán)衡。為了實(shí)現(xiàn)更好的權(quán)衡,許多新的 SAE 方法(如 DeepMind 的 Gated SAE 和 OpenAI 的 TopK SAE)對(duì)稀疏度懲罰做了修改。下圖來(lái)自 DeepMind 的 Gated SAE 論文。Gated SAE 由紅線表示,位于圖中左上方,這表明其在這種權(quán)衡上表現(xiàn)更好。
Gated SAE L0 與 Loss Recovered
SAE 的度量存在多個(gè)難度層級(jí)。L0 和 Loss Recovered 是兩個(gè)代理指標(biāo)。但是,在訓(xùn)練時(shí)我們并不會(huì)使用它們,因?yàn)?L0 不可微分,而在 SAE 訓(xùn)練期間計(jì)算 Loss Recovered 的計(jì)算成本非常高。相反,我們的訓(xùn)練損失由一個(gè) L1 懲罰項(xiàng)和重建內(nèi)部激活的準(zhǔn)確度決定,而非其對(duì)下游損失的影響。
訓(xùn)練損失函數(shù)并不與代理指標(biāo)直接對(duì)應(yīng),并且代理指標(biāo)只是對(duì)特征可解釋性的主觀評(píng)估的代理。由于我們的真正目標(biāo)是「了解模型的工作方式」,主觀可解釋性評(píng)估只是代理,因此還會(huì)有另一層不匹配。LLM 中的一些重要概念可能并不容易解釋?zhuān)椅覀兛赡軙?huì)在盲目?jī)?yōu)化可解釋性時(shí)忽視這些概念。
總結(jié)
可解釋性領(lǐng)域還有很長(zhǎng)的路要走,但 SAE 是真正的進(jìn)步。SAE 能實(shí)現(xiàn)有趣的新應(yīng)用,比如一種用于查找「金門(mén)大橋」導(dǎo)向向量(steering vector)這樣的導(dǎo)向向量的無(wú)監(jiān)督方法。SAE 也能幫助我們更輕松地查找語(yǔ)言模型中的回路,這或可用于移除模型內(nèi)部不必要的偏置。
SAE 能找到可解釋的特征(即便目標(biāo)僅僅是識(shí)別激活中的模式),這一事實(shí)說(shuō)明它們能夠揭示一些有意義的東西。還有證據(jù)表明 LLM 確實(shí)能學(xué)習(xí)到一些有意義的東西,而不僅僅是記憶表層的統(tǒng)計(jì)規(guī)律。
SAE 也能代表 Anthropic 等公司曾引以為目標(biāo)的早期里程碑,即「用于機(jī)器學(xué)習(xí)模型的 MRI(磁共振成像)」。SAE 目前還不能提供完美的理解能力,但卻可用于檢測(cè)不良行為。SAE 和 SAE 評(píng)估的主要挑戰(zhàn)并非不可克服,并且現(xiàn)在已有很多研究者在攻堅(jiān)這一課題。
有關(guān)稀疏自編碼器的進(jìn)一步介紹,可參閱 Callum McDougal 的 Colab 筆記本:https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab