譯者 | 布加迪
審校 | 重樓
近年來,Keras和Tensorflow組合遇到了一種與之競爭的框架:JAX,它在深度學(xué)習(xí)開發(fā)者社區(qū)逐漸變得很重要。那么JAX到底是什么?它有哪些功能?它與Keras API又有什么相似和不同之處?Keras API一直是使用Tensorflow(最龐大的Python深度學(xué)習(xí)庫)的幾乎通用的方法。本文逐一解答了這些問題。
Keras是什么?
Keras誕生于2015年,這種接口用于簡化使用成熟的庫來構(gòu)建神經(jīng)網(wǎng)絡(luò)架構(gòu),比如Tensorflow。盡管Keras最初作為一種獨(dú)立的框架而創(chuàng)建,但它最終成為了與Tensorflow結(jié)合使用的框架:Tensorflow是用于高效訓(xùn)練和使用可擴(kuò)展深度神經(jīng)網(wǎng)絡(luò)的主要Python庫。隨后,Keras成為Tensorflow上面的抽象層:換句話說,它使“原始”Tensorflow用起來變得容易多了。
Keras便于實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)架構(gòu)最常見的構(gòu)建模塊:神經(jīng)元層、目標(biāo)及激活函數(shù)以及優(yōu)化器等等。特殊類型的深度神經(jīng)網(wǎng)絡(luò)架構(gòu)使用Keras抽象類和方法可以輕松構(gòu)建,比如卷積神經(jīng)網(wǎng)絡(luò)(CNN)和循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)。
JAX是什么?
JAX是一種比較新的框架,不僅適用于深度學(xué)習(xí),還適用于整個(gè)機(jī)器學(xué)習(xí)開發(fā)。它于2018年由谷歌發(fā)布,側(cè)重于高性能數(shù)值計(jì)算。具體來說,JAX使Python和numpy(其最大的數(shù)值計(jì)算庫)用起來更簡單更快捷,同時(shí)無縫支持GPU和TPU高性能處理。就科學(xué)計(jì)算和數(shù)值計(jì)算而言,這是相對普通numpy的一個(gè)重要優(yōu)勢,因?yàn)?/span>numpy只支持CPU執(zhí)行。
由于兼顧高性能執(zhí)行模式的直觀性和多功能,JAX正迅速名聲大噪,成為機(jī)器學(xué)習(xí)和深度學(xué)習(xí)開發(fā)的最先進(jìn)框架,有機(jī)會最終取代Tensorflow和PyTorch等其他框架。它的自動微分特性有助于高效地執(zhí)行訓(xùn)練深度神經(jīng)網(wǎng)絡(luò)背后的基于梯度的復(fù)雜計(jì)算。
簡而言之,JAX將科學(xué)計(jì)算和高性能計(jì)算的功能整合到單單一個(gè)框架中。
Keras和JAX的異同
現(xiàn)在我們對Keras和JAX已有了大致的了解,下面列出這兩種框架共有的一些特性和諸多不同的方面。
相似之處:
- 深度學(xué)習(xí)模型開發(fā):這兩種框架都被廣泛用于構(gòu)建和訓(xùn)練深度學(xué)習(xí)模型。
- GPU/TPU加速:Keras和JAX都可以利用GPU和TPU等加速硬件高效地訓(xùn)練模型。
- 自動微分:這兩種框架結(jié)合了自動計(jì)算梯度的機(jī)制,梯度計(jì)算是模型在訓(xùn)練過程中優(yōu)化的關(guān)鍵過程。
- 與深度學(xué)習(xí)庫的互操作性:這兩種框架都與流行的深度學(xué)習(xí)庫TensorFlow兼容。
差異之處:
- 抽象級別:雖然兩種解決方案都提供了一定程度的抽象,但Keras更適合尋求高級API且易于使用的用戶;而JAX更注重控制的靈活性,停留在較低的抽象級別,專注于數(shù)值計(jì)算。
- 后端:Keras完全基于并依賴Tensorflow作為后端。同時(shí),JAX不依賴Tensorflow,而是使用一種名為適時(shí)(JIT)編譯的方法。話雖如此,JAX和Tensorflow可以結(jié)合使用,它們在某些情況下可以很好地互補(bǔ),比如將高級數(shù)學(xué)轉(zhuǎn)換整合到高級深度學(xué)習(xí)架構(gòu)中。
- 易用性:與抽象級別密切相關(guān),Keras旨在易于快速使用。雖然JAX功能更強(qiáng)大,但需要更深入的技術(shù)知識才能順利地使用它。
- 函數(shù)轉(zhuǎn)換:這是JAX獨(dú)有的特性,允許高級轉(zhuǎn)換功能,比如自動向量化和并行執(zhí)行。
- 自動優(yōu)化:同樣,JAX在這方面很突出,它更加靈活,便于在神經(jīng)網(wǎng)絡(luò)范圍之外優(yōu)化各種函數(shù)(這就是為什么它也適用于其他機(jī)器學(xué)習(xí)方法,比如集成學(xué)習(xí)),Keras專門專注于深度學(xué)習(xí)模型。
我該選擇哪種框架?
了解了這兩種框架之間的異同之后,根據(jù)手頭的問題或場景決定選擇哪種框架就不是什么麻煩事了。
如果用戶尋求易用性、更平緩的學(xué)習(xí)曲線和更高的抽象級別,Keras是不二的選擇。這個(gè)基于Tensorflow庫的API將使用戶能夠在短時(shí)間內(nèi)構(gòu)建原型,并利用各種深度學(xué)習(xí)模型處理預(yù)測和推理任務(wù)。
另一方面,對于經(jīng)驗(yàn)豐富的開發(fā)人員來說,JAX是一種更強(qiáng)大、更通用的選擇,可以獲得優(yōu)化計(jì)算和高級函數(shù)轉(zhuǎn)換之類的附加功能,而不是嚴(yán)格局限于Tensorflow或深度學(xué)習(xí)建模,不過它需要用戶更大的控制度和低級工程決策。
原文標(biāo)題:Keras vs. JAX: A Comparison,作者:Iván Palomares Carrascosa