稲妻のガイド

Jennifer Aniston
リリース: 2025-03-19 11:21:11
オリジナル
626 人が閲覧しました

ねえ、仲間のPython愛好家! Supersonic SpeedでNumpyコードが実行されたいと思ったことはありますか? Jaxに会いましょう!あなたの機械学習、ディープラーニング、数値コンピューティングの旅におけるあなたの新しい親友。それを超大国を持つnumpyと考えてください。グラデーションを自動的に処理し、JITを使用して高速に実行するためにコードをコンパイルし、汗をかくことなくGPUとTPUで実行することもできます。ニューラルネットワークを構築したり、科学データの削り取り、トランスモデルを微調整したり、計算をスピードアップしようとしている場合でも、Jaxは背中を持っています。飛び込んで、Jaxの特別な理由を見てみましょう。

このガイドは、JAXとそのエコシステムの詳細な紹介を提供します。

学習目標

  • Jaxの核となる原則と、それらがNumpyとどのように異なるかを説明します。
  • JAXの3つの重要な変換を適用して、Pythonコードを最適化します。 numpy操作を効率的なJAX実装に変換します。
  • JAXコードで共通のパフォーマンスボトルネックを特定して修正します。典型的な落とし穴を避けながら、JITコンパイルを正しく実装します。
  • Jaxを使用して、ニューラルネットワークをゼロから構築およびトレーニングします。 JAXの機能的アプローチを使用して、一般的な機械学習操作を実装します。
  • Jaxの自動差別化を使用して最適化の問題を解決します。効率的なマトリックス操作と数値計算を実行します。
  • JAX固有の問題に効果的なデバッグ戦略を適用します。大規模な計算にメモリ効率の高いパターンを実装します。

この記事は、データサイエンスブログソンの一部として公開されました

目次

  • Jaxとは何ですか?
  • なぜJaxが際立っているのですか?
  • Jaxを始めましょう
  • なぜJaxを学ぶのですか?
  • 必須のJAX変換
  • JAXを使用したニューラルネットワークの構築
  • ベストプラクティスとヒント
  • パフォーマンスの最適化
  • デバッグ戦略
  • JAXの一般的なパターンとイディオム
  • 次は何ですか?
  • 結論
  • よくある質問

Jaxとは何ですか?

公式文書によると、Jaxは、高性能の数値コンピューティングと大規模な機械学習用に設計された、加速指向の配列計算とプログラム変換のためのPythonライブラリです。したがって、Jaxは本質的にステロイドのNumpyであり、おなじみのNumpyスタイルの操作と自動分化とハードウェアアクセラレーションを組み合わせています。 3つの世界の最大限を得ると考えてください。

  • Numpyのエレガントな構文と配列操作
  • 自動分化機能のようなPytorch
  • ハードウェアアクセラレーションとコンパイルの利点のためのXLA (加速線形代数)。

なぜJaxが際立っているのですか?

Jaxを際立たせるのは、その変換です。これらは、Pythonコードを変更できる強力な機能です。

  • JIT :迅速な実行のためのジャストインタイムコンパイル
  • GRAD :コンピューティンググラデーションの自動差別化
  • VMAP :バッチ処理用の自動ベクトル化

これが簡単な見方です:

 JAX.numpyをJNPとしてインポートします
JAXインポートグレード、JITから
#単純な関数を定義します
@jit#コンピレーションでスピードアップします
def square_sum(x):
jnp.sum(jnp.square(x))を返します
#グラデーション関数を自動的に取得します
gradient_fn = grad(square_sum)
#それを試してみてください
x = jnp.array([1.0、2.0、3.0])
print(f "gradient:{gradient_fn(x)}")
ログイン後にコピー

出力:

勾配:[2。 4。6。]
ログイン後にコピー

Jaxを始めましょう

以下に、Jaxを開始するためのいくつかの手順に従います。

ステップ1:インストール

JAXのセットアップは、CPUのみの使用において簡単です。詳細については、JAXドキュメントを使用できます。

ステップ2:プロジェクトの環境の作成

プロジェクトのコンドマ環境を作成します

#ジャックスのコンドラ環境を作成します
$ conda create  - 名前Jaxdev python = 3.11

#env
$ condaはJaxdevをアクティブにします

#プロジェクトDir Name Jax101を作成します
$ mkdir jax101

#dirに移動します
$ CD JAX101
ログイン後にコピー

ステップ3:Jaxのインストール

新しく作成された環境にJAXをインストールします

#CPUのみ
ピップインストール - アップグレードPIP
ピップインストール - アップグレード「Jax」

#GPUの場合
ピップインストール - アップグレードPIP
PIPインストール - アップグレード "Jax [cuda12]"
ログイン後にコピー

今、あなたは本物に飛び込む準備ができています。実用的なコーディングで手を汚す前に、いくつかの新しい概念を学びましょう。私は最初に概念を説明し、それから一緒にコーディングして実際の視点を理解します。

まず、動機を得て、ちなみに、なぜ私たちは再び新しいライブラリを学ぶのですか?このガイド全体で、できるだけ簡単に段階的にその質問に答えます。

なぜJaxを学ぶのですか?

Jaxを電動工具と考えてください。 Numpyは信頼できる手鋸のようなものですが、Jaxはモダンな電動鋸のようなものです。もう少しステップと知識が必要ですが、集中的な計算タスクにはパフォーマンスの利点が価値があります。

  • パフォーマンス:JAXコードは、特にGPUとTPUで、純粋なPythonまたはNumpyコードよりも大幅に速く実行できます
  • 柔軟性:機械学習だけではありません。JAXは科学的なコンピューティング、最適化、シミュレーションに優れています。
  • モダンなアプローチ: Jaxは、よりクリーンで保守可能なコードにつながる機能的なプログラミングパターンを奨励しています。

次のセクションでは、JITコンピレーションから始めて、Jaxの変換に深く掘り下げます。これらの変換は、Jaxに超大国に与えるものであり、それらを理解することは、JAXを効果的に活用するための鍵です。

必須のJAX変換

Jaxの変換は、NumpyやScipyなどの数値計算ライブラリとは際立っているものです。それぞれを探索し、コードをどのように充電できるかを見てみましょう。

JITまたはジャストインタイムコンピレーション

Just-in-Timeコンピレーションは、前もってではなく実行時にプログラムの一部をコンパイルすることにより、コードの実行を最適化します。

JITはJaxでどのように機能しますか?

Jaxでは、jax.jitはPython関数をJITコンパイルバージョンに変換します。 @jax.jitで関数を飾ることは、実行グラフをキャプチャし、最適化し、XLAを使用してコンパイルします。コンパイルされたバージョンは、特に繰り返される関数呼び出しのために、重要なスピードアップを実行し、実行します。

これがあなたがそれを試す方法です。

 JAX.numpyをJNPとしてインポートします
JAXインポートJITから
インポート時間


#計算集中関数
def slow_function(x):
    _ in range(1000):
        x = jnp.sin(x)jnp.cos(x)
    xを返します


#JITを使用した同じ関数
@jit
def fast_function(x):
    _ in range(1000):
        x = jnp.sin(x)jnp.cos(x)
    xを返します
ログイン後にコピー

同じ関数があります。1つは単なるPythonコンパイルプロセスであり、もう1つはJaxのJITコンピレーションプロセスとして使用されます。 1000のデータポイントのサイン関数とコサイン関数の合計を計算します。時間を使用してパフォーマンスを比較します。

 #パフォーマンスを比較します
x = jnp.arange(1000)

#ウォームアップJIT
fast_function(x)#最初の呼び出しは関数をコンパイルします

#時間比較
start = time.time()
slow_result = slow_function(x)
print(f "without jit:{time.time() -  start:.4f}秒")

start = time.time()
fast_result = fast_function(x)
print(f "with jit:{time.time() -  start:.4f}秒")
ログイン後にコピー

結果はあなたを驚かせるでしょう。 JITコンピレーションは、通常のコンピレーションよりも333倍高速です。自転車とブガティのシロンを比較するようなものです。

出力:

 JITなし:0.0330秒
JITで:0.0010秒
ログイン後にコピー

JITは超高速実行ブーストを提供できますが、適切に使用する必要があります。

一般的なJIT落とし穴

JITは、静的な形状とタイプで最適に機能します。配列値に依存するPythonループと条件を使用しないでください。 JITは動的配列では動作しません。

 #bad -pythonコントロールフローを使用します
@jit
def bad_function(x):
    x [0]> 0の場合:#これはJITでうまく機能しません
        xを返します
    return -x


#print(bad_function(jnp.array([1、2、3])))


#グッド -  Jax Control Flowを使用します
@jit
def good_function(x):
    jnp.where(x [0]> 0、x、-x)#jax -native条件を返します


print(good_function(jnp.array([1、2、3])))
ログイン後にコピー

出力:

稲妻のガイド

つまり、jitは計算中にxの値に位置していなかったため、bad_functionが悪いことを意味します。

出力:

 [1 2 3]
ログイン後にコピー

制限と考慮事項

  • コンパイルオーバーヘッド: JITコンパイルされた関数が初めて実行されたとき、コンパイルのためにオーバーヘッドがあります。コンピレーションコストは、小さな機能または1回だけ呼ばれる機能のパフォーマンスの利点を上回る可能性があります。
  • 動的なPython機能: JaxのJITでは、関数が「静的」になる必要があります。 Pythonループに基づいた形状や値の変更など、動的制御フローは、コンパイルされたコードではサポートされていません。 Jaxは、動的制御フローを処理するために、「Jax.lax.cond」や `jax.lax.scan`などの代替案を提供しました。

自動分化

自動分化、または自動装置は、機能の導関数を正確かつ効果的に計算するための計算手法です。これは、特にモデルパラメーターを更新するために勾配を使用するニューラルネットワークのトレーニングにおいて、機械学習モデルを最適化する上で重要な役割を果たします。

稲妻のガイド

JAXで自動分化はどのように機能しますか?

AutoDiffは、複雑な関数を単純な機能に分解し、これらのサブファンクションの導関数を計算し、結果を組み合わせるように複雑な関数を分解して、計算の鎖規則を適用することにより機能します。関数実行中に各操作を記録して計算グラフを構築し、デリバティブを自動的に計算するために使用されます。

オートディフには2つの主要なモードがあります。

  • フォワードモード:少数のパラメーターを使用して関数に効率的に、計算グラフを通過する単一のフォワードパスでデリバティブを計算します。
  • リバースモード:計算グラフを通る単一の後方パスで導関数を計算します。これは、多数のパラメーターを使用して関数に効率的です。

稲妻のガイド

JAX自動分化の重要な機能

  • 勾配計算(jax.grad): `jax.grad`は、その入力のスカラー出力関数の導関数を計算します。複数の入力を持つ関数の場合、部分微分を取得できます。
  • 高次派生(jax.jacobian、jax.hessian): Jaxは、JacobiansやHessainsなどの高次誘導体の計算をサポートしているため、高度な最適化と物理シミュレーションに適しています。
  • 他のJAX変換との複合性: JaxのAutodiffは、「jax.jit」や「jax.vmap」などの他の変換とシームレスに統合し、効率的でスケーラブルな計算を可能にします。
  • リバースモード差別化(BackPropagation): JaxのAuto-Diffは、スカラー出力関数にリバースモード差別化を使用します。これは、深い学習タスクに非常に効果的です。
 JAX.numpyをJNPとしてインポートします
Jax Import Grad、value_and_gradから


#単純なニューラルネットワークレイヤーを定義します
defレイヤー(params、x):
    重量、バイアス=パラメーション
    jnp.dot(x、weight)バイアスを返します


#スカラー値の損失関数を定義します
def loss_fn(params、x):
    output = layer(params、x)
    jnp.sum(output)を返します#スカラーへの削減


#出力と勾配の両方を取得します
layer_grad = grad(loss_fn、argnums = 0)#パラメーションに関するグラデーション
layer_value_and_grad = value_and_grad(loss_fn、argnums = 0)#値とグラデーションの両方

