Heim > Technologie-Peripheriegeräte > KI > Leitfaden zum blitzschnellen Jax

Leitfaden zum blitzschnellen Jax

Jennifer Aniston
Freigeben: 2025-03-19 11:21:11
Original
636 Leute haben es durchsucht

Hey da, Python -Kollegen! Haben Sie sich jemals gewünscht, dass Ihr Numpy -Code mit Überschallgeschwindigkeit ausgeführt wird? Treffen Sie Jax!. Ihr neuer bester Freund in Ihrem maschinellen Lernen, tiefem Lernen und numerischen Computerreise. Betrachten Sie es als Numpy mit Supermächten. Es kann automatisch Gradienten verarbeiten, Ihren Code so kompilieren, dass sie mit JIT schnell ausgeführt werden, und sogar mit GPU und TPU ausgeführt werden, ohne einen Schweiß zu brechen. Egal, ob Sie neuronale Netzwerke aufbauen, wissenschaftliche Daten knirschen, Transformatormodelle optimieren oder nur versuchen, Ihre Berechnungen zu beschleunigen, Jax hat Ihren Rücken. Lassen Sie uns eintauchen und sehen, was Jax so besonders macht.

Dieser Leitfaden bietet eine detaillierte Einführung in Jax und sein Ökosystem.

Lernziele

  • Erklären Sie Jax 'Kernprinzipien und wie sie sich von Numpy unterscheiden.
  • Wenden Sie die drei wichtigsten Transformationen von JAX an, um den Python -Code zu optimieren. Umwandeln Sie Numpy -Operationen in eine effiziente JAX -Implementierung.
  • Identifizieren und beheben Sie gemeinsame Leistungsengpässe im JAX -Code. Implementieren Sie die JIT -Kompilierung korrekt und vermeiden Sie typische Fallstricke.
  • Bauen und trainieren Sie ein neuronales Netzwerk mit JAX von Grund auf neu. Implementieren Sie gemeinsame maschinelle Lernvorgänge mit dem funktionalen Ansatz von JAX.
  • Lösen Sie die Optimierungsprobleme mithilfe der automatischen Differenzierung von JAX. Führen Sie effiziente Matrixoperationen und numerische Berechnungen durch.
  • Wenden Sie effektive Debugging-Strategien für JAX-spezifische Themen an. Implementieren Sie speichereffiziente Muster für groß angelegte Berechnungen.

Dieser Artikel wurde als Teil des Data Science -Blogathons veröffentlicht.

Inhaltsverzeichnis

  • Was ist Jax?
  • Warum fällt Jax auf?
  • Erste Schritte mit Jax
  • Warum Jax lernen?
  • Essentielle JAX -Transformationen
  • Aufbau neuronaler Netzwerke mit Jax
  • Best Practice und Tipps
  • Leistungsoptimierung
  • Debugging -Strategien
  • Gemeinsame Muster und Redewendungen in Jax
  • Was kommt als nächstes?
  • Abschluss
  • Häufig gestellte Fragen

Was ist Jax?

Nach der offiziellen Dokumentation ist JAX eine Python-Bibliothek zur Beschleunigungsorientierten Array-Berechnung und -programmtransformation, die für leistungsstarke numerische Computing und maschinelles Lernen in großem Maßstab entwickelt wurde. Jax ist also im Wesentlichen Numpy in Steroiden, es kombiniert bekannte Vorgänge im Numme-Stil mit automatischer Differenzierung und Hardwarebeschleunigung. Stellen Sie sich vor, Sie holen das Beste aus drei Welten.

  • Numpys elegante Syntax- und Array -Operation
  • Pytorchähnliche automatische Differenzierungsfähigkeit
  • XLAs (beschleunigte lineare Algebra) für Hardware -Beschleunigung und Kompilierungsvorteile.

Warum fällt Jax auf?

Was Jax auszeichnet, sind seine Transformationen. Dies sind leistungsstarke Funktionen, die Ihren Python -Code ändern können:

  • JIT : Just-in-Time-Zusammenstellung für eine schnellere Ausführung
  • Grad : Automatische Differenzierung für Computergradienten
  • VMAP : automatisch Vektorisierung für die Stapelverarbeitung

Hier ist ein kurzer Blick:

 Importieren Sie Jax.numpy als JNP
von Jax Import Grad, JIT
# Definieren Sie eine einfache Funktion
@jit # beschleunigen Sie es mit Zusammenstellung
Def square_sum (x):
return jnp.sum (jnp.square (x))
# Erhalten Sie seine Gradientenfunktion automatisch
gradient_fn = grad (square_sum)
# Probieren Sie es aus
x = jnp.Array ([1.0, 2.0, 3.0])
print (f "gradienten: {gradient_fn (x)}"))
Nach dem Login kopieren

Ausgabe:

 Gradient: [2. 4. 6.]
Nach dem Login kopieren

Erste Schritte mit Jax

Im Folgenden werden wir einige Schritte befolgen, um mit Jax zu beginnen.

STEP1: Installation

Das Einrichten von JAX ist für die Nutzung von CPU-Nutzung einfach. Sie können die JAX -Dokumentation für weitere Informationen verwenden.

STEP2: Umgebung für das Projekt erstellen

Erstellen Sie eine Conda -Umgebung für Ihr Projekt

 # Erstellen Sie eine Conda Env für Jax
$ conda create -name jaxdev python = 3.11

#aktiv die env
$ conda aktivieren jaxdev

# Erstellen Sie einen Projekt -Dir -Namen JAX101
$ mkdir jax101

# Gehen Sie in die Dire
$ CD JAX101
Nach dem Login kopieren

Schritt 3: Installieren von JAX

Installieren von JAX in der neu erstellten Umgebung

 # Nur für CPU
PIP -Installation -Upgrade PIP
PIP -Installation -Upgrade "Jax"

# für GPU
PIP -Installation -Upgrade PIP
PIP -Installation -Upgrade "JAX [CUDA12]"
Nach dem Login kopieren

Jetzt sind Sie bereit, in echte Dinge einzutauchen. Bevor Sie Ihre Hände auf praktische Codierung schmutzig machen, lernen wir einige neue Konzepte. Ich werde zuerst die Konzepte erklären und dann werden wir zusammen codieren, um den praktischen Standpunkt zu verstehen.

Holen Sie sich zunächst übrigens etwas Motivation, warum lernen wir wieder eine neue Bibliothek? Ich werde diese Frage in diesem Leitfaden Schritt für Schritt so einfach wie möglich beantworten.

Warum Jax lernen?

Stellen Sie sich Jax als Elektrowerkzeug vor. Während Numpy wie eine zuverlässige Handsäge ist, ist Jax wie eine moderne elektrische Säge. Es erfordert ein bisschen mehr Schritte und Wissen, aber die Leistungsvorteile sind für intensive Berechnungsaufgaben wert.

  • Leistung : Jax Code kann erheblich schneller ausgeführt als Python- oder Numpy -Code, insbesondere bei GPU und TPUs
  • Flexibilität : Es ist nicht nur für maschinelles Lernen- Jax Excels in wissenschaftlichem Computer, Optimierung und Simulation.
  • Moderner Ansatz: JAX fördert funktionale Programmiermuster, die zu einem saubereren, wartbaren Code führen.

Im nächsten Abschnitt tauchen wir tief in Jax 'Transformation ein und beginnend mit der JIT -Zusammenstellung. Diese Transformationen geben Jax seine Supermächte, und das Verständnis ist der Schlüssel zur effektiven Nutzung von JAX.

Essentielle JAX -Transformationen

Die Transformationen von JAX unterscheiden es wirklich von den numerischen Berechnungsbibliotheken wie Numpy oder Scipy. Lassen Sie uns jeden untersuchen und sehen, wie sie Ihren Code aufladen können.

JIT- oder Just-in-Zeit-Zusammenstellung

Die Just-in-Time-Kompilierung optimiert die Codeausführung, indem Teile eines Programms zur Laufzeit und nicht vorab zusammengestellt werden.

Wie funktioniert JIT in Jax?

In Jax verwandelt Jax.jit eine Python-Funktion in eine JIT-kompilierte Version. Dekorieren einer Funktion mit @jax.jit erfasst das Ausführungsdiagramm, optimiert sie und kompiliert sie mit XLA. Die kompilierte Version führt dann aus und liefert erhebliche Beschleunigungen, insbesondere für wiederholte Funktionsaufrufe.

So können Sie es versuchen.

 Importieren Sie Jax.numpy als JNP
von Jax Import JIT
Importzeit


# Eine rechenintensive Funktion
Def Slow_function (x):
    für _ im Bereich (1000):
        x = jnp.sin (x) jnp.cos (x)
    Rückkehr x


# Die gleiche Funktion mit JIT
@jit
def fast_function (x):
    für _ im Bereich (1000):
        x = jnp.sin (x) jnp.cos (x)
    Rückkehr x
Nach dem Login kopieren

Hier ist die gleiche Funktion, einer ist nur ein einfacher Python -Kompilierungsprozess und der andere wird als JAX -Zusammenstellungsprozess verwendet. Es berechnet die 1000 Datenpunkte Summe von Sinus- und Cosinus -Funktionen. Wir werden die Leistung mit der Zeit vergleichen.

 # Performance vergleichen
x = jnp.arange (1000)

# Aufwärmjit
fast_function (x) # Erster Anruf kompiliert die Funktion

# Zeitvergleich
start = time.time ()
Slow_result = Slow_function (x)
print (f "ohne JIT: {time.time () - start: .4f} Sekunden"))

start = time.time ()
fast_result = fast_function (x)
print (f "mit jit: {time.time () - start: .4f} Sekunden"))
Nach dem Login kopieren

Das Ergebnis wird Sie erstaunen. Die JIT -Zusammenstellung ist 333 -mal schneller als die normale Zusammenstellung. Es ist, als würde man ein Fahrrad mit einem Buggati Chiron vergleichen.

Ausgabe:

 Ohne JIT: 0,0330 Sekunden
Mit JIT: 0,0010 Sekunden
Nach dem Login kopieren

JIT kann Ihnen einen superschnellen Ausführungsschub geben, aber Sie müssen ihn richtig verwenden, sonst ist es wie das Fahren von Bugatti auf einer matschigen Dorfstraße, die keine Supercar -Einrichtung bietet.

Häufige JIT -Fallstricke

JIT funktioniert am besten mit statischen Formen und Typen. Vermeiden Sie es, Python -Schleifen und -bedingungen zu verwenden, die von Array -Werten abhängen. JIT funktioniert nicht mit den dynamischen Arrays.

 # Schlecht - verwendet Python Control Flow
@jit
Def bad_function (x):
    Wenn x [0]> 0: # Dies funktioniert mit JIT nicht gut
        Rückkehr x
    return -x


# print (bad_function (Jnp.Array ([1, 2, 3])))


# Gut - verwendet Jax Control Flow
@jit
Def Good_function (x):
    Return Jnp.where (x [0]> 0, x, -x) # JAX -native Bedingung


print (Good_function (Jnp.Array ([1, 2, 3])))
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Das bedeutet, dass Bad_function schlecht ist, da sich JIT während der Berechnung nicht im Wert von x befand.

Ausgabe:

 [1 2 3]
Nach dem Login kopieren

Einschränkungen und Überlegungen

  • Kompilierungsaufwand: Wenn eine JIT-kompilierte Funktion zum ersten Mal ausgeführt wird, gibt es aufgrund der Zusammenstellung etwas Overhead. Die Zusammenstellungskosten können die Leistungsvorteile für kleine Funktionen oder die nur einmal genannten Funktionen überwiegen.
  • Dynamische Python -Funktionen: Jax 'JIT erfordert, dass Funktionen „statisch“ sind. Der dynamische Steuerfluss wie das Ändern von Formen oder Werten, die auf Python -Schleifen basieren, wird im kompilierten Code nicht unterstützt. JAX stellte Alternativen wie `jax.lax.cond` und` jax.lax.scan` zur Verfügung, um den dynamischen Steuerfluss zu behandeln.

Automatische Differenzierung

Die automatische Differenzierung oder Autodiff ist eine Berechnungstechnik zur Berechnung der Ableitung von Funktionen genau und effektiv. Es spielt eine entscheidende Rolle bei der Optimierung von Modellen für maschinelles Lernen, insbesondere bei der Schulung neuronaler Netzwerke, in denen Gradienten zur Aktualisierung von Modellparametern verwendet werden.

Leitfaden zum blitzschnellen Jax

Wie funktioniert die automatische Differenzierung in JAX?

Autodiff arbeitet mit der Anwendung der Kettenregel des Kalküls, um Komplexfunktionen in einfachere zu zersetzen, die Ableitung dieser Unterfunktionen zu berechnen und dann die Ergebnisse zu kombinieren. Es werden jeden Vorgang während der Funktionsausführung aufgezeichnet, um ein Rechendiagramm zu konstruieren, das dann zur automatischen Berechnung der Derivate verwendet wird.

Es gibt zwei Hauptmodi des Auto-DIFF:

  • Vorwärtsmodus: Berechnet Derivate in einem einzigen Vorwärtsgabereiten durch das Rechendiagramm, effizient für Funktionen mit einer geringen Anzahl von Parametern.
  • Reverse -Modus: Berechnet Derivate in einem einzigen Rückwärts -Durchgang durch das Rechendiagramm, das für Funktionen mit einer großen Anzahl von Parametern effizient ist.

Leitfaden zum blitzschnellen Jax

Schlüsselmerkmale bei der automatischen Differenzierung von JAX

  • Gradientenberechnung (jax.grad): `jax.grad` berechnet die Ableitung einer Scaler-Output-Funktion für ihre Eingabe. Für Funktionen mit mehreren Eingängen kann ein teilweise Derivat erhalten werden.
  • Derivat höherer Ordnung (jax.jacobian, jax.hessian): JAX unterstützt die Berechnung von Derivaten höherer Ordnung wie Jacobians und Hessains, wodurch es für die fortschrittliche Optimierung und Physiksimulation geeignet ist.
  • Komposition mit anderer JAX -Transformation: Autodiff in JAX integriert nahtlos in andere Transformationen wie `jax.jit` und` jax.vmap`, die eine effiziente und skalierbare Berechnung ermöglichen.
  • Reverse-Mode-Differenzierung (Backpropagation): Das Auto-DIFF von JAX verwendet die Reverse-Mode-Differenzierung für Scaler-Output-Funktionen, was für Deep-Lern-Aufgaben hochwirksam ist.
 Importieren Sie Jax.numpy als JNP
von JAX Import Grad, Value_and_grad


# Definieren Sie eine einfache neuronale Netzwerkschicht
Def Layer (Params, x):
    Gewicht, Bias = Parames
    Return Jnp.dot (x, Gewicht) Vorspannung


# Definieren Sie eine skalarwerte Verlustfunktion
Def LUST_FN (Params, x):
    Ausgabe = Schicht (Params, x)
    Return Jnp.sum (Ausgabe) # Reduziert auf einen Skalar


# Holen Sie sich sowohl den Ausgang als auch den Gradienten
Layer_grad = Grad (LUST_FN, ARGNUMS = 0) # Gradient in Bezug auf Params
layer_value_and_grad = value_and_grad (LUST_FN, ARGNUMS = 0) # Beide Wert und Gradienten

# Beispielnutzung
key = jax.random.prngkey (0)
x = jax.random.normal (Schlüssel, (3, 4))
Gewicht = jax.random.normal (Schlüssel, (4, 2))
bias = jax.random.normal (Schlüssel, (2,))

# Fassungsgradienten berechnen
Grads = Layer_grad ((Gewicht, Voreingenommenheit), x)
Ausgabe, Grads = layer_value_and_grad ((Gewicht, Vorspannung), x)

# Mehrere Derivate sind einfach
toppy_grad = grad (Grad (jnp.sin))
x = jnp.Array (2.0)
print (f "zweite Ableitung der Sünde bei x = 2: {toppy_grad (x)}")
Nach dem Login kopieren

Ausgabe:

 Zweite Derivate der Sünde bei x = 2: -0.9092974066734314
Nach dem Login kopieren

Effektivität in Jax

  • Effizienz: Die automatische Differenzierung von JAX ist aufgrund seiner Integration mit XLA hocheffizient und ermöglicht eine Optimierung auf Maschinencodeebene.
  • Komposition : Die Fähigkeit, verschiedene Transformationen zu kombinieren, macht JAX zu einem leistungsstarken Werkzeug zum Aufbau komplexer Pipelines und neuronalen Netzwerke wie CNN, RNN und Transformers.
  • Benutzerfreundlichkeit: Die JAX -Syntax für Autodiff ist einfach und intuitiv, sodass Benutzer den Gradienten berechnen können, ohne sich mit den Details von XLA und komplexen Bibliotheks -APIs zu befassen.

JAX Vectorize Mapping

In JAX ist "VMAP" eine leistungsstarke Funktion, die die Berechnungen automatisch vektorisiert, sodass Sie eine Funktion über Datenstapel anwenden können, ohne manuell Schleifen zu schreiben. Es ordnet eine Funktion über eine Array -Achse (oder mehrere Achsen) ab und bewertet sie parallel, was zu erheblichen Leistungsverbesserungen führen kann.

Wie funktioniert VMAP in Jax?

Die VMAP -Funktion automatisiert den Prozess der Anwendung einer Funktion auf jedes Element entlang einer angegebenen Achse eines Eingangsarrays, während die Effizienz der Berechnung erhalten bleibt. Es transformiert die angegebene Funktion, um angegebene Eingänge zu akzeptieren und die Berechnung auf vektorisierte Weise auszuführen.

Anstatt explizite Schleifen zu verwenden, ermöglicht VMAP, dass Operationen parallel durch Vektorisierung über eine Eingangsachse durchgeführt werden. Dies nutzt die Fähigkeit der Hardware, SIMD-Operationen (einzelne Anweisungen, mehrere Daten) durchzuführen, was zu erheblichen Beschleunigungen führen kann.

Schlüsselmerkmale von VMAP

  • Automatische Vektorisierung: VAMP automatisiert die Batching von Berechnungen, wodurch es einfach ist, den Parallelcode über Stapelabmessungen zu besitzen, ohne die ursprüngliche Funktionslogik zu ändern.
  • Kompositionsfähigkeit mit anderen Transformationen: Es funktioniert nahtlos mit anderen JAX-Transformationen wie JAX.grad für Differenzierung und Jax.jit für die Just-in-Time-Kompilierung, die einen hochoptimierten und flexiblen Code ermöglicht.
  • Handhabung mehrerer Batch-Abmessungen: VMAP unterstützt die Zuordnung über mehrere Eingangsarrays oder Achsen, wodurch es für verschiedene Anwendungsfälle wie die Verarbeitung mehrdimensionaler Daten oder mehrere Variablen gleichzeitig vielseitig verarbeitet wird.
 Importieren Sie Jax.numpy als JNP
von JAX Import VMAP


# Eine Funktion, die bei einzelnen Eingängen funktioniert
Def Single_input_fn (x):
    return jnp.sin (x) jnp.cos (x)


# Vectorisieren Sie es, um an Chargen zu arbeiten
batch_fn = vmap (Single_input_fn)

# Performance vergleichen
x = jnp.arange (1000)

# Ohne VMAP (unter Verwendung eines Listenverständnisses)
result1 = jnp.Array ([Single_input_fn (xi) für xi in x])

# Mit VMAP
result2 = batch_fn (x) # viel schneller!


# Mehrere Argumente vektorisieren
Def Two_input_fn (x, y):
    return x * jnp.sin (y)


# Über beide Eingaben vektorisieren
vectorized_fn = vmap (two_input_fn, in_axes = (0, 0))

# Oder vectorisieren Sie nur den ersten Eingang
teilweise_Vectorized_fn = vmap (Two_input_fn, in_axes = (0, keine))


# Druck
print (result1.shape)
drucken (result2.shape)
print (teilweise_Vectorized_fn (x, y) .shape)
Nach dem Login kopieren

Ausgabe:

 (1000,)
(1000,)
(1000,3)
Nach dem Login kopieren

Wirksamkeit von VMAP in JAX

  • Leistungsverbesserungen: Durch vektorisierende Berechnungen kann VMAP die Ausführung erheblich beschleunigen, indem die parallelen Verarbeitungsfunktionen moderner Hardware wie GPUs und TPUs (Tensor -Verarbeitungseinheiten) eingesetzt werden.
  • Cleaner -Code: Er ermöglicht einen prägnanteren und lesbaren Code, indem die Bedürfnisse manueller Schleifen beseitigt werden.
  • Die Kompatibilität mit JAX und Autodiff: VMAP kann mit automatischer Differenzierung (JAX.grad) kombiniert werden, wodurch die effiziente Berechnung von Derivaten über Datenstapel ermöglicht werden kann.

Wann zu jeder Transformation verwendet werden

Verwenden Sie @jit wann:

  • Ihre Funktion wird mehrmals mit ähnlichen Eingangsformen aufgerufen.
  • Die Funktion enthält schwere numerische Berechnungen.

Verwenden Sie Grad, wann:

  • Sie benötigen Ableitungen zur Optimierung.
  • Implementierung von Algorithmen für maschinelles Lernen
  • Differentialgleichungen für Simulationen lösen

Verwenden Sie VMAP, wenn:

  • Verarbeitungsstapel von Daten mit.
  • Parallelenberechnungen
  • Vermeiden explizite Schleifen

Matrixoperationen und lineare Algebra mit JAX

JAX bietet umfassende Unterstützung für Matrixoperationen und lineare Algebra, wodurch es für wissenschaftliche Computer-, maschinelles Lernen und numerische Optimierungsaufgaben geeignet ist. Die linearen Algebra-Funktionen von JAX ähneln denen in Bibliotheken wie Numpy, aber mit zusätzlichen Funktionen wie automatischer Differenzierung und Just-in-Time-Zusammenstellung für eine optimierte Leistung.

Matrixaddition und Subtraktion

Diese Operationen werden elementzielle Matrizen derselben Form durchgeführt.

 # 1 Matrix -Addition und Subtraktion:

Importieren Sie Jax.numpy als JNP

A = jnp.array ([[1, 2], [3, 4]])
B = Jnp.Array ([5, 6], [7, 8]])

# MATRIX -Addition
C = ab
# Matrix -Subtraktion
D = a - b

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "matrix b: \ n {b}"))
print ("=========================="))
print (f "Matrix -Ehemann von AB: \ n {c}")
print ("=========================="))
print (f "Matrix -Substraktion von AB: \ n {d}")
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Matrixmultiplikation

JAX unterstützt sowohl die elemente multiplikation als auch die produktbasierte Matrixmultiplikation von DOR.

 # Elementweise Multiplikation
E = a * b

# Matrix -Multiplikation (Punktprodukt)
F = Jnp.dot (a, b)

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "matrix b: \ n {b}"))
print ("=========================="))
print (f "Elementwise Multiplikation von a*b: \ n {e}"))
print ("=========================="))
print (f "matrix multiplication von a*b: \ n {f}"))
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Matrixentransponung

Die Transponierung einer Matrix kann unter Verwendung von `Jnp.transpsis ()` erhalten werden

 # Matric Transponieren
G = jnp.transpsis (a)

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "matrix transponieren von a: \ n {g}")
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Matrix inverse

JAX bietet Funktion für die Matrixinversion mit `jnp.linalg.inv ()`

 # Matric -Inversion
H = jnp.linalg.inv (a)

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "Matrixinversion von a: \ n {h}")
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Matrixdeterminante

Die Determinante einer Matrix kann mit `jnp.linalg.det ()` berechnet werden.

 # Matrix -Determinante
det_a = jnp.linalg.det (a)

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "matrix determinante von a: \ n {det_a}"))
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Matrix -Eigenwerte und Eigenvektoren

Sie können die Eigenwerte und Eigenvektoren einer Matrix unter Verwendung von `jnp.linalg.Egh ()` berechnen

 # Eigenwerte und Eigenvektoren
Importieren Sie Jax.numpy als JNP

A = jnp.array ([[1, 2], [3, 4]])
Eigenwerte, Eigenvektoren = jnp.linalg.Egh (a)

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "Eigenwerte von a: \ n {Eigenvalues}")
print ("=========================="))
print (f "Eigenvektoren von a: \ n {Eigenvektoren}")
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Matrix Singularwertabbauung

SVD wird über `jnp.linalg.svd` unterstützt, die bei der Reduktion und der Matrixfaktorisierung von Dimensionalität nützlich sind.

 # Singular Value Decomposition (SVD)

Importieren Sie Jax.numpy als JNP

A = jnp.array ([[1, 2], [3, 4]])
U, s, v = jnp.linalg.svd (a)

print (f "matrix a: \ n {a}"))
print ("=========================="))
print (f "matrix u: \ n {u}"))
print ("=========================="))
print (f "matrix s: \ n {s}"))
print ("=========================="))
print (f "matrix v: \ n {v}"))
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Lösungssystem linearer Gleichungen lösen

Um ein System der linearen Gleichung AX = B zu lösen, verwenden wir `Jnp.Linalg.Solve ()`, wobei a eine quadratische Matrix ist und B ein Vektor oder eine Matrix derselben Anzahl von Zeilen ist.

 # Lösungssystem linearer Gleichungen lösen
Importieren Sie Jax.numpy als JNP

A = jnp.Array ([[2.0, 1.0], [1.0, 3.0]])
B = Jnp.Array ([5.0, 6.0])
x = jnp.linalg.Solve (a, b)

print (f "Wert von x: {x}")
Nach dem Login kopieren

Ausgabe:

 Wert von x: [1,8 1,4]
Nach dem Login kopieren

Berechnung des Gradienten einer Matrixfunktion

Mit der automatischen Differenzierung von JAX können Sie den Gradienten einer Skalarfunktion in Bezug auf eine Matrix berechnen.
Wir werden den Gradienten der folgenden Funktion und die Werte von x berechnen

Funktion

Leitfaden zum blitzschnellen Jax

 # Berechnung des Gradienten einer Matrixfunktion
Jax importieren
Importieren Sie Jax.numpy als JNP


Def matrix_function (x):
    return jnp.sum (jnp.sin (x) x ** 2)


# Berechnen Sie den Abschluss der Funktion
Grad_f = jax.grad (matrix_function)

X = jnp.Array ([[1.0, 2.0], [3.0, 4.0]])
Gradient = Grad_f (x)

print (f "matrix x: \ n {x}"))
print ("=========================="))
print (f "Gradient von matrix_function: \ n {Gradient}"))
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Diese nützlichste Funktion von JAX, die in numerischen Computing, maschinellem Lernen und Physikberechnung verwendet werden. Es gibt noch viele weitere, die Sie erkunden können.

Wissenschaftliches Computing mit Jax

JAX 'leistungsstarke Bibliotheken für wissenschaftliches Computing, JAX, eignet sich am besten für wissenschaftliches Computing für seine Vorabfunktionen wie JIT-Kompilierung, automatische Differenzierung, Vektorisierung, Parallelisierung und GPU-TPU-Beschleunigung. Die Fähigkeit von JAX, Hochleistungs -Computing zu unterstützen, sorgt für eine breite Palette wissenschaftlicher Anwendungen, einschließlich Physiksimulationen, maschinelles Lernen, Optimierung und numerische Analyse.

Wir werden in diesem Abschnitt ein Optimierungsproblem untersuchen.

Optimierungsprobleme

Lassen Sie uns die folgenden Schritte der Optimierungsprobleme durchlaufen:

SCHRITT1: Definieren Sie die Funktion, um (oder das Problem) zu minimieren (oder das Problem)

 # Definieren Sie eine Funktion zum Minimieren (z. B. Rosenbrock -Funktion)

@jit

Def Rosenbrock (x):

Rückgabesumme (100,0 * (x [1:] - x [: - 1] ** 2.0) ** 2.0 (1 - x [: - 1]) ** 2.0)
Nach dem Login kopieren

Hier ist die Rosenbrock -Funktion definiert, was ein häufiges Testerproblem bei der Optimierung ist. Die Funktion nimmt ein Array X als Eingabe an und berechnet eine Valie, die darstellt, wie weit X vom globalen Minimum der Funktion liegt. Der @Jit-Dekorateur wird verwendet, um die Zusammenstellung der Jut-in-Time zu aktivieren, die die Berechnung beschleunigt, indem die Funktion zum effizienten Ausführen auf CPUs und GPUs ausgeführt wird.

STEP2: Implementierung von Gradientenabfällen Schritt

 # Optimierung des Verlaufsabstiegs

@jit

