JAX „vmap' unerwartetes Verhalten mit mehreren Parametern

王林
Freigeben: 2024-02-09 09:21:07
nach vorne
1031 Leute haben es durchsucht

JAX `vmap` 对于多个参数的意外行为

Frageninhalt

Ich habe festgestellt, dass sich vmap in Jax nicht wie erwartet verhält, wenn es auf mehrere Parameter angewendet wird. Betrachten Sie beispielsweise die folgende Funktion:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f
Nach dem Login kopieren

für x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3). Allerdings für die folgenden vmap-Versionen:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f
Nach dem Login kopieren

Es wird dieser Fehler ausgegeben:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]
Nach dem Login kopieren

Kann jemand den Grund für diesen Fehler erklären? Die Semantik von


Richtige Antwort


vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 yz 的前导维度映射”:您看到的错误告诉您 yy besteht darin, dass eine einzelne Stapeloperation für ein oder mehrere Arrays ausgeführt wird. Wenn Sie in_axes=(none, 0, 0) angeben, bedeutet dies „Abbildung entlang der führenden Dimensionen von y und z“: Sie Der angezeigte Fehler weist darauf hin, dass die führenden Dimensionen von y und y unterschiedliche Größen haben und daher nicht stapelkompatibel sind.

Ihre Funktion f1 verwendet im Wesentlichen Broadcasting, um drei Stapeloperationen zu kodieren. Um diese Logik mit f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap zu replizieren, benötigen Sie also drei Anwendungen von

. Du kannst es so ausdrücken: 🎜
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f
Nach dem Login kopieren

Das obige ist der detaillierte Inhalt vonJAX „vmap' unerwartetes Verhalten mit mehreren Parametern. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Quelle:stackoverflow.com
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage
Über uns Haftungsausschluss Sitemap
Chinesische PHP-Website:Online-PHP-Schulung für das Gemeinwohl,Helfen Sie PHP-Lernenden, sich schnell weiterzuentwickeln!