#例の使用
key = jax.random.prngkey(0)
x = jax.random.normal(key、(3、4))
weight = jax.random.normal(key、(4、2))
バイアス= jax.random.normal(key、(2、))

#勾配を計算します
grads = layer_grad((weight、bias)、x)
出力、grads = layer_value_and_grad((weight、bias)、x)

#複数のデリバティブは簡単です
Twile_grad = grad(grad(jnp.sin))
x = jnp.array(2.0)
print(f "x = 2:{twol_grad(x)}"での罪の2番目の派生物)
ログイン後にコピー

出力:

 x = 2:-0.909297406734314でのsinの2番目の誘導体
ログイン後にコピー

Jaxの有効性

  • 効率: JAXの自動分化は、XLAとの統合により非常に効率的であり、マシンコードレベルで最適化できます。
  • 複合性:異なる変換を組み合わせる機能により、JAXは、CNN、RNN、トランスなどの複雑な機械学習パイプラインとニューラルネットワークアーキテクチャを構築するための強力なツールになります。
  • 使いやすさ: AutoDiffのJaxの構文はシンプルで直感的であり、XLAおよび複雑なライブラリAPIの詳細を掘り下げることなく勾配を計算できるようになります。

Jax Vectorizeマッピング

JAXでは、「VMAP」は計算を自動的に自動的に自動的に補う強力な関数であり、ループを手動で書き込むことなくデータのバッチに関数を適用できます。アレイ軸(または複数の軸)を介して関数をマッピングし、並行して効率的に評価し、パフォーマンスの大幅な改善につながる可能性があります。

VMAPはJaxでどのように機能しますか?

VMAP関数は、計算の効率を維持しながら、入力配列の指定された軸に沿って各要素に関数を適用するプロセスを自動化します。与えられた関数を変換して、バッチ入力を受け入れ、ベクトル化された方法で計算を実行します。

明示的なループを使用する代わりに、VMAPを使用すると、入力軸上でベクトル化することにより、操作を並行して実行できます。これにより、SIMD(単一命令、複数のデータ)操作を実行するハードウェアの機能が活用され、実質的なスピードアップになる可能性があります。

VMAPの重要な機能

  • 自動ベクトル化: VAMPは計算のバッチを自動化し、元の関数ロジックを変更せずにバッチ寸法上のコードを簡単に並列にするようにします。
  • 他の変換との複合性: jax.grad for DifteriationやJax.jitなどの他のJax変換やJust-in-timeコンピレーションなど、非常に最適化された柔軟なコードを可能にします。
  • 複数のバッチ寸法の処理: VMAPは、複数の入力配列または軸上のマッピングをサポートしているため、多次元データまたは複数の変数を同時に処理するなど、さまざまなユースケースに汎用性があります。
 JAX.numpyをJNPとしてインポートします
JAXインポートVMAPから


#単一の入力で動作する関数
def single_input_fn(x):
    jnp.sin(x)jnp.cos(x)を返します


#バッチで動作するようにvectorizeします
batch_fn = vmap(single_input_fn)

#パフォーマンスを比較します
x = jnp.arange(1000)

#VMAPなし(リスト理解を使用)
result1 = jnp.array([single_input_fn(xi)for xi in x])

#vmapで
result2 = batch_fn(x)#はるかに高速!


#複数の引数を獲得します
def two_input_fn(x、y):
    x * jnp.sin(y)を返す


#両方の入力を介してvectorizeします
vectorized_fn = vmap(two_input_fn、in_axes =(0、0)))

#または最初の入力のみでベクトル化します
partivally_vectorized_fn = vmap(two_input_fn、in_axes =(0、none))


#印刷
print(result1.shape)
print(result2.shape)
print(partivally_vectorized_fn(x、y).shape)
ログイン後にコピー

出力:

 (1000、)
(1000、)
(1000,3)
ログイン後にコピー

JAXのVMAPの有効性

  • パフォーマンスの改善:計算をベクトル化することにより、VMAPは、GPUやTPU(テンソル処理ユニット)などの最新のハードウェアの並列処理機能を活用することにより、実行を大幅に高速化できます。
  • クリーナーコード:手動ループの必要性を排除することにより、より簡潔で読みやすいコードが可能になります。
  • JAXおよびAutoDiffとの互換性: VMAPは、自動分化(jax.grad)と組み合わせることができ、データのバッチ上の導関数の効率的な計算を可能にします。

各変換をいつ使用するか

@jitを使用するとき:

  • あなたの関数は、同様の入力形状で複数回と呼ばれます。
  • この関数には、重い数値計算が含まれています。

段階を使用する場合:

  • 最適化にはデリバティブが必要です。
  • 機械学習アルゴリズムの実装
  • シミュレーションの微分方程式の解決

vmapを使用して:

  • データのバッチを処理します。
  • 並列化計算
  • 明示的なループを避けます

JAXを使用したマトリックス操作と線形代数

JAXは、マトリックス操作と線形代数を包括的にサポートし、科学的コンピューティング、機械学習、数値最適化タスクに適しています。 Jaxの線形代数機能は、Numpyなどのライブラリに見られるものと似ていますが、自動分化や最適化されたパフォーマンスのためのジャストインタイムコンパイルなどの追加機能があります。

マトリックスの添加と減算

これらの操作は、同じ形状の要素ごとのマトリックスを実行します。

 #1マトリックスの追加と減算:

JAX.numpyをJNPとしてインポートします

a = jnp.array([[1、2]、[3、4]])
b = jnp.array([[5、6]、[7、8]])

#マトリックスの追加
c = ab
#マトリックスの減算
d = a -b

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "マトリックスB:\ n {b}")
print( "========================")
print(f "abのマトリックスアディション:\ n {c}")
print( "========================")
print(f "abのマトリックスサブラクション:\ n {d}")
ログイン後にコピー

出力:

稲妻のガイド

マトリックス乗算

JAXは、要素ごとの乗算とDOR製品ベースのマトリックス増殖の両方をサポートしています。

 #要素ごとの乗算
E = a * b

#マトリックス乗算(DOT製品)
f = jnp.dot(a、b)

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "マトリックスB:\ n {b}")
print( "========================")
印刷(f "a*b:\ n {e}"の要素ごとの乗算)
print( "========================")
印刷(f "a*b:\ n {f}"のマトリックス乗算)
ログイン後にコピー

出力:

稲妻のガイド

マトリックスは転置します

マトリックスの転置は、 `jnp.transpose()`を使用して取得できます。

 #MATRIC Transpose
g = jnp.Transpose(a)

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "マトリックスa:\ n {g}"の転置 ")
ログイン後にコピー

出力:

稲妻のガイド

マトリックスの逆

jaxは、 `jnp.linalg.inv()`を使用してマトリックス反転の関数を提供します

#マトリック反転
h = jnp.linalg.inv(a)

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "a:\ n {h}のマトリックス反転")
ログイン後にコピー

出力:

稲妻のガイド

マトリックス決定因子

マトリックスの決定要因は、 `jnp.linalg.det()`を使用して計算できます。

 #マトリックス決定要因
det_a = jnp.linalg.det(a)

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "a:\ n {det_a}のマトリックス決定要因")
ログイン後にコピー

出力:

稲妻のガイド

マトリックス固有値と固有ベクトル

`jnp.linalg.eigh()`を使用して、マトリックスの固有値と固有ベクトルを計算できます。

 #固有値と固有ベクトル
JAX.numpyをJNPとしてインポートします

a = jnp.array([[1、2]、[3、4]])
固有値、固有ベクトル= jnp.linalg.eigh(a)

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "a:\ n {eigenvalues}の固有値")
print( "========================")
印刷(f "a:\ n {eigenvectors}のeigenvectors")
ログイン後にコピー

出力:

稲妻のガイド

マトリックス特異値分解

SVDは、次元削減とマトリックス因数分解に役立つ「JNP.LinalG.SVD」を介してサポートされています。

 #単一値分解(SVD)

JAX.numpyをJNPとしてインポートします

a = jnp.array([[1、2]、[3、4]])
u、s、v = jnp.linalg.svd(a)

印刷(f "マトリックスA:\ n {a}")
print( "========================")
印刷(f "マトリックスu:\ n {u}")
print( "========================")
印刷(f "マトリックスs:\ n {s}")
print( "========================")
印刷(f "マトリックスV:\ n {v}")
ログイン後にコピー

出力:

稲妻のガイド

線形方程式の解決システム

線形方程式ax = bのシステムを解くには、 `jnp.linalg.solve()`を使用します。ここで、aは正方行列であり、bは同じ数の行のベクトルまたはマトリックスです。

 #線形方程式の解決システム
JAX.numpyをJNPとしてインポートします

a = jnp.array([[2.0、1.0]、[1.0、3.0]])
b = jnp.array([5.0、6.0])
x = jnp.linalg.solve(a、b)

印刷(f "xの値:{x}")
ログイン後にコピー

出力:

 Xの値:[1.8 1.4]
ログイン後にコピー

マトリックス関数の勾配を計算します

Jaxの自動分化を使用して、マトリックスに関してスカラー関数の勾配を計算できます。
以下の関数とxの値の勾配を計算します

関数

稲妻のガイド

 #マトリックス関数の勾配を計算します
Jaxをインポートします
JAX.numpyをJNPとしてインポートします


def matrix_function(x):
    jnp.sum(jnp.sin(x)x ** 2)を返す


#関数の卒業を計算します
grad_f = jax.grad(matrix_function)

x = jnp.array([[1.0、2.0]、[3.0、4.0]])
gradient = grad_f(x)

印刷(f "マトリックスx:\ n {x}")
print( "========================")
印刷(f "gradient of matrix_function:\ n {gradient}")
ログイン後にコピー

出力:

稲妻のガイド

数値コンピューティング、機械学習、物理学の計算で使用されるJAXのこれらの最も有用な機能。あなたが探検するためにもっとたくさんの残りがあります。

JAXを使用した科学的コンピューティング

Jaxの科学コンピューティングのための強力なライブラリであるJaxは、JITコンパイル、自動分化、ベクトル化、並列化、GPU-TPU加速などの事前機能に最適です。 JAXの高性能コンピューティングをサポートする能力により、物理シミュレーション、機械学習、最適化、数値分析など、幅広い科学的アプリケーションに適しています。

このセクションでは、最適化の問題を調べます。

最適化の問題

以下の最適化の問題を確認しましょう。

ステップ1:最小化する(または問題)関数を定義する

#最小化する関数を定義します(例えば、Rosenbrock関数)

@jit

DEF ROSENBROCK(X):

return sum(100.0 *(x [1:] -x [:-1] ** 2.0)** 2.0(1 -x [: -  1])** 2.0)
ログイン後にコピー

ここでは、ローゼンブロック関数が定義されています。これは、最適化の一般的なテストの問題です。この関数は、アレイxを入力として取得し、関数のグローバル最小値からxがどれだけ遠いかを表すバリエを計算します。 @jitデコレーターは、Jut-in-timeコンピレーションを有効にするために使用されます。これにより、CPUとGPUで効率的に実行するように関数をコンパイルすることで計算が高速化されます。

ステップ2:勾配降下ステップの実装

#勾配降下最適化

@jit

def gradient_descent_step(x、Learning_rate):

RETURN X -Learning_rate * grad(rosenbrock)(x)
ログイン後にコピー

この関数は、勾配降下最適化の単一のステップを実行します。 Rosenbrock関数の勾配は、Grad(Rosenbrock)(x)を使用して計算されます。これは、xを点在させる微分を提供します。 Xの新しい値は、earning_rateによってスケーリングされた勾配を減算によって更新されます。@jitは以前と同じことをしています。

ステップ3:最適化ループの実行

# 最適化する
x = jnp.array([0.0、0.0])#開始点

Learning_rate = 0.001

範囲のIの場合(2000):

x = gradient_descent_step(x、Learning_rate)

Iの場合、100 == 0の場合:

print(f "step {i}、value:{rosenbrock(x):。4f}")
ログイン後にコピー

最適化ループは、開始点Xを初期化し、勾配降下の1000回の反復を実行します。各反復では、現在の勾配に基づいてGradient_Descent_Step関数が更新されます。 100ステップごとに、Xでのローゼンブロック関数の現在のステップ数と値が印刷されているため、最適化の進行が得られます。

出力:

稲妻のガイド

JAXで実際の物理学の問題を解決します

物理システムに、摩擦摩擦、車両の衝撃吸収体、または電気回路での振動を伴う質量噴射システムのようなものをモデル化する減衰高調波発振器の動きをシミュレートします。いいじゃないですか?やりましょう。

ステップ1:パラメーター定義

Jaxをインポートします
JAX.numpyをJNPとしてインポートします


#パラメーターを定義します
質量= 1.0#オブジェクトの質量(kg)
減衰= 0.1#減衰係数(kg/s)
spring_constant = 1.0#spring constant(n/m)

#時間ステップと合計時間を定義します
DT = 0.01#時間ステップ
num_steps = 3000#ステップ数
ログイン後にコピー

質量、減衰係数、およびスプリング定数が定義されています。これらは、減衰高調波発振器の物理的特性を決定します。

ステップ2:ODE定義

#ODESシステムを定義します
def damped_harmonic_oscillator(状態、t):
    "" "減衰した高調波発振器の誘導体を計算します。

    状態:位置と速度を含む配列[x、v]
    T:時間(この自律システムでは使用されていません)
    "" "
    x、v =状態
    dxdt = v
    dvdt = -damping / mass * v -spring_constant / mass * x
    jnp.arrayを返す([dxdt、dvdt])
ログイン後にコピー

減衰した高調波発振器関数は、動的システムを表す発振器の位置と速度の導関数を定義します。

ステップ3:オイラーの方法

#オイラーの方法を使用してオードを解決します
def euler_step(状態、t、dt):
    "" "オイラーの方法の1つのステップを実行します。" ""
    derivatives = damped_harmonic_oscillator(state、t)
    状態派生物 * dtを返します
ログイン後にコピー

ODEを解くために単純な数値方法が使用されます。現在の状態と微分に基づいて、次回のステップで状態に近似します。

ステップ4:時間進化ループ

#初期状態:[位置、速度]
initial_state = jnp.array([1.0、0.0])#x = 1、v = 0の質量で開始

#時間の進化
states = [initial_state]
時間= 0.0
ステップインレンジ(num_steps):
    next_state = euler_step(stathes [-1]、time、dt)
    states.append(next_state)
    時間= dt

#分析のために状態のリストをJAXアレイに変換します
States = JNP.Stack(STATES)
ログイン後にコピー

ループは、指定された時間ステップ数を繰り返し、Eulerの方法を使用して各ステップで状態を更新します。

出力:

稲妻のガイド

ステップ5:結果をプロットします

最後に、結果をプロットして、減衰高調波発振器の挙動を視覚化できます。

 #結果のプロット
pltとしてmatplotlib.pyplotをインポートします

plt.style.use( "ggplot")

位置= states [:, 0]
velocities = states [:, 1]
time_points = jnp.arange(0、(num_steps 1) * dt、dt)

plt.figure(figsize =(12、6))
plt.subplot(2、1、1)
plt.plot(time_points、positions、label = "position")
plt.xlabel( "time(s)")
plt.ylabel( "position(m)")
plt.legend()

plt.subplot(2、1、2)
plt.plot(time_points、velocities、label = "velocity"、color = "orange")
plt.xlabel( "time(s)")
plt.ylabel( "velocity(m/s)")
plt.legend()

plt.tight_layout()
plt.show()
ログイン後にコピー

出力:

稲妻のガイド

私はあなたがNeural NetworkをJAXでどのように構築できるかを見たいと思っていることを知っています。それで、それに深く飛び込みましょう。

ここでは、値が徐々に最小化されていることがわかります。

JAXを使用したニューラルネットワークの構築

Jaxは、高性能の数値コンピューティングとNumpyのような構文の使用を容易にする強力なライブラリです。このセクションでは、JAXを使用してニューラルネットワークを構築するプロセスをガイドし、自動分化とジャストインタイムコンパイルの高度な機能を活用してパフォーマンスを最適化します。

ステップ1:ライブラリのインポート

ニューラルネットワークの構築に飛び込む前に、必要なライブラリをインポートする必要があります。 Jaxは、効率的な数値計算を作成するための一連のツールを提供しますが、追加のライブラリは結果の最適化と視覚化を支援します。

 Jaxをインポートします
JAX.numpyをJNPとしてインポートします
JAXインポートグレード、JITから
Jax.random Import prngkeyから、通常
Optax#Jaxの最適化ライブラリをインポートします
pltとしてmatplotlib.pyplotをインポートします
ログイン後にコピー

ステップ2:モデルレイヤーの作成

効果的なモデルレイヤーの作成は、ニューラルネットワークのアーキテクチャを定義する上で重要です。このステップでは、高密度の層のパラメーターを初期化し、モデルが明確に定義された重みとバイアスで効果的な学習のために始まるようにします。

 def init_layer_params(key、n_in、n_out):
    "" "単一の密度層のパラメーターを初期化" ""
    key_w、key_b = jax.random.split(key)
    #彼は初期化
    w = normal(key_w、(n_in、n_out)) * jnp.sqrt(2.0 / n_in)  
    b =通常(key_b、(n_out、)) * 0.1
    return(w、b)
    
def relu(x):
    "" "Relu Activation Function" ""
    jnp.maximum(0、x)を返す
    
ログイン後にコピー
  • 初期化関数:init_layer_paramsは、重量の初期化とバイアスの小さな値を使用して、密な層の重み(w)とバイアス(b)を初期化します。彼または彼または彼の初期化は、Relu Activation Functionsのレイヤーに対してより良い動作をしています。シグモイドの活性化を伴う層に対してより良い動作をするXavier初期化など、他の一般的な初期化方法があります。
  • アクティベーション関数: relu関数は、ネガティブ値をゼロに設定する入力にreluアクティベーション関数を適用します。

ステップ3:フォワードパスの定義

フォワードパスは、ネットワークを介して入力データがどのように流れて出力を生成するかを決定するため、ニューラルネットワークの基礎です。ここでは、初期化されたレイヤーを介して入力データに変換を適用することにより、モデルの出力を計算する方法を定義します。

 DEFフォワード(PARAMS、X):
    "" "2層ニューラルネットワークのフォワードパス" "
    (W1、b1)、(w2、b2)= params
    #最初のレイヤー
    h1 = relu(jnp.dot(x、w1)b1)
    #出力層
    logits = jnp.dot(h1、w2)b2
    ロジットを返します
    
ログイン後にコピー
  • フォワードパス: 2層ニューラルネットワークを通過するフォワードパスを実行し、線形変換に続いてReluおよびその他の線形変換を適用することにより、出力(ロジット)を計算します。

S TEP4:損失関数の定義

明確に定義された損失関数は、モデルのトレーニングをガイドするために不可欠です。このステップでは、平均二乗誤差(MSE)損失関数を実装します。これは、予測される出力がターゲット値とどれだけうまく一致するかを測定し、モデルが効果的に学習できるようにします。

 def loss_fn(params、x、y):
    "" "平面誤差喪失" ""
    pred = forward(params、x)
    jnp.meanを返す((pred -y)** 2)
ログイン後にコピー
  • 損失関数: loss_fnは、予測されたロジットとターゲットラベル(y)の間の平均二乗誤差(MSE)損失を計算します。

ステップ5:モデルの初期化

モデルアーキテクチャと損失関数が定義されているため、モデルの初期化に頼ります。このステップでは、ニューラルネットワークのパラメーターを設定し、各レイヤーがランダムだが適切にスケーリングされた重みとバイアスでトレーニングプロセスを開始する準備ができていることを保証します。

 def init_model(rng_key、input_dim、hidden_​​dim、output_dim):
    key1、key2 = jax.random.split(rng_key)
    params = [
        init_layer_params(key1、input_dim、hidden_​​dim)、
        init_layer_params(key2、hidden_​​dim、output_dim)、
    ]
    パラメージを返します
    
ログイン後にコピー
  • モデルの初期化: init_modelニューラルネットワークの両方の層の重みとバイアスを初期化します。各レイヤーの2つの個別のランダムキーを使用します;のパラメーター初期化。

ステップ6:トレーニングステップ

ニューラルネットワークのトレーニングには、損失関数の計算された勾配に基づいて、そのパラメーターの反復更新が含まれます。このステップでは、これらの更新を効率的に適用するトレーニング関数を実装し、モデルが複数のエポック上のデータから学習できるようにします。

 @jit
def train_step(params、opt_state、x_batch、y_batch):
    loss、grads = jax.value_and_grad(loss_fn)(params、x_batch、y_batch)
    更新、opt_state = optimizer.update(grads、opt_state)
    params = optax.apply_updates(params、updates)
    PARAMS、OPT_STATE、損失を返します
ログイン後にコピー
  • トレーニングステップ: Train_Step関数は、単一の勾配降下更新を実行します。
  • value_and_gradを使用して損失と勾配を計算します。これは、関数値と他の勾配の両方を計算します。
  • オプティマイザーの更新が計算され、モデルパラメーターがそれに応じて更新されます。
  • パフォーマンスのためにJITコンパイルされています。

ステップ7:データとトレーニングループ

モデルを効果的にトレーニングするには、適切なデータを生成し、トレーニングループを実装する必要があります。このセクションでは、この例の合成データを作成する方法と、複数のバッチとエポックでトレーニングプロセスを管理する方法について説明します。

 #いくつかの例データを生成します
key = prngkey(0)
X_DATA = Normal(key、(1000、10))#1000サンプル、10の機能
y_data = jnp.sum(x_data ** 2、axis = 1、keepdims = true)#単純な非線形関数

#モデルとオプティマイザーを初期化します
params = init_model(key、input_dim = 10、hidden_​​dim = 32、output_dim = 1)
optimizer = optax.adam(Learning_rate = 0.001)
opt_state = optimizer.init(params)

#トレーニングループ
batch_size = 32
num_epochs = 100
num_batches = x_data.shape [0] // batch_size

#エポックと損失値を保存する配列
epoch_array = []
loss_array = []

範囲のエポックの場合(num_epochs):
    epoch_loss = 0.0
    範囲のバッチ(num_batches)の場合:
        idx = jax.random.permutation(key、batch_size)
        x_batch = x_data [idx]
        y_batch = y_data [idx]
        params、opt_state、loss = train_step(params、opt_state、x_batch、y_batch)
        epoch_loss = loss

    #エポックの平均損失を保存します
    avg_loss = epoch_loss / num_batches
    epoch_array.append(エポック)
    loss_array.append(avg_loss)

    エポック%10 == 0の場合:
        print(f "epoch {epoch}、loss:{avg_loss:.4f}")
ログイン後にコピー
  • データ生成:ランダムトレーニングデータ(X_DATA)および対応するターゲット(Y_DATA)値が作成されます。モデルとオプティマイザーの初期化:モデルパラメーターとオプティマイザー状態が初期化されます。
  • トレーニングループ:ネットワークは、ミニバッチ勾配降下を使用して、指定された数のエポックでトレーニングされます。
  • トレーニングループはバッチを繰り返し、TRAIN_STEP機能を使用してグラデーションアップデートを実行します。エポックあたりの平均損失は計算および保存されます。エポック数と平均損失を印刷します。

ステップ8:結果をプロットします

トレーニング結果を視覚化することは、ニューラルネットワークのパフォーマンスを理解するための鍵です。このステップでは、モデルがどれだけ学習しているかを観察し、トレーニングプロセスの潜在的な問題を特定するために、エポック上のトレーニング損失をプロットします。

 #結果をプロットします
plt.plot(epoch_array、loss_array、label = "トレーニング損失")
plt.xlabel( "epoch")
plt.ylabel( "loss")
plt.title(「エポック上のトレーニング損失」)
plt.legend()
plt.show()
ログイン後にコピー

これらの例は、JAXが高性能とクリーンで読みやすいコードをどのように組み合わせるかを示しています。 JAXによって奨励されている機能的なプログラミングスタイルにより、操作を簡単に作成し、変換を適用できます。

