Maison > interface Web > js tutoriel > le corps du texte

Exécution d'un programme JAX à partir de Dart à l'aide de C FFI

Barbara Streisand
Libérer: 2024-11-23 13:50:10
original
959 Les gens l'ont consulté

? Pourquoi combiner Dart et JAX pour l'apprentissage automatique ?

Lors de la création d'applications, la sélection des bons outils est cruciale. Vous souhaitez des performances élevées, un développement facile et un déploiement multiplateforme transparent. Les frameworks populaires proposent des compromis :

  • C fournit de la vitesse mais peut ralentir le développement.
  • Dart (avec Flutter) est plus lent mais simplifie la gestion de la mémoire et le développement multiplateforme.

Mais voici le problème : la plupart des frameworks ne disposent pas d'un support robuste pour l'apprentissage automatique (ML) natif. Cet écart existe parce que ces cadres sont antérieurs au boom de l’IA. La question est :

Comment pouvons-nous intégrer efficacement le ML dans les applications ?

Les solutions courantes telles que ONNX Runtime permettent d'exporter des modèles ML pour l'intégration d'applications, mais elles ne sont pas optimisées pour les processeurs ni suffisamment flexibles pour les algorithmes généralisés.

Entrez JAX, une bibliothèque Python qui :

  • Permet d'écrire des algorithmes de ML optimisés et à usage général.
  • Offre une exécution indépendante de la plate-forme sur les processeurs, les GPU et les TPU.
  • Prend en charge des fonctionnalités de pointe telles que autograd et compilation JIT.

Dans cet article, nous allons vous montrer comment :

  1. Écrivez des programmes JAX en Python.
  2. Générer des spécifications XLA.
  3. Déployez du code JAX optimisé dans Dart à l'aide de C FFI.

? Qu’est-ce que JAX ?

JAX est comme NumPy sous stéroïdes. Développée par Google, il s'agit d'une bibliothèque de bas niveau et hautes performances qui rend le ML accessible mais puissant.

  • Agnostique de la plate-forme : le code s'exécute sur les CPU, les GPU et les TPU sans modification.
  • Vitesse : Propulsé par le compilateur XLA, JAX optimise et accélère l'exécution.
  • Flexibilité : parfait pour les modèles ML et les algorithmes généraux.

Voici un exemple comparant NumPy et JAX :

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  
Copier après la connexion
Copier après la connexion

L'analyse comparative dans Google Colab révèle l'avantage en termes de performances de JAX :

  • CPU & GPU : JAX est plus rapide que NumPy.
  • TPU : les accélérations deviennent perceptibles pour les grands modèles en raison des coûts de transfert de données.

Cette flexibilité et cette rapidité rendent JAX idéal pour les environnements de production où les performances sont essentielles.


Running a JAX Program from Dart Using C   FFI


?️ Mettre JAX en production

Microservices cloud et déploiement local

  • Cloud : les microservices Python conteneurisés sont parfaits pour le calcul basé sur le cloud.
  • Local : l'envoi d'un interpréteur Python n'est pas idéal pour les applications locales.

Solution : exploiter la compilation XLA de JAX

JAX traduit le code Python en spécifications HLO (High-Level Optimizer), qui peuvent être compilées et exécutées à l'aide de bibliothèques C XLA. Cela permet :

  1. Écriture d'algorithmes en Python.
  2. Les exécuter nativement via une bibliothèque C.
  3. Intégration avec Dart via FFI (Foreign Function Interface).

✍️ Intégration étape par étape

1. Générer un proto HLO

Écrivez votre fonction JAX et exportez sa représentation HLO. Par exemple :

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  
Copier après la connexion
Copier après la connexion

Pour générer le HLO, utilisez le script jax_to_ir.py du référentiel JAX :

import jax.numpy as jnp  

def fn(x, y, z):  
  return jnp.dot(x, y) / z  
Copier après la connexion

Placez les fichiers résultants (fn_hlo.txt et fn_hlo.pb) dans le répertoire des ressources de votre application.


2. Créez une bibliothèque dynamique C

Modifier l'exemple de code C de JAX

Clonez le référentiel JAX et accédez à jax/examples/jax_cpp.

  • Ajouter un fichier d'en-tête main.h :
python jax_to_ir.py \
  --fn jax_example.prog.fn \
  --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \
  --constants '{"z": 2.0}' \
  --ir_format HLO \
  --ir_human_dest /tmp/fn_hlo.txt \
  --ir_dest /tmp/fn_hlo.pb
Copier après la connexion
  • Mettez à jour le fichier BUILD pour créer une bibliothèque partagée :
#ifndef MAIN_H  
#define MAIN_H  

extern "C" {  
  int bar(int foo);  
}  

#endif  
Copier après la connexion

Compiler avec Bazel :

cc_shared_library(  
   name = "jax",  
   deps = [":main"],  
   visibility = ["//visibility:public"],  
)  
Copier après la connexion

Vous trouverez le libjax.dylib compilé dans le répertoire de sortie.


3. Connectez Dart avec C à l'aide de FFI

Utilisez le package FFI de Dart pour communiquer avec la bibliothèque C. Créez un fichier jax.dart :

bazel build examples/jax_cpp:jax  
Copier après la connexion

Incluez la bibliothèque dynamique dans le répertoire de votre projet. Testez-le avec :

import 'dart:ffi';  
import 'package:dynamic_library/dynamic_library.dart';  

typedef FooCFunc = Int32 Function(Int32 bar);  
typedef FooDartFunc = int Function(int bar);  

class JAX {  
  late final DynamicLibrary dylib;  

  JAX() {  
    dylib = loadDynamicLibrary(libraryName: 'jax');  
  }  

  Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar');  

  int bar(int foo) {  
    return _bar(foo);  
  }  
}  
Copier après la connexion

Vous verrez la sortie de la bibliothèque C dans votre console.


? Prochaines étapes

Avec cette configuration, vous pouvez :

  • Optimisez les modèles ML avec JAX et XLA.
  • Exécutez des algorithmes puissants localement.

Les cas d'utilisation potentiels incluent :

  • Algorithmes de recherche (par exemple, A*).
  • Optimisation combinatoire (par exemple, planification).
  • Traitement d'image (par exemple, détection des contours).

JAX comble le fossé entre le développement basé sur Python et les performances au niveau de la production, permettant aux ingénieurs ML de se concentrer sur les algorithmes sans se soucier du code C de bas niveau.


Nous construisons une plate-forme d'IA de pointe avec des jetons de discussion illimités et une mémoire à long terme, garantissant des interactions transparentes et contextuelles qui évoluent au fil du temps.

C'est entièrement gratuit et vous pouvez également l'essayer dans votre IDE actuel.


Running a JAX Program from Dart Using C   FFI

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

source:dev.to
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Derniers articles par auteur
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal