首頁 > 後端開發 > Python教學 > JAX `vmap` 對於多個參數的意外行為

JAX `vmap` 對於多個參數的意外行為

王林
發布: 2024-02-09 09:21:07
轉載
1132 人瀏覽過

JAX `vmap` 对于多个参数的意外行为

問題內容

我發現 jax 中的 vmap 在應用於多個參數時不會如預期執行。例如,考慮下面的函數:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f
登入後複製

對於x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),此函數的輸出形狀為(7, 5 , 3)。但是,對於以下 vmap 版本:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f
登入後複製

它輸出此錯誤:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]
登入後複製

有人可以解釋一下這個錯誤背後的原因嗎?


正確答案


vmap 的語意是它對一個或多個陣列執行單一批次運算。當您指定in_axes=(none, 0, 0) 時,含義是「同時沿著yz 的前導維度映射」:您看到的錯誤告訴您yy 的前導維度具有不同的大小,因此它們不相容於批次。

您的函數f1 本質上使用廣播來編碼三個批次操作,因此要使用vmap 複製該邏輯,您將需要vmap 的三個應用程式。您可以這樣表達:

@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f
登入後複製

以上是JAX `vmap` 對於多個參數的意外行為的詳細內容。更多資訊請關注PHP中文網其他相關文章!

來源:stackoverflow.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板