RNN 從理論到PyTorch
讓我向您展示什么是RNN,在哪里使用,它們?nèi)绾蜗蚯昂拖蚝髠鞑ヒ约叭绾卧赑yTorch中使用它們。
大多數(shù)類型的神經(jīng)網(wǎng)絡(luò)都可以對(duì)要對(duì)其進(jìn)行訓(xùn)練的樣本進(jìn)行預(yù)測(cè)。一個(gè)主要的例子是MNIST數(shù)據(jù)集。像MLP這樣的常規(guī)神經(jīng)網(wǎng)絡(luò)知道有10位數(shù)字,即使圖像與訓(xùn)練網(wǎng)絡(luò)上的圖像非常不同,它也僅基于它們進(jìn)行預(yù)測(cè)。
現(xiàn)在,假設(shè)我們可以通過(guò)提供9個(gè)有序數(shù)字的序列,并讓網(wǎng)絡(luò)猜測(cè)第10個(gè)數(shù)字,來(lái)利用這種網(wǎng)絡(luò)進(jìn)行順序分析。網(wǎng)絡(luò)不僅會(huì)知道如何區(qū)分10位數(shù)字,而且還會(huì)知道從0到8的順序,下一位數(shù)字很可能是9。
在分析序列數(shù)據(jù)時(shí),我們了解到,序列中的元素通常以某種方式相關(guān),這意味著它們彼此依賴。因此,我們需要考慮每個(gè)元素以了解序列的想法。
劍橋大學(xué)出版社將序列定義為"事物或事件彼此跟隨的順序",或者最重要的是,"一系列相關(guān)事物或事件"。為了將此定義調(diào)整到深度學(xué)習(xí)的范圍內(nèi),順序是一組包含可訓(xùn)練上下文的數(shù)據(jù),刪除一些元素可能會(huì)使它無(wú)用。
但是序列包含什么?哪些分組數(shù)據(jù)可以具有上下文?以及如何提取上下文來(lái)利用神經(jīng)網(wǎng)絡(luò)的力量?在進(jìn)入神經(jīng)網(wǎng)絡(luò)本身之前,讓我向您展示使用遞歸神經(jīng)網(wǎng)絡(luò)(RNN)經(jīng)常解決的兩種類型的問(wèn)題。
時(shí)間序列預(yù)測(cè)
第一個(gè)示例是時(shí)間序列預(yù)測(cè)問(wèn)題,其中我們用一系列現(xiàn)有數(shù)值(藍(lán)色)訓(xùn)練神經(jīng)網(wǎng)絡(luò),以便預(yù)測(cè)未來(lái)的時(shí)間步長(zhǎng)(紅色)。
如果我們按照這些家庭多年來(lái)的每月精力支出進(jìn)行排序,我們可以看到正弦曲線趨勢(shì)呈上升趨勢(shì),而突然下降。
正弦曲線部分的背景可能是整個(gè)夏季(從夏季到冬季)再到夏季的不同能量需求。精力充沛的支出增長(zhǎng)可能來(lái)自使用更多的電器和設(shè)備,或者轉(zhuǎn)換為可能需要更多能源的更強(qiáng)大的電器和設(shè)備。突然跌倒的背景可能意味著一個(gè)人長(zhǎng)大了足以離開(kāi)家,而那個(gè)人所需的能量不再在那里。
您越了解上下文,通常可以通過(guò)連接輸入向量將更多信息提供給網(wǎng)絡(luò),以幫助網(wǎng)絡(luò)理解數(shù)據(jù)。在這種情況下,對(duì)于每個(gè)月,我們可以將三個(gè)更多的值與能源聯(lián)系起來(lái),這些價(jià)值包括電器和設(shè)備的數(shù)量,其能量效率以及家庭容納的人數(shù)。
自然語(yǔ)言處理
瑪麗騎自行車,自行車是____。
第二個(gè)例子是自然語(yǔ)言處理問(wèn)題。這也是一個(gè)很好的例子,因?yàn)樯窠?jīng)網(wǎng)絡(luò)必須考慮現(xiàn)有句子提供的上下文來(lái)完成它。
假設(shè)我們的網(wǎng)絡(luò)經(jīng)過(guò)訓(xùn)練,可以用所有格代詞完成句子。一個(gè)受過(guò)良好訓(xùn)練的網(wǎng)絡(luò)將理解該句子是用第三人稱單數(shù)構(gòu)成的,并且Mary最有可能是女性名字。因此,預(yù)測(cè)代詞應(yīng)該是"她的"而不是男性的"他的"或復(fù)數(shù)的"他們的"。
現(xiàn)在,我們已經(jīng)看到了兩個(gè)排序數(shù)據(jù)的例子,讓我們探索網(wǎng)絡(luò)向前和向后傳播的過(guò)程。
RNN配置
如我們所見(jiàn),RNN從序列中提取信息以提高其預(yù)測(cè)能力。
> Simple recurrent network diagram. Figure by author.
上面顯示了一個(gè)簡(jiǎn)單的RNN圖。綠色節(jié)點(diǎn)輸入一些輸入x ^ t并輸出一些值h ^ t,該值也被饋送到該節(jié)點(diǎn),再次包含從輸入中收集的信息。不管饋入節(jié)點(diǎn)的內(nèi)容有什么模式,它都會(huì)學(xué)習(xí)并保留該信息以供下一次輸入。上標(biāo)t代表時(shí)間步長(zhǎng)。
> Recurrent network configurations. Figure by author.
根據(jù)輸入或輸出的形狀,神經(jīng)網(wǎng)絡(luò)的配置會(huì)有一些變化,稍后我們將了解節(jié)點(diǎn)內(nèi)部會(huì)發(fā)生什么。
多對(duì)一配置是指我們以不同的時(shí)間步長(zhǎng)輸入多個(gè)輸入以獲得一個(gè)輸出時(shí),這可能是在電影場(chǎng)景的各個(gè)幀中捕獲的情感分析。
一對(duì)多使用一個(gè)輸入來(lái)獲取多個(gè)輸出。例如,我們可以使用多對(duì)一配置對(duì)表達(dá)某種情感的詩(shī)歌進(jìn)行編碼,并使用一對(duì)多配置來(lái)創(chuàng)建具有相同情感的新詩(shī)行。
多對(duì)多使用多個(gè)輸入來(lái)獲取多個(gè)輸出,例如使用一系列值(例如在能量使用中)并預(yù)測(cè)未來(lái)的十二個(gè)月而不是一個(gè)月。
堆疊配置只是一個(gè)具有多個(gè)隱藏節(jié)點(diǎn)層的網(wǎng)絡(luò)。
RNN前傳
為了了解神經(jīng)網(wǎng)絡(luò)節(jié)點(diǎn)內(nèi)部發(fā)生的情況,我們將使用一個(gè)簡(jiǎn)單的數(shù)據(jù)集作為"時(shí)間序列預(yù)測(cè)"示例。貝婁是價(jià)值的完整序列,它是作為培訓(xùn)和測(cè)試數(shù)據(jù)集而進(jìn)行的重組。
我從這個(gè)網(wǎng)站上拿了這個(gè)例子,這是一般而言深度學(xué)習(xí)的重要資源。現(xiàn)在,讓我們將數(shù)據(jù)集分成批次。
我在這里沒(méi)有顯示它,但不要忘記應(yīng)該對(duì)數(shù)據(jù)集進(jìn)行規(guī)范化。這很重要,因?yàn)樯窠?jīng)網(wǎng)絡(luò)對(duì)數(shù)據(jù)集值的大小很敏感。
這個(gè)想法是預(yù)測(cè)未來(lái)的價(jià)值。因此,假設(shè)我們選擇了該批次的第一行:[10 20 30],在訓(xùn)練了我們的網(wǎng)絡(luò)之后,我們應(yīng)該得到40的值。要測(cè)試神經(jīng)網(wǎng)絡(luò),可以將向量輸入[70 80 90],并期望獲得一個(gè)如果網(wǎng)絡(luò)訓(xùn)練有素,則值接近100。
我們將使用多對(duì)一配置,分別提供每個(gè)序列的三個(gè)時(shí)間步。當(dāng)使用遞歸網(wǎng)絡(luò)時(shí),輸入值不是進(jìn)入網(wǎng)絡(luò)的唯一值,還有一個(gè)隱藏的數(shù)組,該數(shù)組的結(jié)構(gòu)將在節(jié)點(diǎn)之間傳遞序列的上下文。我們將其初始化為零數(shù)組,并將其連接到輸入。它的尺寸(1 x 2)是個(gè)人選擇,只是使用與1 x 1步進(jìn)輸入不同的尺寸。
> Recurrent forward pass. Figure by author.
仔細(xì)觀察,我們可以看到權(quán)重矩陣分為兩部分。第一個(gè)處理輸入創(chuàng)建兩個(gè)輸出,第二個(gè)處理隱藏?cái)?shù)組創(chuàng)建兩個(gè)輸出。然后將這兩組輸出加在一起,并獲得一個(gè)新的隱藏?cái)?shù)組,其中包含來(lái)自第一個(gè)輸入(10)的信息,并將其饋送到下一個(gè)時(shí)間步輸入(20)。應(yīng)當(dāng)注意,權(quán)重和偏差矩陣在時(shí)間步長(zhǎng)之間是相同的。
上面表示了全局輸入向量X ^ t,權(quán)重矩陣W和偏置矩陣B以及隱藏?cái)?shù)組的計(jì)算。僅剩一個(gè)步驟才能完成前進(jìn)。我們正在嘗試預(yù)測(cè)一個(gè)未來(lái)的價(jià)值,我們有三個(gè)隱藏的數(shù)組,每個(gè)輸入的信息作為輸出,因此我們需要將它們轉(zhuǎn)換為單個(gè)值,希望在經(jīng)過(guò)許多次培訓(xùn)后才是正確的值。
通過(guò)連接并重塑數(shù)組,我們可以附加一個(gè)線性層來(lái)計(jì)算最終結(jié)果。完整的網(wǎng)絡(luò)具有以下形式:
> Full recurrent diagram. Figure by author.
您可以看到多對(duì)一配置嗎?我們從一個(gè)序列中饋入三個(gè)輸入,它們的上下文由權(quán)重和偏差矩陣捕獲,并存儲(chǔ)在一個(gè)隱藏的數(shù)組中,該數(shù)組在每個(gè)時(shí)間步都用新信息更新。最終,存儲(chǔ)在隱藏?cái)?shù)組中的上下文將經(jīng)歷另一組權(quán)重和偏差,并且在將序列的所有時(shí)間步長(zhǎng)輸入到網(wǎng)絡(luò)之后,將輸出一個(gè)值。
我們可以看到隱藏狀態(tài)的線性形式以及線性層的權(quán)重和偏差矩陣,以及預(yù)測(cè)值的計(jì)算(y hat)。
現(xiàn)在,這是RNN的前向傳播,但我們?nèi)匀粵](méi)有看到向后傳遞。
RNN向后傳遞
向后傳播是訓(xùn)練每個(gè)神經(jīng)網(wǎng)絡(luò)的非常重要的一步。在此,預(yù)測(cè)輸出和實(shí)際值之間的誤差朝著神經(jīng)網(wǎng)絡(luò)傳播,目的是改善權(quán)重和偏差,以便每次迭代都能獲得更好的預(yù)測(cè)。
在大多數(shù)情況下,此步驟由于其復(fù)雜性而被忽略了。在提及重要內(nèi)容的同時(shí),我將向您提供一個(gè)盡可能簡(jiǎn)單的解釋。
向后傳遞是使用微積分的鏈法則從損耗到所有權(quán)重和偏差參數(shù)的一系列推導(dǎo)。這意味著我們最終需要以下值(如果是多維的,則為數(shù)組):
如果您不熟悉一階導(dǎo)數(shù)的數(shù)學(xué)含義,但實(shí)際上,當(dāng)一階導(dǎo)數(shù)為零時(shí),我通常建議您閱讀有關(guān)梯度下降的文章,這通常意味著我們?cè)谙到y(tǒng)中找到了一個(gè)最小值,并且理想情況下我們將無(wú)法進(jìn)一步改善它。
這里需要注意的一點(diǎn)是:零也可能是最大值,它是不穩(wěn)定的,不應(yīng)在那里進(jìn)行優(yōu)化,或者是鞍點(diǎn),其本身也不是很穩(wěn)定。最小值可以是全局值(函數(shù)的最小值)或局部值。這對(duì)我的解釋并不重要,但是如果您想了解更多信息,可以查找一下!
> Gradient Descent example. Two balls rolling down the hill. Figure by author.
您在圖片中看到的是兩個(gè)球從山谷中滾下來(lái)。從視覺(jué)上看,一階導(dǎo)數(shù)給了我們山坡的大小。如果我們沿W軸增加的方向(從左到右)行進(jìn),則對(duì)綠色球的傾斜度為負(fù)(向下),對(duì)于紅色球的傾斜度為正(向上)。
仔細(xì)閱讀下一段,然后根據(jù)需要返回到該圖。
如果我們希望損失最小,我們希望球到達(dá)山谷的最低點(diǎn)。W代表權(quán)重和偏差的值,因此,如果我們處于綠色球的位置,我們將減去負(fù)導(dǎo)數(shù)的一部分(使其為正值)到綠色球的W位置,將其向右移動(dòng)并減去一部分將紅色球的W位置向左移動(dòng)的正導(dǎo)數(shù)(使其變?yōu)樨?fù)數(shù)),以使兩個(gè)球都接近最小值。
從數(shù)學(xué)上講,我們有以下內(nèi)容:
η調(diào)整我們用來(lái)更新權(quán)重和偏差的導(dǎo)數(shù)的比例。
現(xiàn)在,繼續(xù)進(jìn)行向后傳遞問(wèn)題。我將介紹從損失到所有參數(shù)的鏈?zhǔn)綄?dǎo)數(shù),我們將看到每種導(dǎo)數(shù)代表什么。重要的是要牢記上面介紹的各層的方程式以及它們的參數(shù)矩陣。
要記住的一件事是,我們要查找的四個(gè)一階導(dǎo)數(shù)數(shù)組的形狀必須與我們要更新的參數(shù)相同。例如,陣列dL / dW_h的形狀必須與權(quán)重陣列W_h相同。上標(biāo)T表示矩陣已轉(zhuǎn)置。
我們一直追溯到線性層的參數(shù)。因?yàn)槲覀儗㈦[藏狀態(tài)數(shù)組重塑為線性向量,所以我們應(yīng)將dL / dH ^ t重塑為串聯(lián)的隱藏狀態(tài)數(shù)組的原始形狀。目前,它是一個(gè)6 x 1的數(shù)組,但從循環(huán)圖層計(jì)算得出的隱藏?cái)?shù)組的形狀是3 x 2。我們還將所有全局輸入連接在一起(t = 1、2和3),現(xiàn)在我們可以繼續(xù)進(jìn)行反向傳遞了。
現(xiàn)在剩下要做的就是應(yīng)用我們之前看到的Gradient Descent方程來(lái)更新參數(shù),并且模型可以進(jìn)行下一次迭代了。讓我們看看如何使用PyTorch構(gòu)建簡(jiǎn)單的RNN。
PyTorch的RNN
使用PyTorch非常簡(jiǎn)單,因?yàn)槲覀冋娴牟恍枰獡?dān)心向后傳遞。但是,即使我們不直接使用它,我仍然相信了解它的工作原理很重要。
繼續(xù),如果我們參考PyTorch的文檔,我們可以看到它們已經(jīng)具有可以使用的RNN對(duì)象。定義它時(shí),有兩個(gè)基本參數(shù):
- input_size —輸入x中預(yù)期要素的數(shù)量
- hidden_size —處于隱藏狀態(tài)h的特征數(shù)
input_size為1,因?yàn)槲覀円淮问褂妹總€(gè)序列的一個(gè)時(shí)間步長(zhǎng)(例如,序列10、20、30中的10),而hidden_size為2,因?yàn)槲覀儷@得了包含兩個(gè)值的隱藏狀態(tài)。
將n_layers參數(shù)定義為2意味著我們有一個(gè)帶有兩個(gè)隱藏層的堆疊RNN。
另外,我們還將參數(shù)batch_first定義為True。這意味著輸入和輸出中的批次尺寸排在首位(輸入和輸出不要錯(cuò))
輸入:輸入,h_0
- 形狀的輸入(seq_len,batch,input_size):包含輸入序列特征的張量。
- h_0的形狀(num_layers * num_directions,batch,hidden_size):張量,包含批次中每個(gè)元素的初始隱藏狀態(tài)。
RNN的輸入應(yīng)該是形狀為1 x 3 x 1的輸入數(shù)組。該序列包含三個(gè)時(shí)間步長(zhǎng),分別是數(shù)據(jù)集的第一批10、20和30。從每個(gè)批次中,大小為1的輸入將作為該序列的三個(gè)時(shí)間步長(zhǎng)被饋送到網(wǎng)絡(luò)三遍。
隱藏狀態(tài)h_0是我們的第一個(gè)隱藏?cái)?shù)組,我們將其與形狀為1 x 1 x 2的第一時(shí)間步輸入一起饋入網(wǎng)絡(luò)。
輸出:輸出,h_n
- 形狀的輸出(seq_len,batch,num_directions * hidden_size):張量,包含每個(gè)t的RNN的最后一層的輸出特征(h_t)。
- h_n的形狀(num_layers * num_directions,batch,hidden_size):張量包含t = seq_len的隱藏狀態(tài)。
輸出包含形狀為1 x 3 x 2的每個(gè)時(shí)間步長(zhǎng)由神經(jīng)網(wǎng)絡(luò)計(jì)算的所有隱藏狀態(tài),h_n是最后一個(gè)時(shí)間步長(zhǎng)的隱藏狀態(tài)。這對(duì)于保持有用很有用,因?yàn)槿绻覀冞x擇使用堆疊式遞歸網(wǎng)絡(luò),則這將是隱藏狀態(tài),該狀態(tài)將在第一時(shí)間步進(jìn)給,形狀為1 x 1 x 2。
所有這些數(shù)組都在上面的示例中表示,并且可以在RNN圖中看到。需要注意的另一件事是,使用遞歸網(wǎng)絡(luò)和"時(shí)間序列預(yù)測(cè)"的特定示例,將num_directions設(shè)置為2將意味著預(yù)測(cè)未來(lái)和過(guò)去。此處將不考慮這種類型的配置。
我將在實(shí)現(xiàn)RNN以及如何對(duì)其進(jìn)行培訓(xùn)的過(guò)程中留下一段代碼。我還將將其留給您使用,以根據(jù)需要與所需的數(shù)據(jù)集一起使用。在使用網(wǎng)絡(luò)之前,請(qǐng)不要忘記規(guī)范化數(shù)據(jù)并創(chuàng)建數(shù)據(jù)集和數(shù)據(jù)加載器。
總結(jié)思想
為了以一個(gè)簡(jiǎn)短的總結(jié)來(lái)結(jié)束這個(gè)故事,我們首先看到了通常使用遞歸網(wǎng)絡(luò)解決的兩種類型的問(wèn)題,即時(shí)間序列預(yù)測(cè)和自然語(yǔ)言處理。
后來(lái),我們看到了一些典型配置的示例,以及一個(gè)實(shí)際示例,其目的是使用多對(duì)一配置預(yù)測(cè)未來(lái)的一步。
在前向傳遞中,我們了解了輸入和隱藏狀態(tài)如何與遞歸層的權(quán)重和偏差交互作用,以及如何使用隱藏狀態(tài)中包含的信息來(lái)預(yù)測(cè)下一個(gè)時(shí)間步長(zhǎng)值。
反向傳遞只是鏈規(guī)則的應(yīng)用,從損失梯度相對(duì)于預(yù)測(cè)的關(guān)系到相對(duì)于我們要優(yōu)化的參數(shù)的變化。
最后,我們?yōu)g覽了有關(guān)RNN的PyTorch文檔的一部分,并討論了用于構(gòu)建基本循環(huán)網(wǎng)絡(luò)的最重要部分。
感謝您的閱讀!也許您從這個(gè)冗長(zhǎng)的故事中得到了一些啟示。我寫它們是為了幫助我理解新概念,并希望也能幫助其他人。
原文鏈接:https://towardsdatascience.com/rnns-from-theory-to-pytorch-f0af30b610e1