出力:

稲妻のガイド

プロット:

稲妻のガイド

これらの例は、JAXが高性能とクリーンで読みやすいコードをどのように組み合わせるかを示しています。 JAXによって奨励されている機能的なプログラミングスタイルにより、操作を簡単に作成し、変換を適用できます。

ベストプラクティスとヒント

ニューラルネットワークの構築において、ベストプラクティスを順守することで、パフォーマンスと保守性が大幅に向上する可能性があります。 This section will discuss various strategies and tips for optimizing your code and improving the overall efficiency of your JAX-based models.

パフォーマンスの最適化

Optimizing performance is essential when working with JAX, as it enables us to fully leverage its capabilities. Here, we will explore different techniques for improving the efficiency of our JAX functions, ensuring that our models run as quickly as possible without sacrificing readability.

JIT Compilation Best Practices

Just-In-Time (JIT) compilation is one of the standout features of JAX, enabling faster execution by compiling functions at runtime. This section will outline best practices for effectively using JIT compilation, helping you avoid common pitfalls and maximize the performance of your code.

Bad Function

 import jax
import jax.numpy as jnp
from jax import jit
from jax import lax


# BAD: Dynamic Python control flow inside JIT
@jit
def bad_function(x, n):
    for i in range(n): # Python loop - will be unrolled
        x = x 1
    return x
    
    
print("===========================")
# print(bad_function(1, 1000)) # does not work
    
ログイン後にコピー

This function uses a standard Python loop to iterate n times, incrementing the of x by 1 on each iteration. When compiled with jit, JAX unrolls the loop, which can be inefficient, especially for large n. This approach does not fully leverage JAX's capabilities for performance.

Good Function

 # GOOD: Use JAX-native operations
@jit
def good_function(x, n):
    return xn # Vectorized operation


print("===========================")
print(good_function(1, 1000))
ログイン後にコピー

This function does the same operation, but it uses a vectorized operation (xn) instead of a loop. This approach is much more efficient because JAX can better optimize the computation when expressed as a single vectorized operation.

Best Function

 # BETTER: Use scan for loops


@jit
def best_function(x, n):
    def body_fun(i, val):
        return val 1

    return lax.fori_loop(0, n, body_fun, x)


print("===========================")
print(best_function(1, 1000))
ログイン後にコピー

This approach uses `jax.lax.fori_loop`, which is a JAX-native way to implement loops efficiently. The `lax.fori_loop` performs the same increment operation as the previous function, but it does so using a compiled loop structure. The body_fn function defines the operation for each iteration, and `lax.fori_loop` executes it from o to n. This method is more efficient than unrolling loops and is especially suitable for cases where the number of iterations isn't known ahead of time.

Output :

 ===========================
===========================
1001
===========================
1001
ログイン後にコピー

The code demonstrates different approaches to handling loops and control flow within JAX's jit-complied functions.

Memory Management

Efficient memory management is crucial in any computational framework, especially when dealing with large datasets or complex models. This section will discuss common pitfalls in memory allocation and provide strategies for optimizing memory usage in JAX.

Inefficient Memory Management

 # BAD: Creating large temporary arrays
@jit
def inefficient_function(x):
    temp1 = jnp.power(x, 2) # Temporary array
    temp2 = jnp.sin(temp1) # Another temporary
    return jnp.sum(temp2)
ログイン後にコピー

inefficient_function(x): This function creates multiple intermediate arrays, temp1, temp1 and finally the sum of the elements in temp2. Creating these temporary arrays can be inefficient because each step allocates memory and incurs computational overhead, leading to slower execution and higher memory usage.

Efficient Memory Management

 # GOOD: Combining operations
@jit
def efficient_function(x):
    return jnp.sum(jnp.sin(jnp.power(x, 2))) # Single operation
ログイン後にコピー

This version combines all operations into a single line of code. It computes the sine of squared elements of x directly and sums the results. By combining the operation, it avoids creating intermediate arrays, reducing memory footprints and improving performance.

Test Code

 x = jnp.array([1, 2, 3])
print(x)
print(inefficient_function(x))
print(efficient_function(x))
ログイン後にコピー

出力:

 [1 2 3]
0.49678695
0.49678695
ログイン後にコピー

The efficient version leverages JAX's ability to optimize the computation graph, making the code faster and more memory-efficient by minimizing temporary array creation.

Debugging Strategies

Debugging is an essential part of the development process, especially in complex numerical computations. In this section, we will discuss effective debugging strategies specific to JAX, enabling you to identify and resolve issues quickly.

Using print inside JIT for Debugging

The code shows techniques for debugging within JAX, particularly when using JIT-compiled functions.

 import jax.numpy as jnp
from jax import debug


@jit
def debug_function(x):
    # Use debug.print instead of print inside JIT
    debug.print("Shape of x: {}", x.shape)
    y = jnp.sum(x)
    debug.print("Sum: {}", y)
    yを返します
ログイン後にコピー
# For more complex debugging, break out of JIT
def debug_values(x):
    print("Input:", x)
    result = debug_function(x)
    print("Output:", result)
    return result
    
ログイン後にコピー
  • debug_function(x): This function shows how to use debug.print() for debugging inside a jit compiled function. In JAX, regular Python print statements are not allowed inside JIT due to compilation restrictions, so debug.print() is used instead.
  • It prints the shape of the input array x using debug.print()
  • After computing the sum of the elements of x, it prints the resulting sum using debug.print()
  • Finally, the function returns the computed sum y.
  • debug_values(x) function serves as a higher-level debugging approach, breaking out of the JIT context for more complex debugging. It first prints the inputs x using regular print statement. Then calls debug_function(x) to compute the result and finally prints the output before returning the results.

出力:

 print("===========================")
print(debug_function(jnp.array([1, 2, 3])))
print("===========================")
print(debug_values(jnp.array([1, 2, 3])))
ログイン後にコピー

稲妻のガイド

This approach allows for a combination of in-JIT debugging with debug.print() and more detailed debugging outside of JIT using standard Python print statements.

Common Patterns and Idioms in JAX

Finally, we will explore common patterns and idioms in JAX that can help streamline your coding process and improve efficiency. Familiarizing yourself with these practices will aid in developing more robust and performant JAX applications.

Device Memory Management for Processing Large Datasets

 # 1. Device Memory Management
def process_large_data(data):
    # Process in chunks to manage memory
    chunk_size = 100
    results = []

    for i in range(0, len(data), chunk_size):
        chunk = data[i : i chunk_size]
        chunk_result = jit(process_chunk)(chunk)
        results.append(chunk_result)

    return jnp.concatenate(results)


def process_chunk(chunk):
    chunk_temp = jnp.sqrt(chunk)
    return chunk_temp
ログイン後にコピー

This function processes large datasets in chunks to avoid overwhelming device memory.

It sets chunk_size to 100 and iterates over the data increments of the chunk size, processing each chunk separately.

For each chunk, the function uses jit(process_chunk) to JIT-compile the processing operation, which improves performance by compiling it ahead of time.

The result of each chunk is concatenated into a single array using jnp.concatenated(result) to form a single list.

出力:

 print("===========================")
data = jnp.arange(10000)
print(data.shape)

print("===========================")
印刷(データ)

print("===========================")
print(process_large_data(data))
ログイン後にコピー

稲妻のガイド

Handling Random Seed for Reproducibility and Better Data Generation

The function create_traing_state() demonstrates managing random number generators (RNGs) in JAX, which is essential for reproducibility and consistent results.

 # 2. Handling Random Seeds
def create_training_state(rng):
    # Split RNG for different uses
    rng, init_rng = jax.random.split(rng)
    params = init_network(init_rng)

    return params, rng # Return new RNG for next use
    
ログイン後にコピー

It starts with an initial RNG (rng) and splits it into two new RNGs using jax.random.split(). Split RNGs perform different tasks: `init_rng` initializes network parameters, and the updated RNG returns for subsequent operations.

The function returns both the initialized network parameters and the new RNG for further use, ensuring proper handling of random states across different steps.

Now test the code using mock data

 def init_network(rng):
    # Initialize network parameters
    戻る {
        "w1": jax.random.normal(rng, (784, 256)),
        "b1": jax.random.normal(rng, (256,)),
        "w2": jax.random.normal(rng, (256, 10)),
        "b2": jax.random.normal(rng, (10,)),
    }


print("===========================")

key = jax.random.PRNGKey(0)
params, rng = create_training_state(key)


print(f"Random number generator: {rng}")

print(params.keys())

print("===========================")


print("===========================")
print(f"Network parameters shape: {params['w1'].shape}")

print("===========================")
print(f"Network parameters shape: {params['b1'].shape}")
print("===========================")
print(f"Network parameters shape: {params['w2'].shape}")

print("===========================")
print(f"Network parameters shape: {params['b2'].shape}")


print("===========================")
print(f"Network parameters: {params}")
ログイン後にコピー

出力:

稲妻のガイド

稲妻のガイド

Using Static Arguments in JIT

 def g(x, n):
    i = 0
    while i <p><strong>出力:</strong></p><pre class="brush:php;toolbar:false"> 30
ログイン後にコピー

You can use a static argument if JIT compiles the function with the same arguments each time. This can be useful for the performance optimization of JAX functions.

 from functools import partial


@partial(jax.jit, static_argnames=["n"])
def g_jit_decorated(x, n):
    i = 0
    while i <p>If You want to use static arguments in JIT as a decorator you can use jit inside of functools. partial() function.</p><p><strong>出力:</strong></p><pre class="brush:php;toolbar:false"> 30
ログイン後にコピー

Now, we have learned and dived deep into many exciting concepts and tricks in JAX and overall programming style.

次は何ですか?

  • Experiment with Examples: Try to modify the code examples to learn more about JAX. Build a small project for a better understanding of JAX's transformations and APIs. Implement classical Machine Learning algorithms with JAX such as Logistic Regression, Support Vector Machine, and more.
  • Explore Advanced Topics : Parallel computing with pmap, Custom JAX transformations, Integration with other frameworks

All code used in this article is here

結論

JAX is a powerful tool that provides a wide range of capabilities for machine learning, Deep Learning, and scientific computing. Start with basics, experimenting, and get help from JAX's beautiful documentation and community. There are so many things to learn and it will not be learned by just reading others' code you have to do it on your own. So, start creating a small project today in JAX. The key is to Keep Going, learn on the way.

キーテイクアウト

  • Familiar NumPY-like interface and APIs make learning JAX easy for beginners. Most NumPY code works with minimal modifications.
  • JAX encourages clean functional programming patterns that lead to cleaner, more maintainable code and upgradation. But If developers want JAX fully compatible with Object Oriented paradigm.
  • What makes JAX's features so powerful is automatic differentiation and JAX's JIT compilation, which makes it efficient for large-scale data processing.
  • JAX excels in scientific computing, optimization, neural networks, simulation, and machine learning which makes developer easy to use on their respective project.

よくある質問

Q1。 What makes JAX different from NumPY?

A. Although JAX feels like NumPy, it adds automatic differentiation, JIT compilation, and GPU/TPU support.

Q2。 Do I need a GPU to use JAX?

A. In a single word big NO, though having a GPU can significantly speed up computation for larger data.

Q3。 Is JAX a good alternative to NumPy?

A. Yes, You can use JAX as an alternative to NumPy, though JAX's APIs look familiar to NumPy JAX is more powerful if you use JAX's features well.

Q4。 Can I use my existing NumPy code with JAX?

A. Most NumPy code can be adapted to JAX with minimal changes. Usually just changing import numpy as np to import jax.numpy as jnp.

Q5。 Is JAX harder to learn than NumPy?

A. The basics are just as easy as NumPy! Tell me one thing, will you find it hard after reading the above article and hands-on? I answered it for you. YES hard. Every framework, language, libraries is hard not because it is hard by design but because we don't give much time to explore it. Give it time to get your hand dirty it will be easier day by day.

この記事に示されているメディアは、Analytics Vidhyaが所有しておらず、著者の裁量で使用されています。

以上が稲妻のガイドの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート