Traducteur | Zhu Xianzhong
Reviewer | Ink
L'apprentissage profond affecte nos vies sous diverses formes chaque jour. Qu’il s’agisse de Siri, d’Alexa, d’applications de traduction en temps réel sur votre téléphone basées sur les commandes vocales de l’utilisateur ou de la technologie de vision par ordinateur qui alimente les tracteurs intelligents, les robots d’entrepôt et les voitures autonomes, chaque mois semble apporter de nouvelles avancées. Presque toutes ces applications de deep learning sont écrites dans ces trois frameworks : TensorFlow, PyTorch ou JAX.
Alors, quels frameworks de deep learning devriez-vous utiliser ? Dans cet article, nous effectuerons une comparaison de haut niveau entre TensorFlow, PyTorch et JAX. Notre objectif est de vous donner une idée des types d'applications qui exploitent leurs atouts, tout en prenant bien sûr en compte des facteurs tels que le soutien de la communauté et la facilité d'utilisation.
"Personne n'a jamais été viré pour avoir acheté IBM" était un slogan dans le monde informatique des années 1970 et 1980. Au début de ce siècle, il en allait de même pour l’apprentissage profond utilisant TensorFlow. Mais comme nous le savons tous, au début des années 1990, IBM avait été « mis en veilleuse ». Alors, TensorFlow est-il toujours compétitif aujourd'hui, 7 ans après sa sortie initiale en 2015, et dans la nouvelle décennie à venir ?
Bien sûr. TensorFlow n'est pas toujours resté immobile. Premièrement, TensorFlow 1.x crée des graphiques statiques d'une manière non pythonique, mais dans TensorFlow 2.x, vous pouvez également créer des modèles en utilisant le mode impatient pour évaluer les opérations immédiatement, ce qui donne l'impression que cela ressemble plus à PyTorch. Au niveau supérieur, TensorFlow fournit Keras pour faciliter le développement ; au niveau inférieur, il fournit le compilateur d'optimisation XLA (Accelerated Linear Algebra, algèbre linéaire accélérée) pour augmenter la vitesse. XLA joue un rôle magique dans l'amélioration des performances du GPU. Il s'agit de la principale méthode permettant d'exploiter la puissance du TPU (Tensor Processing Units) de Google, offrant des performances inégalées pour la formation de modèles à grande échelle.
Deuxièmement, TensorFlow s'est efforcé au fil des années d'être aussi bon que possible dans tout. Par exemple, souhaitez-vous servir des modèles de manière bien définie et reproductible sur une plateforme mature ? TensorFlow est prêt à être utilisé. Souhaitez-vous déplacer le déploiement de modèles vers le Web, les ordinateurs à faible consommation tels que les smartphones ou les appareils aux ressources limitées tels que l'Internet des objets ? À ce stade, TensorFlow.js et TensorFlow Lite sont tous deux très matures.
Évidemment, étant donné que Google utilise toujours TensorFlow à 100 % pour exécuter ses déploiements de production, vous pouvez être sûr que TensorFlow sera en mesure de répondre aux besoins d'évolution des utilisateurs.
Cependant, il y a effectivement certains facteurs dans les projets récents qui ne peuvent être ignorés. Bref, mettre à niveau un projet de TensorFlow 1.x vers TensorFlow 2.x est en réalité très cruel. Certaines entreprises décident simplement de porter le code sur le framework PyTorch, compte tenu de l'effort requis pour mettre à jour le code pour qu'il fonctionne correctement sur la nouvelle version. En outre, TensorFlow a également perdu de son élan dans le domaine de la recherche scientifique, qui a commencé à préférer la flexibilité offerte par PyTorch il y a quelques années, ce qui a conduit au déclin continu de l'utilisation de TensorFlow dans les articles de recherche.
De plus, "l'incident Keras" n'a joué aucun rôle. Keras est devenu partie intégrante de la distribution TensorFlow il y a deux ans, mais a récemment été réintégré dans une bibliothèque distincte avec son propre plan de publication. Bien sûr, l'exclusion de Keras n'affectera pas la vie quotidienne des développeurs, mais un changement aussi radical dans une petite version mise à jour du framework n'inspire pas confiance aux programmeurs pour utiliser le framework TensorFlow.
Cela dit, TensorFlow est en effet un framework fiable. Il dispose d'un vaste écosystème d'apprentissage en profondeur et les utilisateurs peuvent créer des applications et des modèles de toutes tailles sur TensorFlow. Si nous faisons cela, nous aurons de nombreuses bonnes entreprises avec lesquelles travailler. Mais aujourd’hui, TensorFlow n’est peut-être pas le premier choix.
PyTorch n'est plus un « parvenu » après TensorFlow, mais constitue aujourd'hui une force majeure dans l'apprentissage profond, peut-être principalement pour la recherche, mais de plus en plus pour les applications de production. Le mode impatient devenant l'approche par défaut pour le développement dans TensorFlow et PyTorch, l'approche plus pythonique fournie par l'autograd de PyTorch semble gagner la guerre contre les graphes statiques.
Contrairement à TensorFlow, le code principal de PyTorch n'a connu aucune panne majeure depuis que l'API de variable est devenue obsolète dans la version 0.4. Auparavant, les variables nécessitaient des tenseurs générés automatiquement, mais désormais, tout est un tenseur. Mais cela ne veut pas dire qu’il n’y a pas d’erreurs partout. Par exemple, si vous utilisez PyTorch pour vous entraîner sur plusieurs GPU, vous avez peut-être rencontré des différences entre DataParallel et le plus récent DistributedDataParaller. Vous devez toujours utiliser DistributedDataParallel, mais il n'y a vraiment rien contre l'utilisation de DataParaller.
Bien que PyTorch ait toujours été à la traîne de TensorFlow et JAX en termes de support XLA/TPU, depuis 2022, la situation s'est beaucoup améliorée. PyTorch prend désormais en charge l'accès aux machines virtuelles TPU, la prise en charge des nœuds TPU hérités et la prise en charge du déploiement simple en ligne de commande de code exécuté sur CPU, GPU ou TPU sans nécessiter de modifications de code. Si vous ne souhaitez pas gérer une partie du code passe-partout que PyTorch vous fait souvent écrire, vous pouvez vous tourner vers des extensions de niveau supérieur comme Pytorche Lightning, qui vous permettent de vous concentrer sur le travail réel au lieu de réécrire les boucles de formation. En revanche, même si les travaux sur PyTorch Mobile se poursuivent, il est bien moins mature que TensorFlow Lite.
Du côté de la production, PyTorch peut désormais s'intégrer à des plates-formes indépendantes du framework comme Kubeflow, et le projet TorchServe gère les détails de déploiement tels que la mise à l'échelle, les métriques et l'inférence par lots, le tout disponible dans un petit package géré par les développeurs PyTorch eux-mêmes. Avantages de MLOps. D’un autre côté, PyTorch prend-il en charge la mise à l’échelle ? Aucun problème! Meta utilise PyTorch en production depuis des années ; donc quiconque vous dit que PyTorch ne peut pas gérer des charges de travail à grande échelle ment. Néanmoins, il existe une situation dans laquelle PyTorch n'est peut-être pas aussi convivial que JAX, en particulier lorsqu'il s'agit d'une formation très lourde nécessitant un grand nombre de GPU ou de TPU.
Enfin, il reste encore un problème épineux que les gens ne veulent pas mentionner : la popularité de PyTorch ces dernières années est presque indissociable du succès de la bibliothèque Transformers de Hugging Face. Oui, Transformers prend désormais également en charge TensorFlow et JAX, mais il s'agissait à l'origine d'un projet PyTorch et est toujours étroitement intégré au framework. Avec l'essor de l'architecture Transformer, la flexibilité de PyTorch en matière de recherche et la possibilité d'introduire autant de nouveaux modèles dans les jours ou heures suivant la sortie via le centre de modèles de Hugging Face, il est facile de comprendre pourquoi PyTorch est si populaire dans ces domaines.
Si TensorFlow ne vous intéresse pas, Google peut proposer d'autres services pour vous. JAX est un framework d'apprentissage profond construit, maintenu et utilisé par Google, mais ce n'est pas un produit Google officiel. Cependant, si vous prêtez attention aux articles et aux versions de produits Google/DeepMind au cours de la dernière année, vous remarquerez qu'une grande partie des recherches de Google ont été déplacées vers JAX. Ainsi, même si JAX n'est pas un produit Google « officiel », c'est quelque chose que les chercheurs de Google utilisent pour repousser les limites.
Qu'est-ce que JAX exactement ? Une façon simple de penser à JAX est : imaginez une version accélérée GPU/TPU de NumPy qui peut comme par magie vectoriser les fonctions Python avec « une baguette magique » et gérer le calcul des dérivées de toutes ces fonctions. Enfin, il fournit un composant juste à temps (JIT : Just-In-Time) pour récupérer le code et l'optimiser pour le compilateur XLA (Accelerated Linear Algebra), améliorant ainsi considérablement les performances de TensorFlow et PyTorch. Certains codes s'exécutent actuellement quatre à cinq fois plus vite simplement en les réimplémentant dans JAX sans réel travail d'optimisation.
Considérant que JAX fonctionne au niveau NumPy, le code JAX est écrit à un niveau bien inférieur à TensorFlow/Keras (ou même PyTorch). Heureusement, il existe un écosystème petit mais en pleine croissance autour de JAX avec une certaine expansion. Voulez-vous utiliser la bibliothèque de réseaux neuronaux ? bien sûr. Parmi eux figurent Flax de Google et Haiku de DeepMind (également Google). De plus, Optax est disponible pour tous vos besoins d'optimisation, PIX est disponible pour le traitement d'images et bien plus encore. Une fois que vous utilisez quelque chose comme Flax, la création de réseaux de neurones devient relativement facile à maîtriser. Notez qu’il reste encore quelques problèmes troublants. Par exemple, les personnes expérimentées parlent souvent de la manière dont JAX gère les nombres aléatoires différemment de nombreux autres frameworks.
Alors, faut-il tout convertir en JAX et profiter de cette technologie de pointe ? Cette question varie d'une personne à l'autre. Cette approche est recommandée si vous explorez des modèles à grande échelle qui nécessitent beaucoup de ressources pour s'entraîner. De plus, si vous êtes intéressé par JAX pour la formation déterministe et d'autres projets nécessitant des milliers de pods TPU, cela vaut la peine d'essayer.
Alors, quelle est la conclusion ? Quel framework d’apprentissage profond devriez-vous utiliser ? Malheureusement, il n'y a pas de réponse unique à cette question, tout dépend du type de problème sur lequel vous travaillez, de l'échelle à laquelle vous envisagez de déployer le modèle à gérer et même de la plate-forme informatique à laquelle vous avez affaire.
Cependant, si vous travaillez dans le domaine du texte et des images et effectuez des recherches de petite à moyenne envergure en vue de déployer ces modèles en production, alors PyTorch est probablement le meilleur choix pour le moment. À en juger par la version récente, il correspond au point idéal de ce type d'espace d'application.
Si vous avez besoin d'obtenir toutes les performances d'un appareil informatique peu puissant, il est recommandé d'utiliser TensorFlow et le package TensorFlow Lite extrêmement robuste. Enfin, si vous envisagez d'entraîner des modèles avec des dizaines, des centaines de milliards ou plus de paramètres et que vous les entraînez principalement à des fins de recherche, il est peut-être temps d'essayer JAX.
Lien original : https://www.infoworld.com/article/3670114/tensorflow-pytorch-and-jax-choosing-a-deep-learning-framework.html
Zhu Xianzhong, rédacteur en chef de la communauté 51CTO, blogueur expert 51CTO, conférencier, professeur d'informatique dans une université de Weifang et vétéran de l'industrie de la programmation indépendante.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!