Def Gradient_Descent_Step (x, Learning_rate):

Return X - Learning_Rate * Grad (Rosenbrock) (x)
Nach dem Login kopieren

Diese Funktion führt einen einzelnen Schritt der Gradientenabstiegsoptimierung durch. Der Gradient der Rosenbrock -Funktion wird unter Verwendung von Grad (Rosenbrock) (x) berechnet, was das Ableitungsbereich in Bezug auf x liefert. Der neue Wert von X wird durch Subtraktion der von einem Learning_Rate skalierten Gradienten aktualisiert. Der @jit tut das gleiche wie zuvor.

Schritt 3: Ausführen der Optimierungsschleife

 # Optimieren
X = Jnp.Array ([0,0, 0,0]) # Startpunkt

Learning_rate = 0,001

für i in Range (2000):

x = gradient_descent_step (x, lern_rate)

Wenn ich % 100 == 0:

print (f "Schritt {i}, Wert: {Rosenbrock (x): 4f}")
Nach dem Login kopieren

Die Optimierungsschleife initialisiert den Startpunkt X und führt 1000 Iterationen von Gradientenabstiegungen durch. In jeder Iteration aktualisiert die Funktion gradient_descent_step basierend auf dem aktuellen Gradienten. Alle 100 Schritte werden die aktuelle Schrittnummer und der Wert der Rosenbrock -Funktion bei x gedruckt, was den Fortschritt der Optimierung darstellt.

Ausgabe:

Leitfaden zum blitzschnellen Jax

Lösen Sie das reale Physikproblem mit JAX

Wir werden ein physikalisches System der Bewegung eines gedämpften harmonischen Oszillators simulieren, der Dinge wie ein Massenbestandsystem mit Reibung, Stoßdämpfer in Fahrzeugen oder Schwingung in elektrischen Schaltungen modelliert. Ist es nicht schön? Lass es uns tun.

STEP1: Parameter Definition

 Jax importieren
Importieren Sie Jax.numpy als JNP


# Parameter definieren
Masse = 1,0 # Masse des Objekts (kg)
Dämpfung = 0,1 # Dämpfungskoeffizient (kg/s)
Spring_Constant = 1.0 # Federkonstante (N/M)

# Zeitschritt und Gesamtzeit definieren
dt = 0,01 # Zeitschritt (en)
num_steps = 3000 # Anzahl der Schritte
Nach dem Login kopieren

Die Masse, der Dämpfungskoeffizient und die Federkonstante sind definiert. Diese bestimmen die physikalischen Eigenschaften des gedämpften Harmonischen Oszillators.

Schritt 2: ODE -Definition

 # Definieren Sie das System der ODES
Def damped_harmonic_ocillator (Zustand, T):
    "" "Berechnen Sie die Derivate für einen gedämpften harmonischen Oszillator.

    Zustand: Array mit Position und Geschwindigkeit [x, v]
    T: Zeit (in diesem autonomen System nicht verwendet)
    "" "
    x, v = Zustand
    dxdt = v
    dvdt = -damping / mass * v - spring_constant / mass * x
    return jnp.array ([dxdt, dvdt])
Nach dem Login kopieren

Die gedämpfte harmonische Oszillatorfunktion definiert die Derivate der Position und Geschwindigkeit des Oszillators, die das dynamische System darstellen.

Schritt 3: Eulers Methode

 # Lösen Sie die ODE mithilfe der Eulers Methode
def euler_step (Zustand, t, dt):
    "" "Einen Schritt von Eulers Methode durchführen." "" "
    Derivate = dampiert_harmonic_ocillator (Zustand, t)
    Return State Derivate * dt
Nach dem Login kopieren

Eine einfache numerische Methode wird verwendet, um die ODE zu lösen. Auf der Grundlage des aktuellen Zustands und der Derivat nähert es den Staat zum nächsten Zeitschritt.

Schritt 4: Zeitentwicklungsschleifen

 # Anfangszustand: [Position, Geschwindigkeit]
initial_state = jnp.Array ([1.0, 0,0]) # Beginnen Sie mit der Masse bei x = 1, v = 0

# Zeitentwicklung
States = [initial_state]
Zeit = 0,0
Für Schritt in Reichweite (num_steps):
    Next_State = Euler_Step (Zustände [-1], Zeit, DT)
    States.Append (Next_state)
    Zeit = dt

# Konvertieren Sie die Liste der Zustände in ein JAX -Array zur Analyse
Zustände = jnp.stack (Zustände)
Nach dem Login kopieren

Die Schleifen durchlaufen die angegebenen Zeitschritte und aktualisiert den Status bei jedem Schritt mit der Euler -Methode.

Ausgabe:

Leitfaden zum blitzschnellen Jax

Schritt 5: Darstellung der Ergebnisse

Schließlich können wir die Ergebnisse zeichnen, um das Verhalten des gedämpften harmonischen Oszillators zu visualisieren.

 # Die Ergebnisse aufnehmen
matplotlib.pyplot als pLT importieren

plt.style.use ("ggplot")

Positionen = Zustände [: 0]
Geschwindigkeiten = Zustände [:, 1]
time_points = jnp.arange (0, (num_steps 1) * dt, dt)

Plt.Figure (AbbSize = (12, 6))
PLT.SUBPLOT (2, 1, 1)
PLT.PLOT (TIME_POINTS, Positionen, label = "Position")
Plt.xlabel ("Zeit (s)")
PLT.YLABEL ("Position (m)")
Plt.Legend ()

PLT.SUBPLOT (2, 1, 2)
PLT.PLOT (TIME_POINTS, Geschwindigkeiten, Label = "Velocity", Color = "Orange")
Plt.xlabel ("Zeit (s)")
PLT.YLABEL ("Geschwindigkeit (m/s)")
Plt.Legend ()

Plt.TIGHT_LAYOUT ()
Plt.Show ()
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Ich weiß, dass Sie gespannt sind, wie das neuronale Netzwerk mit Jax aufgebaut werden kann. Lassen Sie uns also tief hineintauchen.

Hier können Sie sehen, dass die Werte allmählich minimiert wurden.

Aufbau neuronaler Netzwerke mit Jax

Jax ist eine leistungsstarke Bibliothek, die das numerische Computing mit hohem Performance mit der Verwendung von Numpy-ähnlicher Syntax kombiniert. In diesem Abschnitt führt Sie durch den Prozess der Erstellung eines neuronalen Netzwerks mit JAX und nutzt seine erweiterten Funktionen für die automatische Differenzierung und die Just-in-Time-Zusammenstellung, um die Leistung zu optimieren.

STEP1: Bibliotheken importieren

Bevor wir in das Aufbau unseres neuronalen Netzwerks eintauchen, müssen wir die erforderlichen Bibliotheken importieren. JAX bietet eine Reihe von Tools zum Erstellen effizienter numerische Berechnungen, während zusätzliche Bibliotheken die Optimierung und Visualisierung unserer Ergebnisse unterstützen.

 Jax importieren
Importieren Sie Jax.numpy als JNP
von Jax Import Grad, JIT
von jax.random import prngkey, normal
Importieren Sie die Optimierungsbibliothek von Optax # JAX
matplotlib.pyplot als pLT importieren
Nach dem Login kopieren

Schritt 2: Erstellen der Modellschichten

Die Erstellung effektiver Modellschichten ist entscheidend für die Definition der Architektur unseres neuronalen Netzwerks. In diesem Schritt werden wir die Parameter für unsere dichten Schichten initialisieren und sicherstellen, dass unser Modell mit gut definierten Gewichten und Verzerrungen für ein effektives Lernen beginnt.

 Def init_layer_params (Schlüssel, n_in, n_out):
    "" Initialisieren Sie Parameter für eine einzige dichte Schicht "" "
    KEY_W, KEY_B = JAX.RANDOM.SPLIT (Schlüssel)
    # Die Initialisierung
    W = Normal (KEY_W, (N_IN, N_OUT) * JNP.SQRT (2.0 / N_IN)  
    B = Normal (KEY_B, (n_out,)) * 0.1
    Rückkehr (W, B)
    
Def Relu (x):
    "" "Relu -Aktivierungsfunktion" "" "
    Return Jnp.maximum (0, x)
    
Nach dem Login kopieren
  • Initialisierungsfunktion : Init_Layer_params initialisiert Gewichte (W) und Verzerrungen (b) für dichte Schichten unter Verwendung der Initialisierung für Gewicht und einem geringen Wert für Verzerrungen. Er oder Kaiming -He -Initialisierung funktioniert besser für Ebenen mit Relu -Aktivierungsfunktionen. Es gibt andere populäre Initialisierungsmethoden wie die Xavier -Initialisierung, die für Schichten mit Sigmoidaktivierung besser funktionieren.
  • Aktivierungsfunktion: Die Relu -Funktion wendet die Relu -Aktivierungsfunktion auf die Eingänge an, die negative Werte auf Null setzen.

Schritt 3: Definieren des Vorwärtspasses

Der Vorwärtspass ist der Eckpfeiler eines neuronalen Netzwerks, da er bestimmt, wie Eingabedaten über das Netzwerk fließen, um eine Ausgabe zu erzeugen. Hier definieren wir eine Methode zur Berechnung der Ausgabe unseres Modells, indem wir Transformationen auf die Eingabedaten über die initialisierten Ebenen anwenden.

 Def Forward (Params, x):
    "" Vorwärtspass für ein zweischichtiges neuronales Netzwerk "" ""
    (W1, B1), (W2, B2) = Params
    # Erste Schicht
    H1 = Relu (jnp.dot (x, w1) b1)
    # Ausgangsschicht
    logits = jnp.dot (h1, w2) b2
    Logits zurückgeben
    
Nach dem Login kopieren
  • Vorwärtspass: Forward führt einen Vorwärtspass durch ein zweischichtiges neuronales Netzwerk durch, wobei die Ausgabe (Logits) durch Anwenden einer linearen Transformation gefolgt von Relu und anderen linearen Transformationen berechnet wird.

S TEP4: Definieren der Verlustfunktion

Eine genau definierte Verlustfunktion ist für die Leitung der Ausbildung unseres Modells unerlässlich. In diesem Schritt werden wir die MSE -Verlustfunktion (Mean Squared Fehler) implementieren, die misst, wie gut die vorhergesagten Ausgänge mit den Zielwerten übereinstimmen, sodass das Modell effektiv lernen kann.

 Def LUST_FN (Params, x, y):
    "" Mean quadratische Fehlerverlust "" "
    Pred = vorwärts (Params, x)
    return jnp.mean ((pred - y) ** 2)
Nach dem Login kopieren
  • Verlustfunktion: LUST_FN berechnet den MSE -Verlust des mittleren Quadratfehlers (MSE) zwischen den vorhergesagten Logits und den Zieletiketten (Y).

Schritt 5: Modellinitialisierung

Mit unserer definierten Modellarchitektur und Verlustfunktion wenden wir uns nun der Modellinitialisierung zu. In diesem Schritt werden die Parameter unseres neuronalen Netzwerks eingerichtet, um sicherzustellen, dass jede Schicht bereit ist, den Trainingsprozess mit zufälligen, aber angemessen skalierten Gewichten und Verzerrungen zu beginnen.

 Def init_model (rng_key, input_dim, hidden_dim, output_dim):
    key1, key2 = jax.random.split (rng_key)
    params = [
        init_layer_params (key1, input_dim, hidden_dim),
        init_layer_params (key2, hidden_dim, output_dim),
    ]
    Rückgabeparameter
    
Nach dem Login kopieren
  • Modellinitialisierung: Init_Model initialisiert die Gewichte und Verzerrungen für beide Schichten der neuronalen Netze. Es wird zwei separate Zufallsschlüssel für jede Schicht verwendet; Initialisierung von Parameter.

Schritt6: Trainingsschritt

Das Training eines neuronalen Netzwerks beinhaltet iterative Aktualisierungen an seinen Parametern basierend auf den berechneten Gradienten der Verlustfunktion. In diesem Schritt werden wir eine Trainingsfunktion implementieren, die diese Aktualisierungen effizient anwendet und unser Modell aus den Daten über mehrere Epochen lernen kann.

 @jit
Def Train_step (Params, opt_state, x_batch, y_batch):
    Verlust, Grads = jax.value_and_grad (LUST_FN) (Params, x_Batch, y_batch)
    Updates, opt_state = optimizer.update (Absolventen, opt_state)
    params = optax.apply_updates (params, updates)
    Rückgabeparameter, opt_state, Verlust
Nach dem Login kopieren
  • Trainingsschritt: Die Funktion "Train_step" führt ein einzelnes Update für Gradientenabsenkungen aus.
  • Es berechnet den Verlust und die Gradienten mit value_and_grad, wodurch sowohl die Funktionswerte als auch andere Gradienten berechnet werden.
  • Die Optimierer -Aktualisierungen werden berechnet und die Modellparameter entsprechend aktualisiert.
  • Das ist JIT-kompiliert für die Leistung.

Schritt 7: Daten- und Trainingsschleife

Um unser Modell effektiv zu trainieren, müssen wir geeignete Daten generieren und eine Trainingsschleife implementieren. In diesem Abschnitt wird das Erstellen von synthetischen Daten für unser Beispiel und die Verwaltung des Trainingsprozesses über mehrere Chargen und Epochen erstellt.

 # Generieren Sie einige Beispieldaten
key = prngkey (0)
x_data = normal (Schlüssel, (1000, 10)) # 1000 Proben, 10 Merkmale
y_data = jnp.sum (x_data ** 2, axis = 1, keepdims = true) # einfache nichtlineare Funktion

# Modell und Optimierer initialisieren
params = init_model (taste, input_dim = 10, hidden_dim = 32, output_dim = 1)
Optimizer = optax.adam (Learning_Rate = 0,001)
opt_state = optimizer.init (params)

# Trainingsschleife
batch_size = 32
num_epochs = 100
num_batches = x_data.shape [0] // batch_size

# Arrays zum Speichern von Epoche und Verlustwerten
epoch_array = []
LUST_Array = []

Für Epoche im Bereich (num_epochs):
    epoch_loss = 0,0
    Für Batch in Range (num_batches):
        idx = jax.random.permutation (Schlüssel, batch_size)
        x_batch = x_data [idx]
        y_batch = y_data [idx]
        Params, opt_state, loss = train_step (Params, opt_state, x_batch, y_batch)
        epoch_loss = Verlust

    # Speichern Sie den durchschnittlichen Verlust für die Epoche
    avg_loss = epoch_loss / num_batches
    epoch_array.append (Epoche)
    LUST_Array.Append (avg_loss)

    Wenn epoch % 10 == 0:
        print (f "epoch {epoch}, Verlust: {avg_loss: .4f}")
Nach dem Login kopieren
  • Datenerzeugung : Zufällige Trainingsdaten (x_data) und entsprechende Werte von Target (Y_Data) werden erstellt. Modell- und Optimierer -Initialisierung: Die Modellparameter und der Optimiererzustand werden initialisiert.
  • Trainingsschleife: Die Netzwerke werden unter Verwendung von Mini-Batch-Gradientenabstiegungen über eine bestimmte Anzahl von Epochen geschult.
  • Trainingsschleifen iterieren über Chargen und führen Gradientenaktualisierungen mithilfe der Funktion "train_step) durch. Der durchschnittliche Verlust pro Epoche wird berechnet und gespeichert. Es druckt die Epochenzahl und den durchschnittlichen Verlust.

Schritt8: Darstellung der Ergebnisse

Die Visualisierung der Trainingsergebnisse ist der Schlüssel zum Verständnis der Leistung unseres neuronalen Netzwerks. In diesem Schritt werden wir den Trainingsverlust über Epochen planen, um zu beobachten, wie gut das Modell lernt und potenzielle Probleme im Trainingsprozess identifiziert.

 # Die Ergebnisse zeichnen
PLT.PLOT (epoch_array, Loss_array, Label = "Trainingsverlust")
Plt.xlabel ("Epoche")
Plt.ylabel ("Verlust")
PLT.TITLE ("Trainingsverlust über Epochen")
Plt.Legend ()
Plt.Show ()
Nach dem Login kopieren

Diese Beispiele zeigen, wie JAX eine hohe Leistung mit sauberem, lesbarem Code kombiniert. Der von JAX geförderte funktionale Programmierstil erleichtert es einfach, Operationen zu komponieren und Transformationen anzuwenden.

Ausgabe:

Leitfaden zum blitzschnellen Jax

Handlung:

Leitfaden zum blitzschnellen Jax

Diese Beispiele zeigen, wie JAX eine hohe Leistung mit sauberem, lesbarem Code kombiniert. Der von JAX geförderte funktionale Programmierstil erleichtert es einfach, Operationen zu komponieren und Transformationen anzuwenden.

Best Practice und Tipps

In building neural networks, adhering to best practices can significantly enhance performance and maintainability. This section will discuss various strategies and tips for optimizing your code and improving the overall efficiency of your JAX-based models.

Performance Optimization

Optimizing performance is essential when working with JAX, as it enables us to fully leverage its capabilities. Here, we will explore different techniques for improving the efficiency of our JAX functions, ensuring that our models run as quickly as possible without sacrificing readability.

JIT Compilation Best Practices

Just-In-Time (JIT) compilation is one of the standout features of JAX, enabling faster execution by compiling functions at runtime. This section will outline best practices for effectively using JIT compilation, helping you avoid common pitfalls and maximize the performance of your code.

Bad Function

 import jax
import jax.numpy as jnp
from jax import jit
from jax import lax


# BAD: Dynamic Python control flow inside JIT
@jit
def bad_function(x, n):
    for i in range(n): # Python loop - will be unrolled
        x = x 1
    return x
    
    
print("===========================")
# print(bad_function(1, 1000)) # does not work
    
Nach dem Login kopieren

This function uses a standard Python loop to iterate n times, incrementing the of x by 1 on each iteration. When compiled with jit, JAX unrolls the loop, which can be inefficient, especially for large n. This approach does not fully leverage JAX's capabilities for performance.

Good Function

 # GOOD: Use JAX-native operations
@jit
def good_function(x, n):
    return xn # Vectorized operation


print("===========================")
print(good_function(1, 1000))
Nach dem Login kopieren

This function does the same operation, but it uses a vectorized operation (xn) instead of a loop. This approach is much more efficient because JAX can better optimize the computation when expressed as a single vectorized operation.

Best Function

 # BETTER: Use scan for loops


@jit
def best_function(x, n):
    def body_fun(i, val):
        return val 1

    return lax.fori_loop(0, n, body_fun, x)


print("===========================")
print(best_function(1, 1000))
Nach dem Login kopieren

This approach uses `jax.lax.fori_loop`, which is a JAX-native way to implement loops efficiently. The `lax.fori_loop` performs the same increment operation as the previous function, but it does so using a compiled loop structure. The body_fn function defines the operation for each iteration, and `lax.fori_loop` executes it from o to n. This method is more efficient than unrolling loops and is especially suitable for cases where the number of iterations isn't known ahead of time.

Ausgabe :

 ===========================
===========================
1001
===========================
1001
Nach dem Login kopieren

The code demonstrates different approaches to handling loops and control flow within JAX's jit-complied functions.

Speicherverwaltung

Efficient memory management is crucial in any computational framework, especially when dealing with large datasets or complex models. This section will discuss common pitfalls in memory allocation and provide strategies for optimizing memory usage in JAX.

Inefficient Memory Management

 # BAD: Creating large temporary arrays
@jit
def inefficient_function(x):
    temp1 = jnp.power(x, 2) # Temporary array
    temp2 = jnp.sin(temp1) # Another temporary
    return jnp.sum(temp2)
Nach dem Login kopieren

inefficient_function(x): This function creates multiple intermediate arrays, temp1, temp1 and finally the sum of the elements in temp2. Creating these temporary arrays can be inefficient because each step allocates memory and incurs computational overhead, leading to slower execution and higher memory usage.

Efficient Memory Management

 # GOOD: Combining operations
@jit
def efficient_function(x):
    return jnp.sum(jnp.sin(jnp.power(x, 2))) # Single operation
Nach dem Login kopieren

This version combines all operations into a single line of code. It computes the sine of squared elements of x directly and sums the results. By combining the operation, it avoids creating intermediate arrays, reducing memory footprints and improving performance.

Test Code

 x = jnp.array([1, 2, 3])
print(x)
print(inefficient_function(x))
print(efficient_function(x))
Nach dem Login kopieren

Ausgabe:

 [1 2 3]
0.49678695
0.49678695
Nach dem Login kopieren

The efficient version leverages JAX's ability to optimize the computation graph, making the code faster and more memory-efficient by minimizing temporary array creation.

Debugging Strategies

Debugging is an essential part of the development process, especially in complex numerical computations. In this section, we will discuss effective debugging strategies specific to JAX, enabling you to identify and resolve issues quickly.

Using print inside JIT for Debugging

The code shows techniques for debugging within JAX, particularly when using JIT-compiled functions.

 import jax.numpy as jnp
from jax import debug


@jit
def debug_function(x):
    # Use debug.print instead of print inside JIT
    debug.print("Shape of x: {}", x.shape)
    y = jnp.sum(x)
    debug.print("Sum: {}", y)
    return y
Nach dem Login kopieren
 # For more complex debugging, break out of JIT
def debug_values(x):
    print("Input:", x)
    result = debug_function(x)
    print("Output:", result)
    return result
    
Nach dem Login kopieren
  • debug_function(x): This function shows how to use debug.print() for debugging inside a jit compiled function. In JAX, regular Python print statements are not allowed inside JIT due to compilation restrictions, so debug.print() is used instead.
  • It prints the shape of the input array x using debug.print()
  • After computing the sum of the elements of x, it prints the resulting sum using debug.print()
  • Finally, the function returns the computed sum y.
  • debug_values(x) function serves as a higher-level debugging approach, breaking out of the JIT context for more complex debugging. It first prints the inputs x using regular print statement. Then calls debug_function(x) to compute the result and finally prints the output before returning the results.

Ausgabe:

 print("===========================")
print(debug_function(jnp.array([1, 2, 3])))
print("===========================")
print(debug_values(jnp.array([1, 2, 3])))
Nach dem Login kopieren

Leitfaden zum blitzschnellen Jax

This approach allows for a combination of in-JIT debugging with debug.print() and more detailed debugging outside of JIT using standard Python print statements.

Common Patterns and Idioms in JAX

Finally, we will explore common patterns and idioms in JAX that can help streamline your coding process and improve efficiency. Familiarizing yourself with these practices will aid in developing more robust and performant JAX applications.

Device Memory Management for Processing Large Datasets

 # 1. Device Memory Management
def process_large_data(data):
    # Process in chunks to manage memory
    chunk_size = 100
    results = []

    for i in range(0, len(data), chunk_size):
        chunk = data[i : i chunk_size]
        chunk_result = jit(process_chunk)(chunk)
        results.append(chunk_result)

    return jnp.concatenate(results)


def process_chunk(chunk):
    chunk_temp = jnp.sqrt(chunk)
    return chunk_temp
Nach dem Login kopieren

This function processes large datasets in chunks to avoid overwhelming device memory.

It sets chunk_size to 100 and iterates over the data increments of the chunk size, processing each chunk separately.

For each chunk, the function uses jit(process_chunk) to JIT-compile the processing operation, which improves performance by compiling it ahead of time.

The result of each chunk is concatenated into a single array using jnp.concatenated(result) to form a single list.

Ausgabe:

 print("===========================")
data = jnp.arange(10000)
print(data.shape)

print("===========================")
print(data)

print("===========================")
print(process_large_data(data))
Nach dem Login kopieren

Leitfaden zum blitzschnellen Jax

Handling Random Seed for Reproducibility and Better Data Generation

The function create_traing_state() demonstrates managing random number generators (RNGs) in JAX, which is essential for reproducibility and consistent results.

 # 2. Handling Random Seeds
def create_training_state(rng):
    # Split RNG for different uses
    rng, init_rng = jax.random.split(rng)
    params = init_network(init_rng)

    return params, rng # Return new RNG for next use
    
Nach dem Login kopieren

It starts with an initial RNG (rng) and splits it into two new RNGs using jax.random.split(). Split RNGs perform different tasks: `init_rng` initializes network parameters, and the updated RNG returns for subsequent operations.

The function returns both the initialized network parameters and the new RNG for further use, ensuring proper handling of random states across different steps.

Now test the code using mock data

 def init_network(rng):
    # Initialize network parameters
    zurückkehren {
        "w1": jax.random.normal(rng, (784, 256)),
        "b1": jax.random.normal(rng, (256,)),
        "w2": jax.random.normal(rng, (256, 10)),
        "b2": jax.random.normal(rng, (10,)),
    }


print("===========================")

key = jax.random.PRNGKey(0)
params, rng = create_training_state(key)


print(f"Random number generator: {rng}")

print(params.keys())

print("===========================")


print("===========================")
print(f"Network parameters shape: {params['w1'].shape}")

print("===========================")
print(f"Network parameters shape: {params['b1'].shape}")
print("===========================")
print(f"Network parameters shape: {params['w2'].shape}")

print("===========================")
print(f"Network parameters shape: {params['b2'].shape}")


print("===========================")
print(f"Network parameters: {params}")
Nach dem Login kopieren

Ausgabe:

Leitfaden zum blitzschnellen Jax

Leitfaden zum blitzschnellen Jax

Using Static Arguments in JIT

 def g(x, n):
    i = 0
    while i <p> <strong>Ausgabe:</strong></p><pre class="brush:php;toolbar:false"> 30
Nach dem Login kopieren

You can use a static argument if JIT compiles the function with the same arguments each time. This can be useful for the performance optimization of JAX functions.

 from functools import partial


@partial(jax.jit, static_argnames=["n"])
def g_jit_decorated(x, n):
    i = 0
    while i <p>If You want to use static arguments in JIT as a decorator you can use jit inside of functools. partial() function.</p><p> <strong>Ausgabe:</strong></p><pre class="brush:php;toolbar:false"> 30
Nach dem Login kopieren

Now, we have learned and dived deep into many exciting concepts and tricks in JAX and overall programming style.

Was kommt als nächstes?

  • Experiment with Examples: Try to modify the code examples to learn more about JAX. Build a small project for a better understanding of JAX's transformations and APIs. Implement classical Machine Learning algorithms with JAX such as Logistic Regression, Support Vector Machine, and more.
  • Explore Advanced Topics : Parallel computing with pmap, Custom JAX transformations, Integration with other frameworks

All code used in this article is here

Abschluss

JAX is a powerful tool that provides a wide range of capabilities for machine learning, Deep Learning, and scientific computing. Start with basics, experimenting, and get help from JAX's beautiful documentation and community. There are so many things to learn and it will not be learned by just reading others' code you have to do it on your own. So, start creating a small project today in JAX. The key is to Keep Going, learn on the way.

Key Takeaways

  • Familiar NumPY-like interface and APIs make learning JAX easy for beginners. Most NumPY code works with minimal modifications.
  • JAX encourages clean functional programming patterns that lead to cleaner, more maintainable code and upgradation. But If developers want JAX fully compatible with Object Oriented paradigm.
  • What makes JAX's features so powerful is automatic differentiation and JAX's JIT compilation, which makes it efficient for large-scale data processing.
  • JAX excels in scientific computing, optimization, neural networks, simulation, and machine learning which makes developer easy to use on their respective project.

Häufig gestellte Fragen

Q1. What makes JAX different from NumPY?

A. Although JAX feels like NumPy, it adds automatic differentiation, JIT compilation, and GPU/TPU support.

Q2. Do I need a GPU to use JAX?

A. In a single word big NO, though having a GPU can significantly speed up computation for larger data.

Q3. Is JAX a good alternative to NumPy?

A. Yes, You can use JAX as an alternative to NumPy, though JAX's APIs look familiar to NumPy JAX is more powerful if you use JAX's features well.

Q4. Can I use my existing NumPy code with JAX?

A. Most NumPy code can be adapted to JAX with minimal changes. Usually just changing import numpy as np to import jax.numpy as jnp.

Q5. Is JAX harder to learn than NumPy?

A. The basics are just as easy as NumPy! Tell me one thing, will you find it hard after reading the above article and hands-on? I answered it for you. YES hard. Every framework, language, libraries is hard not because it is hard by design but because we don't give much time to explore it. Give it time to get your hand dirty it will be easier day by day.

Die in diesem Artikel gezeigten Medien sind nicht im Besitz von Analytics Vidhya und werden nach Ermessen des Autors verwendet.

Das obige ist der detaillierte Inhalt vonLeitfaden zum blitzschnellen Jax. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage