I discovered that vmap
in jax does not behave as expected when applied to multiple parameters. For example, consider the following function:
def f1(x, y, z): f = x[:, none, none] * z[none, none, :] + y[none, :, none] return f
For x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3)
, the output shape of this function is (7, 5 , 3)
. However, for the following vmap versions:
@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2)) def f2(x, y, z): f = x*z + y return f
It outputs this error:
ValueError: vmap got inconsistent sizes for array axes to be mapped: * one axis had size 5: axis 0 of argument y of type int32[5]; * one axis had size 3: axis 0 of argument z of type int32[3]
Can someone explain the reason behind this error?
The semantics of vmap
is that it performs a single batch operation on one or more arrays. When you specify in_axes=(none, 0, 0)
, the meaning is "map along the leading dimensions of both y
and z
": What you see The error tells you that the leading dimensions of y
and y
have different sizes, so they are not batch compatible.
Your function f1
essentially uses broadcasting to encode three batch operations, so to replicate that logic using vmap
you will need vmap
's Three applications. You can express it like this:
@partial(vmap, in_axes=(0, None, None)) @partial(vmap, in_axes=(None, 0, None)) @partial(vmap, in_axes=(None, None, 0)) def f2(x, y, z): f = x*z + y return f
The above is the detailed content of JAX `vmap` unexpected behavior with multiple parameters. For more information, please follow other related articles on the PHP Chinese website!