首頁 > 科技週邊 > 人工智慧 > 閃電般的JAX指南

閃電般的JAX指南

Jennifer Aniston
發布: 2025-03-19 11:21:11
原創
630 人瀏覽過

嘿,Python愛好者!您是否希望以超音速速度運行Numpy代碼?認識JAX!。您在機器學習,深度學習和數字計算過程中的新最好的朋友。將其視為具有超能力的Numpy。它可以自動處理梯度,編譯代碼以使用JIT快速運行,甚至可以在GPU和TPU上運行而不會破壞汗水。無論您是構建神經網絡,處理科學數據,調整變壓器模型,還是只是試圖加快計算速度,JAX都會支持您。讓我們深入研究,看看是什麼使Jax如此特別。

本指南詳細介紹了JAX及其生態系統。

學習目標

  • 解釋JAX的核心原理及其與Numpy的不同。
  • 應用JAX的三個關鍵轉換來優化Python代碼。將Numpy操作轉換為有效的JAX實現。
  • 在JAX代碼中識別並修復常見的性能瓶頸。在避免典型的陷阱的同時,正確實施JIT編譯。
  • 使用JAX從頭開始構建和訓練神經網絡。使用JAX的功能方法實施常見的機器學習操作。
  • 使用JAX的自動分化解決優化問題。執行有效的矩陣操作和數值計算。
  • 將有效的調試策略應用於特定問題。實施用於大規模計算的內存效率模式。

本文作為數據科學博客馬拉鬆的一部分發表

目錄

  • 什麼是JAX?
  • Jax為什麼脫穎而出?
  • Jax入門
  • 為什麼要學習JAX?
  • 基本的JAX轉換
  • 用JAX構建神經網絡
  • 最佳實踐和技巧
  • 性能優化
  • 調試策略
  • jax中的常見模式和成語
  • 接下來是什麼?
  • 結論
  • 常見問題

什麼是JAX?

根據官方文檔,JAX是用於加速陣列計算和程序轉換的Python庫,專為高性能數值計算和大規模機器學習而設計。因此,JAX本質上是類固醇上的Numpy,它將熟悉的Numpy風格操作與自動分化和硬件加速相結合。可以將其視為獲得三個世界中最好的。

  • Numpy的優雅語法和陣列操作
  • Pytorch喜歡自動分化能力
  • XLA的(加速線性代數)用於硬件加速和彙編優點。

Jax為什麼脫穎而出?

設定JAX的是其轉變。這些是可以修改您的Python代碼的強大功能:

  • JIT :快速執行的及時彙編
  • 畢業生:計算梯度的自動差異化
  • VMAP :自動進行批處理處理

這是一個快速的外觀:

導入jax.numpy作為jnp
來自Jax Import Grad,Jit
#定義一個簡單的功能
@Jit#用編譯加快速度
def square_sum(x):
返回JNP.Sum(JNP.Square(x))
#自動獲取其梯度功能
gradient_fn = grad(square_sum)
#嘗試一下
x = jnp.Array([1.0,2.0,3.0])
打印(f“漸變:{gradient_fn(x)}”)
登入後複製

輸出:

漸變:[2。 4。6。]]
登入後複製

Jax入門

在下面,我們將遵循一些步驟以開始使用JAX。

步驟1:安裝

設置jax非常適合僅使用CPU。您可以使用JAX文檔以獲取更多信息。

步驟2:為項目創造環境

為您的項目創建CONDA環境

#為JAX創建Conda Env
$ conda create -name jaxdev python = 3.11

#激活Env
$ conda激活jaxdev

#創建一個項目dir name jax101
$ MKDIR JAX101

#進入DIR
$ CD JAX101
登入後複製

步驟3:安裝JAX

在新創建的環境中安裝JAX

 #僅適用於CPU
PIP安裝 - 升級PIP
PIP安裝 - 升級“ JAX”

#對於GPU
PIP安裝 - 升級PIP
PIP安裝 - 升級“ JAX [CUDA12]”
登入後複製

現在,您準備深入研究真實的事物。在實用編碼上弄髒您的手之前,讓我們學習一些新概念。我將首先解釋這些概念,然後我們將共同編碼以了解實際的觀點。

首先,順其自然,為什麼我們再次學習新圖書館?我將在本指南中以盡可能簡單的方式回答這個問題。

為什麼要學習JAX?

將JAX視為電動工具。儘管Numpy就像是可靠的手鋸,但Jax就像現代的電鋸。它需要更多的步驟和知識,但是對於密集的計算任務而言,性能好處是值得的。

  • 性能:JAX代碼的運行速度明顯比Pure Python或Numpy代碼快得多,尤其是在GPU和TPU上
  • 靈活性:不僅用於機器學習 - JAX在科學計算,優化和仿真方面表現出色。
  • 現代方法: JAX鼓勵功能編程模式,從而導致更清潔,更可維護的代碼。

在下一節中,我們將深入研究Jax的轉換,從JIT彙編開始。這些轉變是賦予其超級大國的Jax的原因,而理解它們是有效利用JAX的關鍵。

基本的JAX轉換

JAX的轉換是真正將其與數值計算庫(例如Numpy或Scipy)區分開來的。讓我們探索每個人,看看它們如何增強您的代碼。

JIT或即時編譯

Just-Amper Ampilation通過在運行時(而不是提前編制程序)來優化代碼執行。

JAX如何工作?

在JAX中,JAX.JIT將Python函數轉換為JIT編譯版本。用 @jax.jit裝飾功能可捕獲其執行圖,優化它並使用XLA對其進行編譯。然後,編譯的版本執行,提供了重大的加速,尤其是對於重複的功能調用。

這是您可以嘗試的方法。

導入jax.numpy作為jnp
來自JAX Import Jit
進口時間


#計算密集型功能
def slow_function(x):
    對於_範圍(1000):
        x = jnp.sin(x)jnp.cos(x)
    返回x


#與JIT相同的功能
@Jit
def fast_function(x):
    對於_範圍(1000):
        x = jnp.sin(x)jnp.cos(x)
    返回x
登入後複製

這是相同的功能,一個只是一個普通的python彙編過程,另一個函數用作JAX的JIT彙編過程。它將計算正弦和余弦函數的1000個數據點總和。我們將使用時間比較性能。

 #比較性能
X = JNP.Arange(1000)

#熱身吉特
fast_function(x)#第一個調用編譯功能

#時間比較
start = time.time()
slow_result = slow_function(x)
打印(f“沒有jit:{time.time() - 開始:.4f}秒”)

start = time.time()
fast_result = fast_function(x)
打印(f with jit:{time.time() - 開始:.4f}秒”)
登入後複製

結果將使您驚訝。 JIT彙編比正常彙編快333倍。這就像將自行車與Buggati Chiron進行比較。

輸出:

沒有JIT:0.0330秒
與JIT:0.0010秒
登入後複製

JIT可以為您提供超快速的執行力,但您必須正確使用它,否則就像在沒有提供超級跑車設施的泥濘鄉村道路上駕駛布加迪一樣。

常見的jit陷阱

JIT在靜態形狀和類型中最有效。避免使用取決於數組值的python循環和條件。 JIT不適用於動態陣列。

 #不好 - 使用Python控制流
@Jit
def bad_function(x):
    如果x [0]> 0:#這與JIT無法正常工作
        返回x
    返回-x


#print(bad_function(jnp.array([1,2,3])))


#好 - 使用jax控制流
@Jit
def good_function(x):
    返回jnp.Where(x [0]> 0,x,-x)#jax -native條件


打印(good_function(JNP.Array([1,2,3]))))))
登入後複製

輸出:

閃電般的JAX指南

這意味著bad_function是不好的,因為JIT在計算過程中不在X的值中。

輸出:

 [1 2 3]
登入後複製

局限性和考慮因素

  • 彙編開銷:第一次執行JIT編譯功能時,由於編譯而有一些開銷。彙編成本可能超過了小型功能的性能優勢,或者只有一次。
  • 動態python功能: JAX的JIT要求功能為“靜態” 。動態控制流,例如基於Python循環的更改形狀或值,在編譯代碼中不支持。 JAX提供了諸如`jax.lax.cond`和jax.lax.scan`處理動態控制流程的替代方案。

自動差異化

自動分化或Autodiff是一種計算技術,用於準確有效地計算功能的導數。它在優化機器學習模型中起著至關重要的作用,尤其是在訓練神經網絡中,該網絡用於更新模型參數。

閃電般的JAX指南

自動分化如何在JAX中起作用?

Autodiff通過將微積分的鏈規則應用於更簡單的功能,計算這些子功能的派生函數,然後結合結果。它在函數執行過程中記錄每個操作以構建計算圖,然後將其用於自動計算衍生物。

自動陷阱有兩種主要模式:

  • 正向模式:單個正向通過計算圖中計算衍生物,對於具有少數參數的函數有效。
  • 反向模式:計算單個向後通過計算圖的衍生物,對於具有大量參數的函數有效。

閃電般的JAX指南

JAX自動差異的主要功能

  • 梯度計算(jax.grad): `jax.grad`計算其輸入的縮放器輸出函數的導數。對於具有多個輸入的函數,可以獲得部分導數。
  • 高階導數(jax.jacobian,jax.hessian): JAX支持高階衍生物的計算,例如Jacobians和Hessains,使其適合於高級優化和物理模擬。
  • 與其他JAX轉換的合成性: JAX中的AutoDiff無縫集成與其他轉換,例如jax.jit`和jax.vmap`允許進行有效且可擴展的計算。
  • 反向模式分化(反向傳播): JAX的自動陷阱對縮放器輸出功能使用反向模式分化,這對於深度學習任務非常有效。
導入jax.numpy作為jnp
從jax進口畢業,value_and_grad


#定義一個簡單的神經網絡層
def層(params,x):
    重量,偏見=參數
    返回jnp.dot(x,重量)偏差


#定義標量值損耗函數
def loss_fn(params,x):
    輸出=圖層(參數,x)
    返回JNP.SUM(輸出)#還原為標量


#獲得輸出和梯度
layer_grad = grad(loss_fn,argnums = 0)#相對於參數的漸變
layer_value_and_grad = value_and_grad(loss_fn,argnums = 0)#值和漸變

#示例用法
key = jax.random.prngkey(0)
x = jax.random.normal(key,(3,4))
重量= jax.random.normal(key,(4,2))
bias = jax.random.normal(key,(2,))

#計算梯度
grads = layer_grad((重量,偏見),x)
輸出,grads = layer_value_and_grad(((重量,偏見),x)

#多個導數很容易
twice_grad = grad(grad(jnp.sin))
X = JNP.Array(2.0)
print(f“ sin的第二個衍生物在x = 2:{twice_grad(x)}”)
登入後複製

輸出:

 sin的第二個衍生物x = 2:-0.9092974066734314
登入後複製

JAX的有效性

  • 效率: JAX的自動差異由於與XLA的集成而高效,因此可以在機器代碼級別進行優化。
  • 合成性:結合不同變換的能力使JAX成為建立復雜的機器學習管道和神經網絡體系結構(例如CNN,RNN和Transformers)的強大工具。
  • 易用性: JAX的AutoDiff語法簡單而直觀,使用戶能夠計算漸變,而無需深入研究XLA和復雜庫API的詳細信息。

JAX矢量化映射

在JAX中,“ VMAP”是一個強大的函數,可以自動矢量化計算,從而可以在無需手動編寫循環的情況下將功能應用於批次的數據。它可以在陣列軸(或多個軸)上繪製函數,並並行評估它,從而可以顯著改善性能。

VMAP如何在JAX中起作用?

VMAP函數可自動化沿輸入陣列的指定軸將函數應用於每個元素的過程,同時保留計算的效率。它轉換給定功能以接受批處理輸入並以矢量化的方式執行計算。

VMAP不是使用顯式循環,而是通過在輸入軸上進行矢量進行並行執行操作。這利用了硬件執行SIMD(單個指令,多個數據)操作的功能,這可能會導致大幅加速。

VMAP的關鍵功能

  • 自動矢量化: VAMP自動化計算的批處理,使得在批處理維度上並行代碼在不更改原始功能邏輯的情況下簡單。
  • 與其他轉換的合成性:它可以與其他JAX轉換無縫地工作,例如Jax.grad用於分化和JAX.JIT,以進行即時彙編,從而可以進行高度優化和靈活的代碼。
  • 處理多個批次尺寸: VMAP支持在多個輸入陣列或軸上映射映射,使其用於各種用例,例如同時處理多維數據或多個變量。
導入jax.numpy作為jnp
來自JAX導入VMAP


#在單個輸入中起作用的功能
def single_input_fn(x):
    返回jnp.sin(x)jnp.cos(x)


#將其矢量化以在批處理
batch_fn = vmap(single_input_fn)

#比較性能
X = JNP.Arange(1000)

#沒有VMAP(使用列表理解)
result1 = jnp.Array(x In xi in xi])

#與vmap
結果2 = batch_fn(x)#快得多!


#矢量化多個參數
def兩_input_fn(x,y):
    返回x * jnp.sin(y)


#在兩個輸入上進行矢量化
vectorized_fn = vmap(tw_input_fn,in_axes =(0,0))

#或僅通過第一個輸入進行矢量化
partaly_vectorized_fn = vmap(tw_input_fn,in_axes =(0,none))


# 列印
打印(結果1.形)
打印(結果2.形狀)
打印(partaly_vectorized_fn(x,y).shape)
登入後複製

輸出:

 (1000,)
(1000,)
(1000,3)
登入後複製

VMAP在JAX中的有效性

  • 性能改進:通過對矢量化計算,VMAP可以通過利用現代硬件(例如GPU和TPU(張量處理單元))的並行處理能力來大大加快執行速度。
  • 清潔器代碼:它可以通過消除對手動循環的需求來更簡潔而可讀的代碼。
  • 與JAX和AUTODIFF:VMAP的兼容性可以與自動分化(JAX.Grad)結合使用,從而可以有效地計算衍生物而不是數據批次。

何時使用每個轉換

使用@Jit時:

  • 您的功能多次稱為具有相似輸入形狀的功能。
  • 該函數包含大量的數值計算。

使用畢業時:

  • 您需要衍生物進行優化。
  • 實施機器學習算法
  • 求解微分方程以進行模擬

使用vmap時:

  • 使用的數據批次。
  • 平行計算
  • 避免明確的循環

使用JAX的矩陣操作和線性代數

JAX為矩陣操作和線性代數提供了全面的支持,使其適合科學計算,機器學習和數值優化任務。 JAX的線性代數功能與諸如Numpy之類的庫中的功能相似,但具有其他功能,例如自動差異化和即時彙編,以進行優化的性能。

矩陣加法和減法

這些操作是相同形狀的元素矩陣進行的。

 #1矩陣加法和減法:

導入jax.numpy作為jnp

a = jnp.array([[[1,2],[3,4]])
b = jnp.Array([[[5,6],[7,8]])

#矩陣加法
C = AB
#矩陣減法
d = a -b

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
打印(f“矩陣B:\ n {b}”)
打印(“ ==========================
print(f“ ab:\ n {c}”的矩陣adtion”)
打印(“ ==========================
打印(f“ ab:\ n {d}的矩陣縮寫”)
登入後複製

輸出:

閃電般的JAX指南

矩陣乘法

JAX支持元素乘法和基於DOR產品的矩陣乘法。

 #元素乘法
e = a * b

#矩陣乘法(點產品)
f = jnp.dot(a,b)

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
打印(f“矩陣B:\ n {b}”)
打印(“ ==========================
print(f“*b:\ n {e}的元素乘法”)
打印(“ ==========================
print(f“ a*b:\ n {f}的矩陣乘法”)
登入後複製

輸出:

閃電般的JAX指南

基質轉置

可以使用`

 #矩陣
g = jnp.transpose(a)

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
print(f“ a:\ n {g}的矩陣轉置”)
登入後複製

輸出:

閃電般的JAX指南

矩陣逆

JAX使用jnp.linalg.inv()`提供矩陣反轉的功能

#矩陣倒置
h = jnp.linalg.inv(a)

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
print(f“ a:\ n {h}的矩陣反轉”)
登入後複製

輸出:

閃電般的JAX指南

矩陣決定因素

可以使用`jnp.linalg.det()``。

 #矩陣決定因素
det_a = jnp.linalg.det(a)

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
print(f“ a:\ n {det_a}”的矩陣決定因素”)
登入後複製

輸出:

閃電般的JAX指南

矩陣特徵值和特徵向量

您可以使用`jnp.linalg.eigh()計算矩陣的特徵值和特徵向量

#特徵值和特徵向量
導入jax.numpy作為jnp

a = jnp.array([[[1,2],[3,4]])
特徵值,特徵向量= jnp.linalg.eigh(a)

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
print(a:\ n {eigenvalues}的f“ egenvalues”)
打印(“ ==========================
print(a:\ n {eigenVectors}的f“ eigenVectors}”)
登入後複製

輸出:

閃電般的JAX指南

矩陣單數值分解

通過`jnp.linalg.svd`支持SVD,可用於降低維度和矩陣分解。

 #單數值分解(SVD)

導入jax.numpy作為jnp

a = jnp.array([[[1,2],[3,4]])
u,s,v = jnp.linalg.svd(a)

打印(f“矩陣A:\ n {a}”)
打印(“ ==========================
print(f“ matrix u:\ n {u}”)
打印(“ ==========================
打印(f“矩陣S:\ n {s}”)
打印(“ ==========================
打印(f“矩陣V:\ n {v}”)
登入後複製

輸出:

閃電般的JAX指南

線性方程的求解系統

為了求解線性方程式AX = B的系統,我們使用`jnp.linalg.solve()`,其中A是平方矩陣,B是相同數量的行的向量或矩陣。

 #線性方程的求解系統
導入jax.numpy作為jnp

a = jnp.array([[[2.0,1.0],[1.0,3.0]])
B = JNP.Array([[5.0,6.0])
x = jnp.linalg.solve(a,b)

打印(f“ x:{x}的值”)
登入後複製

輸出:

 x的值:[1.8 1.4]
登入後複製

計算矩陣函數的梯度

使用JAX的自動分化,您可以計算標量功能相對於矩陣的梯度。
我們將計算以下功能的梯度和x的值

功能

閃電般的JAX指南

 #計算矩陣函數的梯度
導入JAX
導入jax.numpy作為jnp


def matrix_function(x):
    返回JNP.SUM(JNP.SIN(X)X ** 2)


#計算功能的畢業
grad_f = jax.grad(matrix_function)

x = jnp.Array([[[1.0,2.0],[3.0,4.0]]))
漸變= grad_f(x)

打印(f“矩陣x:\ n {x}”)
打印(“ ==========================
打印(f“ matrix_function的梯度:\ n {漸變}”)
登入後複製

輸出:

閃電般的JAX指南

這些在數值計算,機器學習和物理計算中使用的JAX的最有用的功能。還有更多供您探索。

JAX的科學計算

JAX具有科學計算的強大庫,JAX最適合科學計算,用於其提前特徵,例如JIT彙編,自動分化,矢量化,並行化和GPU-TPU加速度。 JAX支持高性能計算的能力使其適用於廣泛的科學應用,包括物理模擬,機器學習,優化和數值分析。

我們將在本節中探討一個優化問題。

優化問題

讓我們瀏覽以下優化問題:

步驟1:定義最小化功能(或問題)

 #定義一個函數以最小化(例如,Rosenbrock函數)

@Jit

Def Rosenbrock(X):

返回sum(100.0 *(x [1:]  -  x [: -  1] ** 2.0)** 2.0(1 -x [: -  1])** 2.0)
登入後複製

在這裡,定義了Rosenbrock函數,這是優化中常見的測試問題。該函數將數組x作為輸入,併計算一個代表x距函數全局最小值的valie。 @JIT裝飾器用於啟用JUT-IN-IN時間彙編,該彙編通過編譯功能在CPU和GPU上有效運行來加快計算的速度。

步驟2:梯度下降步驟實現

#梯度下降優化

@Jit

def gradient_descent_step(x,Learning_rate):

返回X -Learning_rate * grad(Rosenbrock)(x)
登入後複製

此功能執行梯度下降優化的單一步驟。使用Grad(Rosenbrock)(X)計算Rosenbrock函數的梯度,該級提供了相對於X的導數。 X的新值通過減法更新,通過Learning_rate縮放梯度。@Jit的做法與以前相同。

步驟3:運行優化循環

# 最佳化
x = jnp.array([0.0,0.0])#起點

Learning_rate = 0.001

對於範圍的我(2000年):

x = gradient_descent_step(x,Learning_rate)

如果我%100 == 0:

print(f“步驟{i},值:{Rosenbrock(x):。4f}”)
登入後複製

優化循環初始化了起點X,並執行梯度下降的1000次迭代。在每次迭代中,gradient_descent_step函數基於當前梯度更新。每100個步驟,當前的步驟編號和X處的Rosenbrock函數的值,提供優化的進度。

輸出:

閃電般的JAX指南

解決現實世界的物理問題

我們將模擬一個物理系統的運動系統的運動,該運動的運動震盪振盪器的運動模型,該系統像帶有摩擦的質量彈簧系統,車輛中的減震器或電路中的振盪一樣建模。不是很好嗎?我們開始做吧。

步驟1:參數定義

導入JAX
導入jax.numpy作為jnp


#定義參數
質量= 1.0#對象的質量(kg)
阻尼= 0.1#阻尼係數(kg/s)
spring_constant = 1.0#彈簧常數(n/m)

#定義時間步驟和總時間
DT = 0.01#時間步長(S)
num_steps = 3000#步驟數
登入後複製

定義了質量,阻尼係數和彈簧常數。這些決定了阻尼的諧波振盪器的物理特性。

步驟2:ODE定義

#定義ODES系統
DEF DAMPED_HARMONIC_COSCILLATOR(狀態,T):
    “”“計算阻尼諧波振盪器的衍生物。

    狀態:包含位置和速度的數組[X,V]
    T:時間(在此自治系統中不使用)
    ”“”
    x,v =狀態
    dxdt = v
    dvdt = -Damping / Mass * V -Spring_constant / Mass * x
    返回JNP.Array([DXDT,DVDT])
登入後複製

阻尼的諧波振盪器函數定義了振盪器的位置和速度的衍生物,代表了動力學系統。

步驟3:Euler的方法

#使用Euler的方法解決ODE
def euler_step(狀態,t,dt):
    “”“執行Euler方法的一步。”“”
    衍生物= damped_harmonic_coscillator(狀態,t)
    返回狀態衍生工具 * DT
登入後複製

一種簡單的數值方法用於求解ode。它在下一個時間步驟近似於當前狀態和導數。

步驟4:時間演變循環

#初始狀態:[位置,速度]
oniration_state = jnp.Array([1.0,0.0])#從質量開始,x = 1,v = 0

#時間演變
狀態= [initial_state]
時間= 0.0
對於範圍(num_steps)的步驟:
    next_state = euler_step(狀態[-1],時間,dt)
    states.append(next_state)
    時間= DT

#將狀態列表轉換為JAX數組進行分析
狀態= jnp.stack(狀態)
登入後複製

循環通過指定的時間步驟迭代,使用Euler的方法在每個步驟更新狀態。

輸出:

閃電般的JAX指南

步驟5:繪製結果

最後,我們可以繪製結果以可視化阻尼的諧波振盪器的行為。

 #繪製結果
導入matplotlib.pyplot作為PLT

plt.Style.use(“ GGPLOT”)

位置=狀態[:,0]
速度=狀態[:,1]
time_points = jnp.arange(0,(num_steps 1) * dt,dt)

plt.figure(無花果=(12,6))
plt.subplot(2,1,1)
plt.plot(time_points,位置,label =“位置”)
plt.xlabel(“時間”)
plt.ylabel(“位置(M)”)
plt.legend()

plt.subplot(2,1,2)
plt.plot(time_points,速度,label =“速度”,color =“橙色”)
plt.xlabel(“時間”)
plt.ylabel(“速度(m/s)”)
plt.legend()

plt.tight_layout()
plt.show()
登入後複製

輸出:

閃電般的JAX指南

我知道您渴望看到如何使用JAX構建神經網絡。因此,讓我們深入研究它。

在這裡,您可以看到這些值逐漸最小化。

用JAX構建神經網絡

JAX是一個功能強大的庫,將高性能數值計算與使用Numpy樣語法的易用性結合在一起。本節將指導您使用JAX構建神經網絡的過程,並利用其高級功能進行自動差異化和即時彙編以優化性能。

步驟1:導入庫

在我們深入建立神經網絡之前,我們需要進口必要的庫。 JAX提供了一套用於創建有效數值計算的工具,而其他庫將有助於優化和可視化我們的結果。

導入JAX
導入jax.numpy作為jnp
來自Jax Import Grad,Jit
來自jax.random導入prngkey,正常
導入Optax#JAX的優化庫
導入matplotlib.pyplot作為PLT
登入後複製

步驟2:創建模型層

創建有效的模型層對於定義神經網絡的體系結構至關重要。在此步驟中,我們將初始化密集層的參數,以確保我們的模型從定義明確的權重和偏見開始,以進行有效學習。

 def init_layer_params(key,n_in,n_out):
    “”“單個密集層的初始化參數”“”
    key_w,key_b = jax.random.split(key)
    #初始化
    w = normal(key_w,(n_in,n_out)) * jnp.sqrt(2.0 / n_in)  
    b = normal(key_b,(n_out,)) * 0.1
    返回(w,b)
    
def relu(x):
    “”“ relu激活函數”“”
    返回jnp.maximum(0,x)
    
登入後複製
  • 初始化函數:使用HE初始化重量的初始化和偏差的小值,INIT_LAYER_PARAMS初始化了權重(W)和偏見(B)。他或Kaiming He初始化對於具有relu激活函數的層次,還有其他流行的初始化方法,例如Xavier初始化,它對具有乙狀結腸激活的層效果更好。
  • 激活函數: Relu函數將Relu激活函數應用於將負值設置為零的輸入。

步驟3:定義向前通行證

正向通行證是神經網絡的基石,因為它決定了輸入數據如何流過網絡以產生輸出。在這裡,我們將通過通過初始化層將轉換應用於輸入數據來定義一種計算模型輸出的方法。

 def向前(參數,x):
    “”“前向兩個層神經網絡”“”“”
    (W1,B1),(W2,B2)=參數
    #第一層
    h1 = relu(jnp.dot(x,w1)b1)
    #輸出層
    logits = jnp.dot(h1,w2)b2
    返回logits
    
登入後複製
  • 正向通行:正向通過兩層神經網絡執行前向通過,通過應用線性轉換,然後進行relu和其他線性變換來計算輸出(logits)。

S TEP4:定義損失功能

定義明確的損失功能對於指導我們模型的培訓至關重要。在此步驟中,我們將實施平均誤差(MSE)損耗函數,該函數衡量了預測輸出符合目標值的程度,從而使模型能夠有效學習。

 def loss_fn(params,x,y):
    “”“平均平方錯誤損失”“”
    pred =向前(params,x)
    返回jnp.mean(((pred -y)** 2)
登入後複製
  • 損耗函數: Loss_FN計算預測邏輯和目標標籤(Y)之間的平均平方誤差(MSE)損耗。

步驟5:模型初始化

通過定義了模型體系結構和損失函數,我們現在轉向模型初始化。此步驟涉及設置我們的神經網絡的參數,以確保每一層都準備以隨機但適當縮放的權重和偏見開始訓練過程。

 def init_model(rng_key,input_dim,hidden_​​dim,output_dim):
    key1,key2 = jax.random.split(rng_key)
    params = [
        init_layer_params(key1,input_dim,hidden_​​dim),
        init_layer_params(key2,hidden_​​dim,output_dim),
    這是給出的
    返回參數
    
登入後複製
  • 模型初始化: init_model初始化了神經網絡兩層的權重和偏差。它對每一層的參數初始化使用兩個獨立的隨機鍵。

步驟6:訓練步驟

訓練神經網絡涉及基於損耗函數的計算梯度對其參數的迭代更新。在此步驟中,我們將實施一個有效地應用這些更新的培訓功能,從而使我們的模型可以通過多個時期的數據學習。

 @Jit
def train_step(params,opt_state,x_batch,y_batch):
    損失,grads = jax.value_and_grad(loss_fn)(params,x_batch,y_batch)
    更新,opt_state =優化器。
    params = optax.apply_updates(參數,更新)
    返回參數,opt_state,損失
登入後複製
  • 培訓步驟: Train_Step功能執行單個梯度下降更新。
  • 它使用value_and_grad計算損失和梯度,該value_and_grad既可以計算函數值和其他梯度。
  • 計算優化器更新,並相應地更新模型參數。
  • IS jit編譯以進行性能。

步驟7:數據和培訓循環

為了有效地培訓我們的模型,我們需要生成合適的數據並實施培訓循環。本節將介紹如何為我們的示例創建合成數據,以及如何跨多個批次和時代管理培訓過程。

 #生成一些示例數據
key = prngkey(0)
x_data = normal(鍵,(1000,10))#1000樣本,10個功能
y_data = jnp.sum(x_data ** 2,axis = 1,keepdims = true)#簡單的非線性函數

#初始化模型和優化器
params = init_model(key,input_dim = 10,hidden_​​dim = 32,output_dim = 1)
優化器= optax.adam(Learning_rate = 0.001)
opt_state =優化器(params)

#訓練循環
batch_size = 32
num_epochs = 100
num_batches = x_data.shape [0] // batch_size

#存儲時期和損失值的數組
epoch_array = []
loss_array = []

對於範圍(num_epochs)的時代:
    epoch_loss = 0.0
    對於範圍(num_batches)的批次:
        idx = jax.random.permunt(鍵,batch_size)
        x_batch = x_data [idx]
        y_batch = y_data [idx]
        params,opt_state,loss = train_step(params,opt_state,x_batch,y_batch)
        epoch_loss =損失

    #存儲時代的平均損失
    avg_loss = epoch_loss / num_batches
    epoch_array.append(epoch)
    lose_array.append(avg_loss)

    如果epoch%10 == 0:
        print(f“ epoch {epoch},損失:{avg_loss:.4f}”)
登入後複製
  • 數據生成:創建隨機培訓數據(X_DATA)和相應的目標(Y_DATA)值。模型和優化器初始化:模型參數和優化器狀態是初始化的。
  • 訓練環:使用迷你批次梯度下降,對網絡進行了指定數量的時期訓練。
  • 訓練循環迭代批次,使用Train_Step功能執行梯度更新。計算和存儲每個時期的平均損失。它打印了時期的數字和平均損失。

步驟8:繪製結果

可視化訓練結果是了解我們神經網絡的性能的關鍵。在此步驟中,我們將繪製培訓損失而不是時期,以觀察模型的學習程度並確定培訓過程中的任何潛在問題。

 #繪製結果
plt.plot(epoch_array,loss_array,label =“訓練損失”)
plt.xlabel(“ Epoch”)
plt.ylabel(“損失”)
plt.title(“時代訓練損失”)
plt.legend()
plt.show()
登入後複製

這些示例演示了JAX如何將高性能與乾淨,可讀的代碼結合在一起。 JAX鼓勵的功能編程樣式使組成操作變得容易並應用轉換。

輸出:

閃電般的JAX指南

陰謀:

閃電般的JAX指南

這些示例演示了JAX如何將高性能與乾淨,可讀的代碼結合在一起。 JAX鼓勵的功能編程樣式使組成操作變得容易並應用轉換。

最佳實踐和技巧

在建立神經網絡時,遵守最佳實踐可以顯著提高性能和可維護性。本節將討論各種策略和技巧,以優化您的代碼並提高基於JAX的模型的整體效率。

性能優化

與JAX合作時,優化性能至關重要,因為它使我們能夠充分利用其功能。在這裡,我們將探索不同的技術來提高JAX功能的效率,以確保我們的模型在不犧牲可讀性的情況下盡快運行。

JIT彙編最佳實踐

Just-On-time(JIT)彙編是JAX的出色功能之一,可以通過在運行時編譯功能來更快地執行。本節將概述有效使用JIT編譯的最佳實踐,從而幫助您避免常見的陷阱並最大程度地提高代碼的性能。

不良功能

導入JAX
導入jax.numpy作為jnp
來自JAX Import Jit
來自JAX Import Lax


#不好:動態的python控制流
@Jit
def bad_function(x,n):
    對於範圍(n)的i:#python循環 - 將展開
        x = x 1
    返回x
    
    
打印(“ ==========================
#print(bad_function(1,1000))#不起作用
    
登入後複製

該函數使用標準的Python循環進行迭代n次,在每次迭代中將X的X遞增1。與JIT一起編譯時,JAX展開了循環,這可能是效率低下的,尤其是對於大型n。這種方法並不能完全利用JAX的功能進行性能。

好功能

#好:使用jax-native操作
@Jit
def good_function(x,n):
    返回xn#矢量化操作


打印(“ ==========================
打印(good_function(1,1000))
登入後複製

該函數執行相同的操作,但是它使用矢量化操作(XN)而不是循環。這種方法更有效,因為當以單個矢量化操作表示時,JAX可以更好地優化計算。

最佳功能

#更好:使用掃描進行循環


@Jit
def best_function(x,n):
    def body_fun(i,val):
        返回val 1

    返回lax.fori_loop(0,n,body_fun,x)


打印(“ ==========================
打印(best_function(1,1000))
登入後複製

此方法使用`jax.lax.fori_loop`,這是一種有效實現循環的JAX本地方法。 `lax.fori_loop`執行與上一個函數相同的增量操作,但是它使用編譯的循環結構進行操作。 Body_fn函數定義了每次迭代的操作,並且`lax.fori_loop`從o到n執行它。該方法比展開循環更有效,並且特別適用於未知迭代次數的情況。

輸出

 ===========================
===========================
1001
===========================
1001
登入後複製

該代碼展示了處理循環和控制流程中jax符合JIT符合功能的不同方法。

內存管理

在任何計算框架中,有效的內存管理至關重要,尤其是在處理大型數據集或複雜模型時。本節將討論內存分配中的常見陷阱,並提供在JAX中優化內存使用情況的策略。

效率低下的內存管理

#不好:創建大型臨時陣列
@Jit
def infelfficited_function(x):
    temp1 = jnp.power(x,2)#臨時數組
    temp2 = jnp.sin(temp1)#另一個臨時
    返回JNP.SUM(temp2)
登入後複製

defficited_function(x):此函數創建多個中間陣列Temp1,temp1,最後創建了temp2中元素的總和。創建這些臨時陣列可能是降低的,因為每個步驟都會分配內存並產生計算開銷,從而導致執行速度較慢和更高的內存使用情況。

有效的內存管理

#好:結合操作
@Jit
def效率_function(x):
    返回JNP.SUM(JNP.SIN(JNP.Power(x,2)))#單操作
登入後複製

此版本將所有操作結合到一行代碼中。它直接計算X平方元素的正弦,並總和結果。通過結合操作,它可以避免創建中間陣列,減少內存足跡並提高性能。

測試代碼

x = jnp.Array([1,2,3])
打印(x)
打印(inffelided_function(x))
打印(效率_function(x))
登入後複製

輸出:

 [1 2 3]
0.49678695
0.49678695
登入後複製

有效的版本利用JAX優化計算圖的能力,通過最大程度地減少臨時數組創建來使代碼更快,更快。

調試策略

調試是開發過程的重要組成部分,尤其是在復雜的數值計算中。在本節中,我們將討論特定於JAX的有效調試策略,使您能夠快速識別和解決問題。

在JIT內使用打印件進行調試

該代碼顯示了在JAX內調試的技術,尤其是在使用JIT編譯功能時。

導入jax.numpy作為jnp
來自JAX Import Debug


@Jit
def debug_function(x):
    #使用debug.print而不是在jit中打印
    debug.print(“ x的形狀:{}”,x.Shape)
    y = jnp.sum(x)
    debug.print(“ sum:{}”,y)
    返回y
登入後複製
 #要進行更複雜的調試,請突破JIT
def debug_values(x):
    打印(“輸入:”,x)
    結果= debug_function(x)
    打印(“輸出:”,結果)
    返回結果
    
登入後複製
  • debug_function(x):此功能顯示瞭如何使用debug.print()在jit編譯函數內進行調試。在JAX中,由於彙編限制,在JIT內不允許常規的Python打印語句,因此使用Debug.print()。
  • 它使用debug.print()打印輸入數組x的形狀
  • 在計算X的元素之和之後,它使用debug.print()打印產生的總和。
  • 最後,該函數返回計算的總和y。
  • debug_values(x)功能是一種高級調試方法,突破了JIT上下文以進行更複雜的調試。它首先使用常規打印語句打印輸入X。然後調用debug_function(x)計算結果,並在返回結果之前最終打印輸出。

輸出:

打印(“ ==========================
打印(debug_function(jnp.array([1,2,3]))))))
打印(“ ==========================
打印(debug_values(jnp.array([1,2,3])))))))
登入後複製

閃電般的JAX指南

這種方法允許使用標準的Python打印語句組合使用Debug.print()和更詳細的調試。

jax中的常見模式和成語

最後,我們將探索JAX中的常見模式和成語,可以幫助簡化您的編碼過程並提高效率。熟悉這些實踐將有助於開發更強大和表現的JAX應用程序。

用於處理大數據集的設備內存管理

#1。設備內存管理
def process_large_data(數據):
    #塊中的過程以管理內存
    chunk_size = 100
    結果= []

    對於i在範圍內(0,len(data),chunk_size):
        塊=數據[i:i Chunk_size]
        chunk_result = jit(process_chunk)(塊)
        結果。

    返回JNP.Concatenate(結果)


def Process_chunk(塊):
    chunk_temp = jnp.sqrt(塊)
    返回chunk_temp
登入後複製

此功能會在塊中處理大型數據集,以避免壓倒性的設備內存。

它將Chunk_size設置為100,並在塊大小的數據增量上進行迭代,並分別處理每個塊。

對於每個塊,該函數使用JIT(Process_chunk)來彙編處理操作,從而通過提前進行編譯來改善性能。

每個塊的結果都使用JNP.Concatenated(結果)將單個列表串成一個單個數組。

輸出:

打印(“ ==========================
data = jnp.Arange(10000)
打印(Data.Shape)

打印(“ ==========================
打印(數據)

打印(“ ==========================
打印(process_large_data(數據))
登入後複製

閃電般的JAX指南

處理隨機種子以獲得可重複性和更好的數據生成

函數create_traing_state()演示了JAX中的隨機數生成器(RNG)的管理,這對於可重複性和一致的結果至關重要。

 #2。處理隨機種子
def create_training_state(rng):
    #用於不同用途的拆分RNG
    rng,init_rng = jax.random.split(rng)
    params = init_network(init_rng)

    返回參數,rng#返回新的RNG供下一個使用
    
登入後複製

它從初始RNG(RNG)開始,然後使用jax.random.split()將其分為兩個新的RNG。 Split RNGS執行不同的任務:`init_rng“初始化網絡參數,以及更新的RNG返回以進行後續操作。

該函數返回初始化的網絡參數和新的RNG供進一步使用,以確保在不同步驟中正確處理隨機狀態。

現在使用模擬數據測試代碼

DEF INIT_NETWORK(RNG):
    #初始化網絡參數
    返回 {
        “ W1”:jax.random.normal(rng,(784,256)),
        “ b1”:jax.random.normal(rng,(256,)),
        “ W2”:jax.random.normal(rng,(256,10)),
        “ b2”:jax.random.normal(rng,(10,)),
    }


打印(“ ==========================

key = jax.random.prngkey(0)
參數,rng = create_training_state(鍵)


打印(f“隨機數生成器:{rng}”)

打印(params.keys())

打印(“ ==========================


打印(“ ==========================
打印(f“網絡參數形狀:{params ['w1']。形狀}”)

打印(“ ==========================
打印(f“網絡參數形狀:{params ['b1']。形狀}”)
打印(“ ==========================
打印(f“網絡參數形狀:{params ['w2']。形狀}”)

打印(“ ==========================
打印(f“網絡參數形狀:{params ['b2']。形狀}”)


打印(“ ==========================
打印(f“網絡參數:{params}”)
登入後複製

輸出:

閃電般的JAX指南

閃電般的JAX指南

在JIT中使用靜態論點

def G(x,n):
    i = 0
    當我<n i="1" g_jit_correct="jax.jit(g,static_argnames" n><p><strong>輸出:</strong></p>
<pre class="brush:php;toolbar:false"> 30
登入後複製

如果JIT每次使用相同的參數編譯函數,則可以使用靜態參數。這對於JAX函數的性能優化很有用。

從函數引入部分導入


@partial(jax.jit,static_argnames = [“ n”])
def g_jit_decorated(x,n):
    i = 0
    當我<n i="1"><p>如果您想將JIT中的靜態參數用作裝飾器,則可以在功能上使用JIT。 partial()函數。</p>
<p><strong>輸出:</strong></p>
<pre class="brush:php;toolbar:false"> 30
登入後複製

現在,我們已經學習並深入研究了許多令人興奮的概念和技巧以及整體編程風格。

接下來是什麼?

  • 嘗試示例:嘗試修改代碼示例以了解有關JAX的更多信息。建立一個小型項目,以更好地了解JAX的轉換和API。用JAX(例如邏輯回歸,支持向量機等)實現經典的機器學習算法。
  • 探索高級主題:使用PMAP的並行計算,自定義JAX轉換,與其他框架集成

本文中使用的所有代碼都在這裡

結論

JAX是一種強大的工具,可為機器學習,深度學習和科學計算提供廣泛的功能。從基礎知識開始,進行實驗,並從Jax美麗的文檔和社區獲得幫助。有很多東西要學習,只要閱讀他人的代碼,就不會學到這一點。因此,立即開始在JAX中創建一個小型項目。關鍵是繼續前進,學習途中。

關鍵要點

  • 熟悉的類似Numpy的界面和API使初學者的學習很容易。大多數Numpy代碼可用於最小修改。
  • JAX鼓勵乾淨的功能編程模式,從而導致更清潔,更可維護的代碼和升級。但是,如果開發人員希望JAX與面向對象的範式完全兼容。
  • 使JAX的功能如此強大的原因是自動分化和JAX的JIT編譯,這使得大規模數據處理效率。
  • JAX在科學計算,優化,神經網絡,模擬和機器學習方面表現出色,這使開發人員易於在其各自的項目中使用。

常見問題

Q1。是什麼使JAX與Numpy不同?

答:儘管JAX感覺就像Numpy,但它增加了自動分化,JIT彙編和GPU/TPU支持。

Q2。我需要GPU使用JAX嗎?

答:在一個單詞中,儘管擁有GPU可以顯著加快較大數據的計算。

Q3。 JAX是Numpy的好替代品嗎?

答:是的,您可以將JAX用作Numpy的替代方法,儘管如果您很好地使用JAX的功能,Jax的API看起來對Numpy Jax熟悉,則更強大。

Q4。我可以將現有的Numpy代碼與JAX一起使用嗎?

答:大多數Numpy代碼可以以最小的更改適應JAX。通常只是將導入numpy作為NP將其導入JAX.numpy作為JNP。

Q5。 Jax比Numpy更難學習嗎?

答:基礎知識和numpy一樣容易!告訴我一件事,閱讀上述文章和動手完成後,您會發現很難嗎?我為你回答。是的,很難。每個框架,語言,庫都很難,不是因為設計很難,而是因為我們沒有花太多時間來探索它。讓它有時間弄髒您的手,每天都會更容易。

本文所示的媒體不由Analytics Vidhya擁有,並由作者酌情使用。

以上是閃電般的JAX指南的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
最新問題
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板