Transformers基本原理—Decoder如何進行解碼?
一、Transformers整體架構(gòu)概述
Transformers 是一種基于自注意力機制的架構(gòu),最初在2017年由Vaswani等人在論文《Attention Is All You Need》中提出。這種架構(gòu)徹底改變了自然語言處理(NLP)領域,因為它能夠有效地處理序列數(shù)據(jù),并且能夠捕捉長距離依賴關系。
Transformers整體架構(gòu)如下:
主要架構(gòu)由左側(cè)的編碼器(Encoder)和右側(cè)的解碼器(Decoder)構(gòu)成。
對于Decoder主要做了一件什么事兒,可參考之前的文章。
本次我們主要來看解碼器如何工作。
二、編碼器的輸出
假設有這樣一個任務,我們要將德文翻譯為英文,這個就會使用到Encoder-Decoder整個結(jié)構(gòu)。
德文:ich mochte ein bier
英文:i want a beer
為了更好地模擬真實情況(每一句話長度不可能相同),我們需要將德文進行Padding,填充1個P符號;英文目標增加一個End,表示這句話翻譯完了,用E符號表示。那么這兩句話就變?yōu)?/span>
德文:ich mochte ein bier P
英文:i want a beer E
這個Encoder結(jié)束后,會給我們一個輸出向量,這個向量將會在Decoder結(jié)構(gòu)中會形成K和V進行使用。
Encoder的輸出可以當作真實序列的高級表達,或者說是對真實序列的高級拆分。
三、解碼器如何工作
在Transformer模型中,解碼器(Decoder)的主要作用是通過自回歸模型生成目標序列。
自回歸的含義是這樣的:
在預測 “東方紅,太陽升?!?的時候:
會優(yōu)先傳入 “東”,然后預測出來“方”,
接著傳入“東方”, 然后預測出來“東方紅”,
接著傳入“東方紅”,然后預測出來“,”,
接著再傳入“東方紅,”,然后預測出來“太”
...
直到最后預測出來整句話:“東方紅,太陽升。”。
所以,自回歸就是基于已出現(xiàn)的詞預測未來的詞
接下來我們先來搞明白解碼器具體做了一件什么事兒,忽略如何做的細節(jié)。
四、解碼器輸入
4.1 (Shifted)Outputs
解碼器的輸入基本上與編碼器輸入類似,如果忘記了可以去查看:Transformers基本原理—Encoder如何進行編碼(1)
4.1.1 如何理解(Shifted)Outputs
(Shifted)Outputs表示輸入序列詞的偏移。
對于Transformer結(jié)構(gòu)來說,我們的1組數(shù)據(jù)應包含3個:Encoder輸入是1個,Decoder輸入是1個,目標是1個。
比如:
Encoder輸入enc_inputs為德文 “ich mochte ein bier P”
Decoder輸入dec_inputs為英文 “S i want a beer”
優(yōu)化目標輸入target_inputs為英文 “i want a beer E”
對于Decoder來說,就是要通過“S”預測目標“i”,通過“S i”預測目標“want”,通過“S i want”預測“a”,通過“S i want a”預測“beer”,通過“S i want a beer”預測“E”。
S表示給序列一個開始預測的提示,E表示給序列一個停止預測的提示。
4.1.2 如何完成(Shifted)Outputs
這里的處理方法與Encoder是一樣,但是一定要注意輸入的是帶偏移的序列。
我們定義1個大詞表,它包含了輸入所有的詞的位置,如下所示:
tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6}
那對應于 英文 “S i want a beer” ,其編碼端的輸入就是:
dec_inputs: tensor([[5, 1, 2, 3, 4]])
接著我們將這5個詞通過nn.Embedding給映射到高維空間,形成詞向量,為了方便展示,我們映射到6維空間(一般映射到2的n次方維度,最高512維)。
此處的計算邏輯與Encoder一模一樣,不再贅述,直接附結(jié)果。
這樣我們就完成了輸入文本的詞嵌入。
4.2 位置嵌入
位置嵌入原理與Encoder中一模一樣,此處不再贅述,可參考:Transformers基本原理—Encoder如何進行編碼(1)?
這樣我們就完成了輸入文本的位置嵌入。
4.3 融合詞嵌入和位置編碼
由于兩個向量shape都是一樣的,因此直接相加即可,結(jié)果如下:
五、帶掩碼的多頭注意力機制
在Decoder的多頭注意力機制中,QKV的計算方式與Encoder中一模一樣,重點只需要關注掩碼的方式即可。
由于Encoder自回歸的原因,所以,此處的掩碼應遵循預測時不看到未來信息且不關注Padding符號的原則。
5.1 不看未來信息
以“S i want a beer”為例:
在通過“S”預測時,不應該注意后面的“i want a beer”;在通過“S i”預測時,不應該注意后面的“want a beer”。
因此這個注意力矩陣就應該遵循這個規(guī)律:將看不到的詞的注意力設置為0。
而這個我們恰巧可以通過上三角矩陣來完成,具體代碼如下:
def get_attn_subsequent_mask(seq):
# 取出傳進來的向量的三個維度,分別是batch_size,seq_length,d_model
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
# 通過np.triu形成一個上三角矩陣,填充為1
subsequent_mask = np.triu(np.ones(attn_shape), k=1)
# 由array轉(zhuǎn)為torch的tensor格式
subsequent_mask = torch.from_numpy(subsequent_mask).byte()
return subsequent_mask
5.2 不看Padding的詞
這里的處理方式與Encoder一致,主要為了讓對Padding的注意力降為0.
以“S i want P P”為例,具體代碼如下:
# 需要判斷哪些位置是填充的
def get_attn_pad_mask(seq_q, seq_k): # enc_inputs, enc_inputs 告訴后面的句子后面的層 那些是被pad符號填充的 pad的目的是讓batch里面的每一行長度一致
# 獲取seq_q和seq_k的batch_size和長度
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# 判斷輸入的位置編碼是否是0并升維度,0表明這個并不是有效的詞,只是填充或者標點等,后續(xù)計算自注意力時不關注這些詞
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
# batch_size x 1 x len_k(=len_q), one is maskingtensor([[[False, False, False, False, True]]]) True代表為pad符號
# 擴展mask矩陣,使其維度為batch_size x len_q x len_k
# (1,1,5) ->> (1,5,5)
return pad_attn_mask.expand(batch_size, len_q, len_k)
如果序列為“S i want a beer”為例,則掩碼矩陣如下:
5.3 合并的masked矩陣
顯然,我們在對序列進行編碼的時候,肯定希望這個序列既不能偷看未來信息,也不能關注Padding信息,因此,需要將兩個掩碼進行融),得到self-attention的mask。
具體做法為:兩個mask矩陣直接求和,大于0的地方說明要么是未來信息,要么是Padding信息。
“S i want P P”的合并mask矩陣就如下所示:
“S i want a beer”的合并mask矩陣就如下所示:
這樣,我們就得到了最終的mask矩陣,只需要將融合的mask矩陣的在計算注意力時填充為-1e10即可,這樣經(jīng)過Softmax之后得到的注意力就約等于0,不會去關注這部分信息。
5.4 多頭注意力機制
將【4.3 融合詞嵌入和位置編碼】的結(jié)果當作X輸入Decoder的多頭進行計算即可,只需要在計算自注意力的時候根據(jù)合并的masked矩陣進行掩碼填充,防止關注不需要關注的信息即可。
第1步:QKV的計算
X 為我們輸入的詞嵌入,假設此處的shape為(2,4)。
WQ、WK、WV矩陣為優(yōu)化目標,WQ、WK矩陣的維度一定相同,此處均為(4,3);WV矩陣維度可以相同也可以不同,此處為(4,3)。
Q、K、V經(jīng)過矩陣運算后,shape均為(2,3)。
第2步:計算最終輸出
1、Q矩陣與K矩陣的轉(zhuǎn)置相乘,得到詞與詞之間的注意力分數(shù)矩陣,位于(i,j)位置表示第i個詞與第j個詞之間的注意力分數(shù)。
矩陣運算:(2,3) * (3,2) ->> (2,2)
2、歸一化注意力分數(shù)后,矩陣shape不會改變,只改變內(nèi)部數(shù)值大小。【根據(jù)mask的填充就發(fā)生在Softmax前,確保不關注未來信息與padding信息】
矩陣大小:(2,2)
3、與V矩陣相乘,得到最終輸出Z矩陣。
矩陣運算:(2,2) * (2,3) ->> (2,3)
Z矩陣就是原始輸入經(jīng)過變換后的輸出,且整個過程中只需要優(yōu)化WQ、WK、WV矩陣,計算量大大減少。
這樣,我們就完成了Decoder的第一個多頭計算。
六、殘差連接和層歸一化
6.1 殘差連接
為了解決深層網(wǎng)絡退化問題,殘差連接是經(jīng)常被使用的,在代碼內(nèi)也就一行:
dec_outputs = output + residual
dec_outputs:最終輸出,判斷原始輸出好還是經(jīng)過多頭后的輸出好
output:經(jīng)過多頭后的輸出
residual:原始輸入
6.2 層歸一化
層歸一化只需要將最終輸出做一下LayerNorm即可,也是一行代碼即可完成。
dec_outputs = nn.LayerNorm(output + residual)
這樣,我們就完成了殘差連接與層歸一化,如圖所示:
需要注意的就是本層的輸出將會作為Q,傳遞給下一個多頭。
七、第二層多頭自注意力機制
這一層的數(shù)據(jù)來自于兩個位置:一個是Encoder的輸出,一個是Decoder的第一層多頭輸出。
7.1 Encoder的輸出
在編碼端,會產(chǎn)生一個對原始輸入的高級表達,假設變量名為:enc_outputs。
這個編碼端的輸出將會被當作K矩陣和V矩陣使用。
其中,enc_outputs既是K又是V,兩者完全一樣。
同時,還需要將Encoder的原始輸入傳遞給第二層多頭,需要根據(jù)這個原始輸入形成掩碼矩陣,讓解碼時關注K,以確保解碼器在生成每個詞時,只利用編碼器輸出中有效的、相關的信息。
7.2 Decoder第一層多頭輸出
Decoder第一層多頭輸出的結(jié)果就是對原始解碼輸入的抽象表達,會被當作Q輸入第二層多頭內(nèi)。
7.3 Encoder—Decoder邏輯
當Q給多頭后,會通過與Encoder的K進行點乘,得到與要查詢(Q)的注意力,判斷哪些內(nèi)容與Q類似,最后再與V計算,得到最終的輸出。
一個通俗的理解就是你買了一個變形金剛(transformer是變形金剛)的玩具,拿到的東西會有:各個零件、組裝說明書。
Encoder的結(jié)果就是提供了各個零件的組裝說明書;Decoder就是各個零件的查詢。
當我們想要將其組裝好時,會在Decoder內(nèi)查詢(Q)每個零件的用法,通過組裝說明書(Encoder的結(jié)果),尋找與當前零件最相似的說明(尋找高注意力的零件:Q · K^T),最后根據(jù)零件用法進行組裝(輸出最后結(jié)果)成一個變形金剛。
當然,這里的QKV的流程是人類賦予的,只是與現(xiàn)實中的數(shù)據(jù)庫查表過程類似,與組裝變形金剛類似,讓我們更加容易接受的一種說法。
個人對Encoder—Decoder的理解:
Encoder就是對原始輸入的高階抽象表達(低級表達的高級抽象),比如不同語言雖然表面上不一樣(各種各樣的寫法),但是在更高維的向量空間是相似的,存在一套大一統(tǒng)高階向量,只是我們無法用語言來描述。
就像有些數(shù)據(jù)在二維不可分,但是落在高維空間就可分了一樣,我們無法觀察更高維的空間,但它們是存在的。
Decoder就是當我們在高維空間查詢(Q)時,就能找到相似(注意力強弱)的高級抽象表達(Q · K^T),然后通過V輸出為低級表達,讓我們能看得懂。
正如作為三維生物的我們無法理解更高維空間,需要一個中介將高維空間轉(zhuǎn)化為低維空間表達;或者我們作為三維生物,如果二維生物可溝通,我們亦可當作中介將高維空間翻譯給低維生物,我們就是這個多頭機制,我們就是這個中介。
所以,Encoder—Decoder機制更像是一個大的翻譯,大的中介,形成了高維與低維通信的方法。
7.4 Decoder第二層多頭的掩碼機制
由于Q代表當前要生成的詞,而K代表編碼器中的詞,我們要去零件庫尋找與Q相似的零件,那么我們希望的肯定是這個零件是有效的,以確保解碼器在生成每個詞時,只利用編碼器輸出中有效的、相關的信息。
因此,我們在形成掩碼矩陣時,主要根據(jù)Encoder的輸入序列來形成掩碼矩陣,而不是Decoder輸入的序列,當然,Q肯定也有Padding,但這里我們不必過度關注。
所以,在第二層多頭掩碼時,我們更關注K的Padding掩碼,確保解碼得到的信息也是有效的。
具體代碼如下:
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
比如,輸入序列為:[1,2,3,4,0],最后一位是Padding,那么掩碼矩陣如下圖示例:
7.5 Decoder第二層多頭的輸出
至此,我們就將Q進行了解碼輸出,得到了與Q最為相似的表達。
如圖所示步驟:
八、殘差連接和層歸一化
接著就是再次殘差連接與層歸一化,與上述步驟完全一樣。(為了保持連貫性,還是再寫一遍)
8.1 殘差連接
為了解決深層網(wǎng)絡退化問題,殘差連接是經(jīng)常被使用的,在代碼內(nèi)也就一行:
dec_outputs = output + residual
dec_outputs:最終輸出,判斷原始輸出好還是經(jīng)過多頭后的輸出好
output:經(jīng)過多頭后的輸出
residual:原始輸入
8.2 層歸一化
層歸一化只需要將最終輸出做一下LayerNorm即可,也是一行代碼即可完成。
dec_outputs = nn.LayerNorm(output + residual)
這樣,我們就完成了殘差連接與層歸一化,如圖所示:
九、前饋神經(jīng)網(wǎng)絡
該前饋神經(jīng)網(wǎng)絡接受來自歸一化后的輸出,shape為(batch_size, seq_length, dmodel)。
batch_size:批處理大小
seq_length:序列長度,比如64個Token
dmodel:詞嵌入映射維度
在Transformer中,由于我們輸入的是序列數(shù)據(jù),因此Conv1d就被優(yōu)先考慮使用了,但是這里需要一定的前置知識(點擊可看Conv1d原理),才能懂得Transformer中FNN的工作原理。
9.1 Conv1d連接
由于在Pytorch內(nèi)Conv1d要求維度優(yōu)先,因此需要對輸入數(shù)據(jù)進行維度轉(zhuǎn)換。
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1,bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1,bias=False)
# 維度優(yōu)先由6維升維到d_ff維
output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) # 這里因為一維卷積需要維度再前所以要交換特征位置
# 再降維到d_model維度,一般為詞嵌入維度
output = self.conv2(output).transpose(1, 2)
假設輸入的數(shù)據(jù)為(1, 5 , 6),d_ff為2048,那么整個過程的數(shù)據(jù)shape流動就為:
(1,5,6) ->轉(zhuǎn)置>> (1,6,5)->卷積>> (1,2048,5)->卷積>>(1,6,5)->轉(zhuǎn)置>>(1,5,6)
經(jīng)過調(diào)整提取之后,輸出shape仍為(1,5,6)。
至此,前饋神經(jīng)網(wǎng)絡完成。
十、殘差連接和層歸一化
接著就是再次殘差連接與層歸一化,與上述步驟完全一樣。(為了保持連貫性,還是再寫一遍)
10.1 殘差連接
為了解決深層網(wǎng)絡退化問題,殘差連接是經(jīng)常被使用的,在代碼內(nèi)也就一行:
dec_outputs = output + residual
dec_outputs:最終輸出,判斷原始輸出好還是經(jīng)過多頭后的輸出好
output:經(jīng)過多頭后的輸出
residual:原始輸入
10.2 層歸一化
層歸一化只需要將最終輸出做一下LayerNorm即可,也是一行代碼即可完成。
dec_outputs = nn.LayerNorm(output + residual)
這樣,我們就完成了殘差連接與層歸一化,如圖所示:
十一、線性層及Softmax
在得到歸一化的輸出之后,經(jīng)過線性層將輸出維度映射為指定類別,比如,你的目標詞表有1000個詞,那么就會通過softmax預測輸出的這1000個詞哪個詞概率最高,最高的即是目標輸出。
代碼如下:
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
dec_logits = self.projection(dec_outputs).view(-1, dec_logits.size(-1))
可以看到,經(jīng)過50個epoch,是能夠?qū)⒌挛某晒Ψg為英文的,至此,我們已經(jīng)將Transformer剖析完成。
圖片