用 Java 訓練深度學習模型,原來可以這么簡單!
本文適合有 Java 基礎(chǔ)的人群
HelloGitHub 推出的 《講解開源項目》 系列。這一期是由亞馬遜工程師: Keerthan Vasist ,為我們講解 DJL(完全由 Java 構(gòu)建的深度學習平臺)系列的第 4 篇。
一、前言
很長時間以來,Java 都是一個很受企業(yè)歡迎的編程語言。得益于豐富的生態(tài)以及完善維護的包和框架,Java 擁有著龐大的開發(fā)者社區(qū)。盡管深度學習應用的不斷演進和落地,提供給 Java 開發(fā)者的框架和庫卻十分短缺。現(xiàn)今主要流行的深度學習模型都是用 Python 編譯和訓練的。對于 Java 開發(fā)者而言,如果要進軍深度學習界,就需要重新學習并接受一門新的編程語言同時還要學習深度學習的復雜知識。這使得大部分 Java 開發(fā)者學習和轉(zhuǎn)型深度學習開發(fā)變得困難重重。
為了減少 Java 開發(fā)者學習深度學習的成本,AWS 構(gòu)建了 Deep Java Library (DJL),一個為 Java 開發(fā)者定制的開源深度學習框架。它為 Java 開發(fā)者對接主流深度學習框架提供了一個橋梁。

在這篇文章中,我們會嘗試用 DJL 構(gòu)建一個深度學習模型并用它訓練 MNIST 手寫數(shù)字識別任務(wù)。
二、什么是深度學習?
在我們正式開始之前,我們先來了解一下機器學習和深度學習的基本概念。
機器學習是一個通過利用統(tǒng)計學知識,將數(shù)據(jù)輸入到計算機中進行訓練并完成特定目標任務(wù)的過程。這種歸納學習的方法可以讓計算機學習一些特征并進行一系列復雜的任務(wù),比如識別照片中的物體。由于需要寫復雜的邏輯以及測量標準,這些任務(wù)在傳統(tǒng)計算科學領(lǐng)域中很難實現(xiàn)。
深度學習是機器學習的一個分支,主要側(cè)重于對于人工神經(jīng)網(wǎng)絡(luò)的開發(fā)。人工神經(jīng)網(wǎng)絡(luò)是通過研究人腦如何學習和實現(xiàn)目標的過程中歸納而得出一套計算邏輯。它通過模擬部分人腦神經(jīng)間信息傳遞的過程,從而實現(xiàn)各類復雜的任務(wù)。深度學習中的“深度”來源于我們會在人工神經(jīng)網(wǎng)絡(luò)中編織構(gòu)建出許多層(layer)從而進一步對數(shù)據(jù)信息進行更深層的傳導。深度學習技術(shù)應用范圍十分廣泛,現(xiàn)在被用來做目標檢測、動作識別、機器翻譯、語意分析等各類現(xiàn)實應用中。
三、訓練 MNIST 手寫數(shù)字識別
3.1 項目配置
你可以用如下的 gradle 配置來引入依賴項。在這個案例中,我們用 DJL 的 api 包 (核心 DJL 組件) 和 basicdataset 包 (DJL 數(shù)據(jù)集) 來構(gòu)建神經(jīng)網(wǎng)絡(luò)和數(shù)據(jù)集。這個案例中我們使用了 MXNet 作為深度學習引擎,所以我們會引入 mxnet-engine 和 mxnet-native-auto 兩個包。這個案例也可以運行在 PyTorch 引擎下,只需要替換成對應的軟件包即可。
- plugins {
 - id 'java'
 - }
 - repositories {
 - jcenter()
 - }
 - dependencies {
 - implementation platform("ai.djl:bom:0.8.0")
 - implementation "ai.djl:api"
 - implementation "ai.djl:basicdataset"
 - // MXNet
 - runtimeOnly "ai.djl.mxnet:mxnet-engine"
 - runtimeOnly "ai.djl.mxnet:mxnet-native-auto"
 - }
 
3.2 NDArray 和 NDManager
NDArray 是 DJL 存儲數(shù)據(jù)結(jié)構(gòu)和數(shù)學運算的基本結(jié)構(gòu)。一個 NDArray 表達了一個定長的多維數(shù)組。NDArray 的使用方法類似于 Python 中的 numpy.ndarray 。
NDManager 是 NDArray 的老板。它負責管理 NDArray 的產(chǎn)生和回收過程,這樣可以幫助我們更好的對 Java 內(nèi)存進行優(yōu)化。每一個 NDArray 都會是由一個 NDManager 創(chuàng)造出來,同時它們會在 NDManager 關(guān)閉時一同關(guān)閉。NDManager 和 NDArray 都是由 Java 的 AutoClosable 構(gòu)建,這樣可以確保在運行結(jié)束時及時進行回收。想了解更多關(guān)于它們的用法和實踐,請參閱我們前一期文章:
DJL 之 Java 玩轉(zhuǎn)多維數(shù)組,就像 NumPy 一樣
Model
在 DJL 中,訓練和推理都是從 Model class 開始構(gòu)建的。我們在這里主要講訓練過程中的構(gòu)建方法。下面我們?yōu)?Model 創(chuàng)建一個新的目標。因為 Model 也是繼承了 AutoClosable 結(jié)構(gòu)體,我們會用一個 try block 實現(xiàn):
- try (Model model = Model.newInstance()) {
 - ...
 - // 主體訓練代碼
 - ...
 - }
 
準備數(shù)據(jù)
MNIST(Modified National Institute of Standards and Technology)數(shù)據(jù)庫包含大量手寫數(shù)字的圖,通常被用來訓練圖像處理系統(tǒng)。DJL 已經(jīng)將 MNIST 的數(shù)據(jù)集收錄到了 basicdataset 數(shù)據(jù)集里,每個 MNIST 的圖的大小是 28 x 28 。如果你有自己的數(shù)據(jù)集,你也可以通過 DJL 數(shù)據(jù)集導入教程來導入數(shù)據(jù)集到你的訓練任務(wù)中。
數(shù)據(jù)集導入教程: http://docs.djl.ai/docs/development/how_to_use_dataset.html#how-to-create-your-own-dataset
- int batchSize = 32; // 批大小
 - Mnist trainingDataset = Mnist.builder()
 - .optUsage(Usage.TRAIN) // 訓練集
 - .setSampling(batchSize, true)
 - .build();
 - Mnist validationDataset = Mnist.builder()
 - .optUsage(Usage.TEST) // 驗證集
 - .setSampling(batchSize, true)
 - .build();
 
這段代碼分別制作出了訓練和驗證集。同時我們也隨機排列了數(shù)據(jù)集從而更好的訓練。除了這些配置以外,你也可以添加對于圖片的進一步處理,比如設(shè)置圖片大小,對圖片進行歸一化等處理。
制作 model(建立 Block)
當你的數(shù)據(jù)集準備就緒后,我們就可以構(gòu)建神經(jīng)網(wǎng)絡(luò)了。在 DJL 中,神經(jīng)網(wǎng)絡(luò)是由 Block(代碼塊)構(gòu)成的。一個 Block 是一個具備多種神經(jīng)網(wǎng)絡(luò)特性的結(jié)構(gòu)。它們可以代表 一個操作, 神經(jīng)網(wǎng)絡(luò)的一部分,甚至是一個完整的神經(jīng)網(wǎng)絡(luò)。然后 Block 可以順序執(zhí)行或者并行。同時 Block 本身也可以帶參數(shù)和子 Block。這種嵌套結(jié)構(gòu)可以幫助我們構(gòu)造一個復雜但又不失維護性的神經(jīng)網(wǎng)絡(luò)。在訓練過程中,每個 Block 中附帶的參數(shù)會被實時更新,同時也包括它們的各個子 Block。這種遞歸更新的過程可以確保整個神經(jīng)網(wǎng)絡(luò)得到充分訓練。
當我們構(gòu)建這些 Block 的過程中,最簡單的方式就是將它們一個一個的嵌套起來。直接使用準備好 DJL 的 Block 種類,我們就可以快速制作出各類神經(jīng)網(wǎng)絡(luò)。
根據(jù)幾種基本的神經(jīng)網(wǎng)絡(luò)工作模式,我們提供了幾種 Block 的變體。SequentialBlock 是為了應對順序執(zhí)行每一個子 Block 構(gòu)造而成的。它會將前一個子 Block 的輸出作為下一個 Block 的輸入 繼續(xù)執(zhí)行到底。與之對應的,是 ParallelBlock 它用于將一個輸入并行輸入到每一個子 Block 中,同時將輸出結(jié)果根據(jù)特定的合并方程合并起來。最后我們說一下 LambdaBlock,它是幫助用戶進行快速操作的一個 Block,其中并不具備任何參數(shù),所以也沒有任何部分在訓練過程中更新。

我們來嘗試創(chuàng)建一個基本的 多層感知機(MLP)神經(jīng)網(wǎng)絡(luò)吧。多層感知機是一個簡單的前向型神經(jīng)網(wǎng)絡(luò),它只包含了幾個全連接層 (LinearBlock)。那么構(gòu)建這個網(wǎng)絡(luò),我們可以直接使用 SequentialBlock。
- int input = 28 * 28; // 輸入層大小
 - int output = 10; // 輸出層大小
 - int[] hidden = new int[] {128, 64}; // 隱藏層大小
 - SequentialBlock sequentialBlock = new SequentialBlock();
 - sequentialBlock.add(Blocks.batchFlattenBlock(input));
 - for (int hiddenSize : hidden) {
 - // 全連接層
 - sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());
 - // 激活函數(shù)
 - sequentialBlock.add(activation);
 - }
 - sequentialBlock.add(Linear.builder().setUnits(output).build());
 
當然 DJL 也提供了直接就可以拿來用的 MLP Block :
- Block block = new Mlp(
 - Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
 - Mnist.NUM_CLASSES,
 - new int[] {128, 64});
 
訓練
當我們準備好數(shù)據(jù)集和神經(jīng)網(wǎng)絡(luò)之后,就可以開始訓練模型了。在深度學習中,一般會由下面幾步來完成一個訓練過程:

- 初始化:我們會對每一個 Block 的參數(shù)進行初始化,初始化每個參數(shù)的函數(shù)都是由 設(shè)定的 Initializer 決定的。
 - 前向傳播:這一步將輸入數(shù)據(jù)在神經(jīng)網(wǎng)絡(luò)中逐層傳遞,然后產(chǎn)生輸出數(shù)據(jù)。
 - 計算損失:我們會根據(jù)特定的損失函數(shù) Loss 來計算輸出和標記結(jié)果的偏差。
 - 反向傳播:在這一步中,你可以利用損失反向求導算出每一個參數(shù)的梯度。
 - 更新權(quán)重:我們會根據(jù)選擇的優(yōu)化器(Optimizer)更新每一個在 Block 上參數(shù)的值。
 
DJL 利用了 Trainer 結(jié)構(gòu)體精簡了整個過程。開發(fā)者只需要創(chuàng)建 Trainer 并指定對應的 Initializer、Loss 和 Optimizer 即可。這些參數(shù)都是由 TrainingConfig 設(shè)定的。下面我們來看一下具體的參數(shù)設(shè)置:
TrainingListener:這個是對訓練過程設(shè)定的監(jiān)聽器。它可以實時反饋每個階段的訓練結(jié)果。這些結(jié)果可以用于記錄訓練過程或者幫助 debug 神經(jīng)網(wǎng)絡(luò)訓練過程中的問題。用戶也可以定制自己的 TrainingListener 來對訓練過程進行監(jiān)聽。
- DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
 - .addEvaluator(new Accuracy())
 - .addTrainingListeners(TrainingListener.Defaults.logging());
 - try (Trainer trainer = model.newTrainer(config)){
 - // 訓練代碼
 - }
 
當訓練器產(chǎn)生后,我們可以定義輸入的 Shape。之后就可以調(diào)用 fit 函數(shù)來進行訓練。fit 函數(shù)會對輸入數(shù)據(jù),訓練多個 epoch 是并最終將結(jié)果存儲在本地目錄下。
- /*
 - * MNIST 包含 28x28 灰度圖片并導入成 28 * 28 NDArray。
 - * 第一個維度是批大小, 在這里我們設(shè)置批大小為 1 用于初始化。
 - */
 - Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
 - int numEpoch = 5;
 - String outputDir = "/build/model";
 - // 用輸入初始化 trainer
 - trainer.initialize(inputShape);
 - TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");
 
這就是訓練過程的全部流程了!用 DJL 訓練是不是還是很輕松的?之后看一下輸出每一步的訓練結(jié)果。如果你用了我們默認的監(jiān)聽器,那么輸出是類似于下圖:
- [INFO ] - Downloading libmxnet.dylib ...
 - [INFO ] - Training on: cpu().
 - [INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.
 - Training: 100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec
 - Validating: 100% |████████████████████████████████████████|
 - [INFO ] - Epoch 1 finished.
 - [INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24
 - [INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
 - Training: 100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec
 - Validating: 100% |████████████████████████████████████████|
 - [INFO ] - Epoch 2 finished.NG [1m 41s]
 - [INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10
 - [INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
 - [INFO ] - train P50: 12.756 ms, P90: 21.044 ms
 - [INFO ] - forward P50: 0.375 ms, P90: 0.607 ms
 - [INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms
 - [INFO ] - backward P50: 0.608 ms, P90: 0.973 ms
 - [INFO ] - step P50: 0.543 ms, P90: 0.869 ms
 - [INFO ] - epoch P50: 35.989 s, P90: 35.989 s
 
當訓練結(jié)果完成后,我們可以用剛才的模型進行推理來識別手寫數(shù)字。如果剛才的內(nèi)容哪里有不是很清楚的,可以參照下面兩個鏈接直接嘗試訓練。
手寫數(shù)據(jù)集訓練:
https://docs.djl.ai/examples/docs/train_mnist_mlp.html
手寫數(shù)據(jù)集推理:
https://docs.djl.ai/jupyter/tutorial/03_image_classification_with_your_model.html
四、最后
在這個文章中,我們介紹了深度學習的基本概念,同時還有如何優(yōu)雅的利用 DJL 構(gòu)建深度學習模型并進行訓練。DJL 也提供了更加多樣的數(shù)據(jù)集和神經(jīng)網(wǎng)絡(luò)。如果有興趣學習深度學習,可以參閱我們的 Java 深度學習書。
Java 深度學習書: https://zh.d2l.ai/

Deep Java Library(DJL)是一個基于 Java 的深度學習框架,同時支持訓練以及推理。DJL 博取眾長,構(gòu)建在多個深度學習框架之上 (TenserFlow、PyTorch、MXNet 等) 也同時具備多個框架的優(yōu)良特性。你可以輕松使用 DJL 來進行訓練然后部署你的模型。
它同時擁有著強大的模型庫支持:只需一行便可以輕松讀取各種預訓練的模型。現(xiàn)在 DJL 的模型庫同時支持高達 70 多個來自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。
項目地址: https://github.com/awslabs/djl/
















 
 
 









 
 
 
 