TensorFlow、PyTorch和JAX:哪一款深度学习框架更适合你?

WBOY
发布: 2023-04-09 22:01:04
转载
1497 人浏览过

TensorFlow、PyTorch和JAX:哪一款深度学习框架更适合你?

译者 | 朱先忠

审校 | 墨色

深度学习每天都在以各种形式影响着我们的生活。无论是基于用户语音命令的Siri、Alexa、手机上的实时翻译应用程序,还是支持智能拖拉机、仓库机器人和自动驾驶汽车的计算机视觉技术,每个月似乎都会迎来新的进展。几乎所有这些深度学习应用程序的编写都来自于这三种框架:TensorFlow、PyTorch或者JAX。

那么,你到底应该使用哪些深度学习框架呢?在本文中,我们将对TensorFlow、PyTorch和JAX进行高级比较。我们的目标是让你了解发挥其优势的应用程序类型,当然还要考虑社区支持和易用性等因素。

你应该使用TensorFlow吗?

“从来没有人因为购买IBM而被解雇”是20世纪70年代和80年代计算机界的口号。在本世纪初,使用TensorFlow进行深度学习也是如此。但众所周知,进入20世纪90年代时,IBM就已被“搁置一旁”。那么,TensorFlow在2015年首次发布后7年的今天以及未来新的十年中仍然具有竞争力吗?

当然。TensorFlow并不是一直都在原地踏步。首先,TensorFlow 1.x是以一种非Python的方式构建静态图的,但是在TensorFlow 2.x中,还可以使用动态图模式(eager mode)构建模型,以便立即评估操作,这让人感觉它更像PyTorch。在高层,TensorFlow提供了Keras以便于开发;在底层,它提供了XLA(Accelerated Linear Algebra,加速线性代数)优化编译器以提高速度。XLA在提高GPU性能方面发挥了神奇作用,它是利用谷歌TPU(Tensor Processing Units,张量处理单元)能力的主要方法,为大规模模型训练提供了无与伦比的性能。

其次,多年来TensorFlow一直努力尽可能在所有方面做得很好。例如,你是否想要在成熟的平台上以定义良好且可重复的方式为模型提供服务?TensorFlow随时可以提供服务。你是否想要将模型部署重新定位到web、智能手机等低功耗计算或物联网等资源受限设备?在这一点上,TensorFlow.js和TensorFlow Lite都已经非常成熟。

显然,考虑到Google仍然在百分之百地使用TensorFlow运行其生产部署,就可以确信TensorFlow一定能够应用户的规模需求。

但是,近来确实有一些项目中的因素不容忽视。简而言之,把项目从TensorFlow 1.x升级到TensorFlow 2.x其实是非常残酷的。一些公司考虑到更新代码后在新的版本上正常工作所需的努力,干脆决定将代码移植到PyTorch框架下。此外,TensorFlow在科研领域也失去了动力,几年前已开始倾向于PyTorch提供的灵活性,这导致TensorFlow在研究论文中的使用不断减少。

此外,“Keras事件”也没有起到任何作用。Keras在两年前成为TensorFlow发行版的一个集成部分,但最近又被拉回到一个单独的库中,并确定了自己的发行计划。当然,排除Keras不会影响开发人员的日常生活,但在框架的一个小更新版本中出现如此引人注目的变化,并不会激发程序员使用TensorFlow框架的信心。

话虽如此,TensorFlow的确还是一个可靠的框架,它拥有广泛的深度学习生态系统,使用者可以在TensorFlow上构建适用于所有规模的应用程序和模型。如果真的这样做,将会有很多不错的合作公司。但如今,TensorFlow可能还不是首选。

你应该使用PyTorch吗?

PyTorch不再是紧跟TensorFlow之后的“新贵”,而是当今深度学习领域的主要力量,可能主要用于研究,但也越来越多地用于生产应用。随着动态图模式(eager mode)成为TensorFlow和PyTorch中开发的默认方法,PyTorch的自动微分(autograd)提供的更具Python风格的方法似乎赢得了与静态图的战争。

与TensorFlow不同的是,自0.4版本中不推荐使用变量API以来,PyTorch的核心代码没有经历过任何重大的中断。以前,变量需要使用自动生成张量,而现在,一切都是张量。但这并不是说无论在哪儿都不存在错误。例如,如果你一直在使用PyTorch跨多个GPU进行训练,可能会遇到DataParallel和较新的DistributedDataParaller之间的差异。你应该经常使用DistributedDataParallel,但实际上并不反对使用DataParaller。

虽然PyTorch在XLA/TPU支持方面一直落后于TensorFlow和JAX,但截至2022年,情况已经有了很大改善。PyTorch现在支持访问TPU虚拟机,支持老式TPU节点支持,以及支持在CPU、GPU或TPU上运行代码的简单命令行部署,而无需更改代码。如果你不想处理PyTorch经常让你编写的一些样板代码,那么你可以求助于Pytorche Lightning这样更高级别的扩展程序,它让你专注于实际工作,而不是重写训练循环。而另一方面,虽然PyTorch Mobile的工作仍在继续,但它远不如TensorFlow Lite那么成熟。

在生产方面,PyTorch现在可以与Kubeflow等框架无关平台进行集成,而且TorchServe项目可以处理扩展、度量和批量推理等部署细节——在PyTorch开发人员自己维护的小软件包中能够提供所有MLOps优点。另一方面,PyTorch支持规模缩放吗?没有问题!Meta公司多年来一直在生产领域运行PyTorch;所以,任何人告诉你PyTorch无法处理大规模的工作负载其实都是谎言。尽管如此,有一种情况是,PyTorch可能不像JAX那样友好,特别是在需要大量GPU或TPU进行非常大量的训练方面。

最后,依然存在一个人们不愿提及的棘手问题——PyTorch在过去几年的受欢迎程度几乎离不开Hugging Face公司的Transformers库的成功。是的,Transformers现在也支持TensorFlow和JAX,但它最初是一个PyTorch项目,仍然与框架紧密结合。随着Transformer架构的兴起,PyTorch在研究方面的灵活性,以及通过Hugging Face的模型中心在发布后几天或几个小时内引入如此多的新模型的能力,很容易看出为什么PyTorch在这些领域如此流行。

你应该使用JAX吗?

如果你对TensorFlow不感兴趣,那么Google可能会为你提供其他服务。JAX是一个由Google构建、维护和使用的深度学习框架,但它不是官方的Google产品。然而,如果你留意过去一年左右Google/DeepMind的论文和产品发布,你就会注意到Google的许多研究已经转移到了JAX。因此,尽管JAX并不是谷歌的“官方”产品,但它是谷歌研究人员用来推动边界的东西。

到底什么是JAX呢?理解JAX的一个简单方法是:想象一个GPU/TPU加速版本的NumPy,它可以用“一根魔杖”神奇地将Python函数矢量化,并处理所有这些函数的导数计算。最后,它提供了一个即时(JIT:Just-In-Time)组件,用于获取代码并为XLA(Accelerated Linear Algebra,即加速线性代数)编译器进行优化,从而大幅提高TensorFlow和PyTorch的性能。目前一些代码的执行速度提高了四到五倍,只需在JAX中重新实现,而不需要进行任何真正的优化工作。

考虑到JAX是在NumPy级别工作的,JAX代码是在比TensorFlow/Keras(甚至是PyTorch)低得多的级别上编写的。令人高兴的是,有一个小型但不断增长的生态系统,围绕着JAX进行了一些扩展。你想要使用神经网络库吗?当然可以。其中有来自谷歌的Flax,还有来自DeepMind(也包括谷歌)的Haiku。此外,Optax可满足你的所有优化器需求,PIX可用于图像处理,此外还有更多功能。一旦你使用Flax之类的东西,构建神经网络就变得相对容易掌握。请注意,仍然有一些略让人纠结的问题。例如,经验丰富的人经常谈到JAX处理随机数的方式与许多其他框架不同。

那么,你是否应该将所有内容转换为JAX并利用这一前沿技术呢?这个问题因人而异。如果你深入研究需要大量资源来训练的大规模模型的话,建议采用这种方法。此外,如果你关注JAX在确定性训练,以及其他需要数千个TPU Pod的项目,那么,也值得一试。

小结

因此,结论是什么呢?你应该使用哪种深度学习框架?遗憾的是,这道题并没有唯一的答案,完全取决于你正在处理的问题类型、计划部署模型以处理的规模,甚至还依赖于你所面对的计算平台。

不过,如果你从事的是文本和图像领域,并且正在进行中小型研究,以期在生产中部署这些模型,那么PyTorch可能是目前最好的选择。从最近的版本看,它正好针对这类应用空间的最佳点。

如果你需要从低计算设备中获取所有性能,那么建议你使用TensorFlow以及极为坚固的TensorFlow Lite软件包。最后,如果你正在研究数百亿、数千亿或更多参数的训练模型,并且你主要是为了研究目的而训练它们,那么也许是时候试一试JAX了。

原文链接:https://www.infoworld.com/article/3670114/tensorflow-pytorch-and-jax-choosing-a-deep-learning-framework.html

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。

以上是TensorFlow、PyTorch和JAX:哪一款深度学习框架更适合你?的详细内容。更多信息请关注PHP中文网其他相关文章!

来源:51cto.com
本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板