TensorFlow、PyTorch和JAX:哪一款深度學習架構比較適合你?

WBOY
發布: 2023-04-09 22:01:04
轉載
1498 人瀏覽過

TensorFlow、PyTorch和JAX:哪一款深度學習架構比較適合你?

譯者| 朱先忠

#審查| 墨色

深度學習每天都在以各種形式影響著我們的生活。無論是基於用戶語音命令的Siri、Alexa、手機上的即時翻譯應用程序,還是支援智慧拖拉機、倉庫機器人和自動駕駛汽車的電腦視覺技術,每個月似乎都會迎來新的進展。幾乎所有這些深度學習應用程式的編寫都來自於這三種框架:TensorFlow、PyTorch或JAX。

那麼,你到底該使用哪些深度學習框架呢?在本文中,我們將對TensorFlow、PyTorch和JAX進行進階比較。我們的目標是讓你了解發揮其優勢的應用程式類型,當然也要考慮社群支援和易用性等因素。

你應該使用TensorFlow嗎?

「從來沒有人因為購買IBM而被解僱」是20世紀70年代和80年代電腦界的口號。在本世紀初,使用TensorFlow進行深度學習也是如此。但眾所周知,進入1990年代時,IBM已被「擱置一旁」。那麼,TensorFlow在2015年首次發布後7年的今天以及未來新的十年中仍然具有競爭力嗎?

當然。 TensorFlow並不是一直都在原地踏步。首先,TensorFlow 1.x是以非Python的方式建立靜態圖的,但是在TensorFlow 2.x中,還可以使用動態圖模式(eager mode)建立模型,以便立即評估操作,這讓人感覺它更像PyTorch。在高層,TensorFlow提供了Keras以便於開發;在底層,它提供了XLA(Accelerated Linear Algebra,加速線性代數)最佳化編譯器以提高速度。 XLA在提高GPU效能方面發揮了神奇作用,它是利用GoogleTPU(Tensor Processing Units,張量處理單元)能力的主要方法,為大規模模型訓練提供了無與倫比的效能。

其次,多年來TensorFlow一直努力盡可能在所有方面做得很好。例如,你是否想要在成熟的平台上以定義良好且可重複的方式為模型提供服務? TensorFlow隨時可以提供服務。你是否想要將模型部署重新定位到web、智慧型手機等低功耗運算或物聯網等資源受限設備?在這一點上,TensorFlow.js和TensorFlow Lite都已經非常成熟。

顯然,考慮到Google仍在百分之百地使用TensorFlow運行其生產部署,就可以確信TensorFlow一定能夠應使用者的規模需求。

但是,近來確實有一些專案中的因素不容忽視。簡而言之,把專案從TensorFlow 1.x升級到TensorFlow 2.x其實是非常殘酷的。一些公司考慮到更新程式碼後在新的版本上正常工作所需的努力,乾脆決定將程式碼移植到PyTorch框架下。此外,TensorFlow在科研領域也失去了動力,幾年前已開始傾向於PyTorch提供的彈性,這導致TensorFlow在研究論文中的使用不斷減少。

此外,「Keras事件」也沒有起到任何作用。 Keras在兩年前成為TensorFlow發行版的一個整合部分,但最近又被拉回到一個單獨的庫中,並確定了自己的發行計劃。當然,排除Keras不會影響開發人員的日常生活,但在框架的一個小更新版本中出現如此引人注目的變化,並不會激發程式設計師使用TensorFlow框架的信心。

話雖如此,TensorFlow的確還是一個可靠的框架,它擁有廣泛的深度學習生態系統,使用者可以在TensorFlow上建立適用於所有規模的應用程式和模型。如果真的這樣做,將會有很多不錯的合作公司。但如今,TensorFlow可能還不是首選。

你應該使用PyTorch嗎?

PyTorch不再是緊跟TensorFlow之後的“新貴”,而是當今深度學習領域的主要力量,可能主要用於研究,但也越來越多地用於生產應用。隨著動態圖模式(eager mode)成為TensorFlow和PyTorch中開發的預設方法,PyTorch的自動微分(autograd)提供的更具Python風格的方法似乎贏得了與靜態圖的戰爭。

與TensorFlow不同的是,自從0.4版本中不建議使用變數API以來,PyTorch的核心程式碼沒有經歷任何重大的中斷。以前,變數需要使用自動產生張量,而現在,一切都是張量。但這並不是說無論在哪裡都不存在錯誤。例如,如果你一直在使用PyTorch跨多個GPU進行訓練,可能會遇到DataParallel和較新的DistributedDataParaller之間的差異。你應該經常使用DistributedDataParallel,但實際上並不反對使用DataParaller。

雖然PyTorch在XLA/TPU支援方面一直落後於TensorFlow和JAX,但截至2022年,情況已經有了很大改善。 PyTorch現在支援存取TPU虛擬機,支援老式TPU節點支持,以及支援在CPU、GPU或TPU上運行程式碼的簡單命令列部署,而無需更改程式碼。如果你不想處理PyTorch經常讓你寫的一些樣板程式碼,那麼你可以求助於Pytorche Lightning這樣更高層級的擴充程序,它讓你專注於實際工作,而不是重寫訓練循環。而另一方面,雖然PyTorch Mobile的工作仍在繼續,但它遠不如TensorFlow Lite那麼成熟。

在生產方面,PyTorch現在可以與Kubeflow等框架無關平台進行集成,而且TorchServe項目可以處理擴展、度量和批量推理等部署細節——在PyTorch開發人員自己維護的小軟體包中能夠提供所有MLOps優點。另一方面,PyTorch支援規模縮放嗎?沒有問題! Meta公司多年來一直在生產領域運作PyTorch;所以,任何人告訴你PyTorch無法處理大規模的工作負載其實都是謊言。儘管如此,有一種情況是,PyTorch可能不像JAX那麼友好,特別是在需要大量GPU或TPU進行非常大量的訓練方面。

最後,依然存在著一個人們不願提及的棘手問題——PyTorch在過去幾年的受歡迎程度幾乎離不開Hugging Face公司的Transformers庫的成功。是的,Transformers現在也支援TensorFlow和JAX,但它最初是一個PyTorch項目,仍然與框架緊密結合。隨著Transformer架構的興起,PyTorch在研究方面的靈活性,以及​​透過Hugging Face的模型中心在發布後幾天或幾個小時內引入如此多的新模型的能力,很容易看出為什麼PyTorch在這些領域如此流行。

你該使用JAX嗎?

如果你對TensorFlow不感興趣,那麼Google可能會為你提供其他服務。 JAX是一個由Google建構、維護和使用的深度學習框架,但它不是官方的Google產品。然而,如果你留意過去一年左右Google/DeepMind的論文和產品發布,你會注意到Google的許多研究已經轉移到了JAX。因此,儘管JAX並不是Google的「官方」產品,但它是Google研究人員用來推動邊界的東西。

到底什麼是JAX呢?理解JAX的一個簡單方法是:想像一個GPU/TPU加速版本的NumPy,它可以用「一根魔杖」神奇地將Python函數向量化,並處理所有這些函數的導數計算。最後,它提供了一個即時(JIT:Just-In-Time)元件,用於獲取程式碼並為XLA(Accelerated Linear Algebra,即加速線性代數)編譯器進行最佳化,從而大幅提高TensorFlow和PyTorch的效能。目前一些程式碼的執行速度提高了四到五倍,只需在JAX中重新實現,而不需要進行任何真正的最佳化工作。

考慮到JAX是在NumPy層級工作的,JAX程式碼是在比TensorFlow/Keras(甚至是PyTorch)低得多的層級上編寫的。令人高興的是,有一個小型但不斷增長的生態系統,圍繞著JAX進行了一些擴展。你想要使用神經網路庫嗎?當然可以。其中有來自Google的Flax,還有來自DeepMind(也包括Google)的Haiku。此外,Optax可滿足你的所有優化器需求,PIX可用於影像處理,此外還有更多功能。一旦你使用Flax之類的東西,建立神經網路就變得相對容易掌握。請注意,仍然有一些略讓人糾結的問題。例如,經驗豐富的人經常談到JAX處理隨機數的方式與許多其他框架不同。

那麼,你是否應該將所有內容轉換為JAX並利用這項尖端技術?這個問題因人而異。如果你深入研究需要大量資源來訓練的大規模模型的話,建議採用這種方法。此外,如果你關注JAX在確定性訓練,以及其他需要數千個TPU Pod的項目,那麼,也值得一試。

小結

因此,結論是什麼呢?你應該使用哪種深度學習框架?可惜的是,這題並沒有唯一的答案,完全取決於你正在處理的問題類型、規劃部署模型以處理的規模,甚至還依賴你所面對的運算平台。

不過,如果你從事的是文字和圖像領域,並且正在進行中小型研究,以期在生產中部署這些模型,那麼PyTorch可能是目前最好的選擇。從最近的版本來看,它正好針對這類應用空間的最佳點。

如果你需要從低運算設備中取得所有效能,那麼建議你使用TensorFlow以及極為堅固的TensorFlow Lite軟體包。最後,如果你正在研究數百億、數千億或更多參數的訓練模型,並且你主要是為了研究目的而訓練它們,那麼也許是時候試一試JAX了。

原文連結:https://www.infoworld.com/article/3670114/tensorflow-pytorch-and-jax-choosing -a-deep-learning-framework.html

譯者介紹

朱先忠,51CTO社群編輯,51CTO專家部落格、講師,濰坊一所高校電腦教師,自由程式設計界老兵一枚。

以上是TensorFlow、PyTorch和JAX:哪一款深度學習架構比較適合你?的詳細內容。更多資訊請關注PHP中文網其他相關文章!

來源:51cto.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
最新問題
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板