Maison > développement back-end > Tutoriel Python > Comportement inattendu de JAX `vmap` avec plusieurs paramètres

Comportement inattendu de JAX `vmap` avec plusieurs paramètres

王林
Libérer: 2024-02-09 09:21:07
avant
1105 Les gens l'ont consulté

JAX `vmap` 对于多个参数的意外行为

Contenu de la question

J'ai découvert que vmap dans jax ne se comporte pas comme prévu lorsqu'il est appliqué à plusieurs paramètres. Par exemple, considérons la fonction suivante :

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f
Copier après la connexion

pour x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3). Cependant, pour les versions de vmap suivantes :

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f
Copier après la connexion

Il affiche cette erreur :

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]
Copier après la connexion

Quelqu'un peut-il expliquer la raison de cette erreur ? La sémantique de


Correct Answer


vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 yz 的前导维度映射”:您看到的错误告诉您 yy est qu'elle effectue une opération par lots unique sur un ou plusieurs tableaux. Lorsque vous spécifiez in_axes=(none, 0, 0), la signification est « mapper le long des dimensions principales de y et z » : L'erreur que vous voyez vous indique que les dimensions principales de y et y ont des tailles différentes, elles ne sont donc pas compatibles par lots.

Votre fonction f1 utilise essentiellement la diffusion pour coder trois opérations par lots, donc pour reproduire cette logique en utilisant f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap vous auriez besoin de trois applications de

. Vous pouvez l'exprimer ainsi : 🎜
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f
Copier après la connexion

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

source:stackoverflow.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal