訓(xùn)練的神經(jīng)網(wǎng)絡(luò)不工作?一文帶你跨過這37個(gè)坑
近日,Slav Ivanov 在 Medium 上發(fā)表了一篇題為《37 Reasons why your Neural Network is not working》的文章,從四個(gè)方面(數(shù)據(jù)集、數(shù)據(jù)歸一化/增強(qiáng)、實(shí)現(xiàn)、訓(xùn)練),對自己長久以來的神經(jīng)網(wǎng)絡(luò)調(diào)試經(jīng)驗(yàn)做了 37 條總結(jié),并穿插了不少出色的個(gè)人想法和思考,希望能幫助你跨過神經(jīng)網(wǎng)絡(luò)訓(xùn)練中的 37 個(gè)大坑。機(jī)器之心對該文進(jìn)行了編譯。
神經(jīng)網(wǎng)絡(luò)已經(jīng)持續(xù)訓(xùn)練了 12 個(gè)小時(shí)。它看起來很好:梯度在變化,損失也在下降。但是預(yù)測結(jié)果出來了:全部都是零值,全部都是背景,什么也檢測不到。我質(zhì)問我的計(jì)算機(jī):「我做錯(cuò)了什么?」,它卻無法回答。
如果你的模型正在輸出垃圾(比如預(yù)測所有輸出的平均值,或者它的精確度真的很低),那么你從哪里開始檢查呢?
無法訓(xùn)練神經(jīng)網(wǎng)絡(luò)的原因有很多,因此通過總結(jié)諸多調(diào)試,作者發(fā)現(xiàn)有一些檢查是經(jīng)常做的。這張列表匯總了作者的經(jīng)驗(yàn)以及***的想法,希望也對讀者有所幫助。
一、數(shù)據(jù)集問題
1. 檢查你的輸入數(shù)據(jù)
檢查饋送到網(wǎng)絡(luò)的輸入數(shù)據(jù)是否正確。例如,我不止一次混淆了圖像的寬度和高度。有時(shí),我錯(cuò)誤地令輸入數(shù)據(jù)全部為零,或者一遍遍地使用同一批數(shù)據(jù)執(zhí)行梯度下降。因此打印/顯示若干批量的輸入和目標(biāo)輸出,并確保它們正確。
2. 嘗試隨機(jī)輸入
嘗試傳遞隨機(jī)數(shù)而不是真實(shí)數(shù)據(jù),看看錯(cuò)誤的產(chǎn)生方式是否相同。如果是,說明在某些時(shí)候你的網(wǎng)絡(luò)把數(shù)據(jù)轉(zhuǎn)化為了垃圾。試著逐層調(diào)試,并查看出錯(cuò)的地方。
3. 檢查數(shù)據(jù)加載器
你的數(shù)據(jù)也許很好,但是讀取輸入數(shù)據(jù)到網(wǎng)絡(luò)的代碼可能有問題,所以我們應(yīng)該在所有操作之前打印***層的輸入并進(jìn)行檢查。
4. 確保輸入與輸出相關(guān)聯(lián)
檢查少許輸入樣本是否有正確的標(biāo)簽,同樣也確保 shuffling 輸入樣本同樣對輸出標(biāo)簽有效。
5. 輸入與輸出之間的關(guān)系是否太隨機(jī)?
相較于隨機(jī)的部分(可以認(rèn)為股票價(jià)格也是這種情況),輸入與輸出之間的非隨機(jī)部分也許太小,即輸入與輸出的關(guān)聯(lián)度太低。沒有一個(gè)統(tǒng)一的方法來檢測它,因?yàn)檫@要看數(shù)據(jù)的性質(zhì)。
6. 數(shù)據(jù)集中是否有太多的噪音?
我曾經(jīng)遇到過這種情況,當(dāng)我從一個(gè)食品網(wǎng)站抓取一個(gè)圖像數(shù)據(jù)集時(shí),錯(cuò)誤標(biāo)簽太多以至于網(wǎng)絡(luò)無法學(xué)習(xí)。手動檢查一些輸入樣本并查看標(biāo)簽是否大致正確。
7. Shuffle 數(shù)據(jù)集
如果你的數(shù)據(jù)集沒有被 shuffle,并且有特定的序列(按標(biāo)簽排序),這可能給學(xué)習(xí)帶來不利影響。你可以 shuffle 數(shù)據(jù)集來避免它,并確保輸入和標(biāo)簽都被重新排列。
8. 減少類別失衡
一張類別 B 圖像和 1000 張類別 A 圖像?如果是這種情況,那么你也許需要平衡你的損失函數(shù)或者嘗試其他解決類別失衡的方法。
9. 你有足夠的訓(xùn)練實(shí)例嗎?
如果你在從頭開始訓(xùn)練一個(gè)網(wǎng)絡(luò)(即不是調(diào)試),你很可能需要大量數(shù)據(jù)。對于圖像分類,每個(gè)類別你需要 1000 張圖像甚至更多。
10. 確保你采用的批量數(shù)據(jù)不是單一標(biāo)簽
這可能發(fā)生在排序數(shù)據(jù)集中(即前 10000 個(gè)樣本屬于同一個(gè)分類)。可通過 shuffle 數(shù)據(jù)集輕松修復(fù)。
11. 縮減批量大小
巨大的批量大小會降低模型的泛化能力(參閱:https://arxiv.org/abs/1609.04836)
二、數(shù)據(jù)歸一化/增強(qiáng)
12. 歸一化特征
你的輸入已經(jīng)歸一化到零均值和單位方差了嗎?
13. 你是否應(yīng)用了過量的數(shù)據(jù)增強(qiáng)?
數(shù)據(jù)增強(qiáng)有正則化效果(regularizing effect)。過量的數(shù)據(jù)增強(qiáng),加上其它形式的正則化(權(quán)重 L2,中途退出效應(yīng)等)可能會導(dǎo)致網(wǎng)絡(luò)欠擬合(underfit)。
14. 檢查你的預(yù)訓(xùn)練模型的預(yù)處理過程
如果你正在使用一個(gè)已經(jīng)預(yù)訓(xùn)練過的模型,確保你現(xiàn)在正在使用的歸一化和預(yù)處理與之前訓(xùn)練模型時(shí)的情況相同。例如,一個(gè)圖像像素應(yīng)該在 [0, 1],[-1, 1] 或 [0, 255] 的范圍內(nèi)嗎?
15. 檢查訓(xùn)練、驗(yàn)證、測試集的預(yù)處理
CS231n 指出了一個(gè)常見的陷阱:「任何預(yù)處理數(shù)據(jù)(例如數(shù)據(jù)均值)必須只在訓(xùn)練數(shù)據(jù)上進(jìn)行計(jì)算,然后再應(yīng)用到驗(yàn)證、測試數(shù)據(jù)中。例如計(jì)算均值,然后在整個(gè)數(shù)據(jù)集的每個(gè)圖像中都減去它,再把數(shù)據(jù)分發(fā)進(jìn)訓(xùn)練、驗(yàn)證、測試集中,這是一個(gè)典型的錯(cuò)誤?!勾送猓诿恳粋€(gè)樣本或批量(batch)中檢查不同的預(yù)處理。
三、實(shí)現(xiàn)的問題
16. 試著解決某一問題的更簡易的版本。
這將會有助于找到問題的根源究竟在哪里。例如,如果目標(biāo)輸出是一個(gè)物體類別和坐標(biāo),那就試著把預(yù)測結(jié)果僅限制在物體類別當(dāng)中(嘗試去掉坐標(biāo))。
17.「碰巧」尋找正確的損失
還是來源于 CS231n 的技巧:用小參數(shù)進(jìn)行初始化,不使用正則化。例如,如果我們有 10 個(gè)類別,「碰巧」就意味著我們將會在 10% 的時(shí)間里得到正確類別,Softmax 損失是正確類別的負(fù) log 概率: -ln(0.1) = 2.302。然后,試著增加正則化的強(qiáng)度,這樣應(yīng)該會增加損失。
18. 檢查你的損失函數(shù)
如果你執(zhí)行的是你自己的損失函數(shù),那么就要檢查錯(cuò)誤,并且添加單元測試。通常情況下,損失可能會有些不正確,并且損害網(wǎng)絡(luò)的性能表現(xiàn)。
19. 核實(shí)損失輸入
如果你正在使用的是框架提供的損失函數(shù),那么要確保你傳遞給它的東西是它所期望的。例如,在 PyTorch 中,我會混淆 NLLLoss 和 CrossEntropyLoss,因?yàn)橐粋€(gè)需要 softmax 輸入,而另一個(gè)不需要。
20. 調(diào)整損失權(quán)重
如果你的損失由幾個(gè)更小的損失函數(shù)組成,那么確保它們每一個(gè)的相應(yīng)幅值都是正確的。這可能會涉及到測試損失權(quán)重的不同組合。
21. 監(jiān)控其它指標(biāo)
有時(shí)損失并不是衡量你的網(wǎng)絡(luò)是否被正確訓(xùn)練的***預(yù)測器。如果可以的話,使用其它指標(biāo)來幫助你,比如精度。
22. 測試任意的自定義層
你自己在網(wǎng)絡(luò)中實(shí)現(xiàn)過任意層嗎?檢查并且復(fù)核以確保它們的運(yùn)行符合預(yù)期。
23. 檢查「冷凍」層或變量
檢查你是否無意中阻止了一些層或變量的梯度更新,這些層或變量本來應(yīng)該是可學(xué)的。
24. 擴(kuò)大網(wǎng)絡(luò)規(guī)模
可能你的網(wǎng)絡(luò)的表現(xiàn)力不足以采集目標(biāo)函數(shù)。試著加入更多的層,或在全連層中增加更多的隱藏單元。
25. 檢查隱維度誤差
如果你的輸入看上去像(k,H,W)= (64, 64, 64),那么很容易錯(cuò)過與錯(cuò)誤維度相關(guān)的誤差。給輸入維度使用一些「奇怪」的數(shù)值(例如,每一個(gè)維度使用不同的質(zhì)數(shù)),并且檢查它們是如何通過網(wǎng)絡(luò)傳播的。
26. 探索梯度檢查(Gradient checking)
如果你手動實(shí)現(xiàn)梯度下降,梯度檢查會確保你的反向傳播(backpropagation)能像預(yù)期中一樣工作。
四、訓(xùn)練問題
27. 一個(gè)真正小的數(shù)據(jù)集
過擬合數(shù)據(jù)的一個(gè)小子集,并確保其工作。例如,僅使用 1 或 2 個(gè)實(shí)例訓(xùn)練,并查看你的網(wǎng)絡(luò)是否學(xué)習(xí)了區(qū)分它們。然后再訓(xùn)練每個(gè)分類的更多實(shí)例。
28. 檢查權(quán)重初始化
如果不確定,請使用 Xavier 或 He 初始化。同樣,初始化也許會給你帶來壞的局部最小值,因此嘗試不同的初始化,看看是否有效。
29. 改變你的超參數(shù)
或許你正在使用一個(gè)很糟糕的超參數(shù)集。如果可行,嘗試一下網(wǎng)格搜索。
30. 減少正則化
太多的正則化可致使網(wǎng)絡(luò)嚴(yán)重地欠擬合。減少正則化,比如 dropout、批規(guī)范、權(quán)重/偏差 L2 正則化等。在優(yōu)秀課程《編程人員的深度學(xué)習(xí)實(shí)戰(zhàn)》(http://course.fast.ai)中,Jeremy Howard 建議首先解決欠擬合。這意味著你充分地過擬合數(shù)據(jù),并且只有在那時(shí)處理過擬合。
31. 給它一些時(shí)間
也許你的網(wǎng)絡(luò)需要更多的時(shí)間來訓(xùn)練,在它能做出有意義的預(yù)測之前。如果你的損失在穩(wěn)步下降,那就再多訓(xùn)練一會兒。
32. 從訓(xùn)練模式轉(zhuǎn)換為測試模式
一些框架的層很像批規(guī)范、Dropout,而其他的層在訓(xùn)練和測試時(shí)表現(xiàn)并不同。轉(zhuǎn)換到適當(dāng)?shù)哪J接兄诰W(wǎng)絡(luò)更好地預(yù)測。
33. 可視化訓(xùn)練
- 監(jiān)督每一層的激活值、權(quán)重和更新。確保它們的大小匹配。例如,參數(shù)更新的大小(權(quán)重和偏差)應(yīng)該是 1-e3。
 - 考慮可視化庫,比如 Tensorboard 和 Crayon。緊要時(shí)你也可以打印權(quán)重/偏差/激活值。
 - 尋找平均值遠(yuǎn)大于 0 的層激活。嘗試批規(guī)范或者 ELUs。
 
Deeplearning4j 指出了權(quán)重和偏差柱狀圖中的期望值:對于權(quán)重,一些時(shí)間之后這些柱狀圖應(yīng)該有一個(gè)近似高斯的(正常)分布。對于偏差,這些柱狀圖通常會從 0 開始,并經(jīng)常以近似高斯(這種情況的一個(gè)例外是 LSTM)結(jié)束。留意那些向 +/- ***發(fā)散的參數(shù)。留意那些變得很大的偏差。這有時(shí)可能發(fā)生在分類的輸出層,如果類別的分布不均勻。
- 檢查層更新,它們應(yīng)該有一個(gè)高斯分布。
 
34. 嘗試不同的優(yōu)化器
優(yōu)化器的選擇不應(yīng)當(dāng)妨礙網(wǎng)絡(luò)的訓(xùn)練,除非你選擇了一個(gè)特別糟糕的參數(shù)。但是,為任務(wù)選擇一個(gè)合適的優(yōu)化器非常有助于在最短的時(shí)間內(nèi)獲得最多的訓(xùn)練。描述你正在使用的算法的論文應(yīng)當(dāng)指定優(yōu)化器;如果沒有,我傾向于選擇 Adam 或者帶有動量的樸素 SGD。
35. 梯度爆炸、梯度消失
- 檢查隱蔽層的***情況,過大的值可能代表梯度爆炸。這時(shí),梯度截?cái)?Gradient clipping)可能會有所幫助。
 - 檢查隱蔽層的激活值。Deeplearning4j 中有一個(gè)很好的指導(dǎo)方針:「一個(gè)好的激活值標(biāo)準(zhǔn)差大約在 0.5 到 2.0 之間。明顯超過這一范圍可能就代表著激活值消失或爆炸?!?/li>
 
36. 增加、減少學(xué)習(xí)速率
- 低學(xué)習(xí)速率將會導(dǎo)致你的模型收斂很慢;
 - 高學(xué)習(xí)速率將會在開始階段減少你的損失,但是可能會導(dǎo)致你很難找到一個(gè)好的解決方案。
 - 試著把你當(dāng)前的學(xué)習(xí)速率乘以 0.1 或 10。
 
37. 克服 NaNs
據(jù)我所知,在訓(xùn)練 RNNs 時(shí)得到 NaN(Non-a-Number)是一個(gè)很大的問題。一些解決它的方法:
- 減小學(xué)習(xí)速率,尤其是如果你在前 100 次迭代中就得到了 NaNs。
 - NaNs 的出現(xiàn)可能是由于用零作了除數(shù),或用零或負(fù)數(shù)作了自然對數(shù)。
 - Russell Stewart 對如何處理 NaNs 很有心得(http://russellsstewart.com/notes/0.html)。
 - 嘗試逐層評估你的網(wǎng)絡(luò),這樣就會看見 NaNs 到底出現(xiàn)在了哪里。
 
原文:https://medium.com/@slavivanov/4020854bd607
【本文是51CTO專欄機(jī)構(gòu)“機(jī)器之心”的原創(chuàng)譯文,微信公眾號“機(jī)器之心( id: almosthuman2014)”】
















 
 
 














 
 
 
 