JAX `vmap` unexpected behavior with multiple parameters

王林
Release: 2024-02-09 09:21:07
forward
1029 people have browsed it

JAX `vmap` 对于多个参数的意外行为

Question content

I discovered that vmap in jax does not behave as expected when applied to multiple parameters. For example, consider the following function:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f
Copy after login

For x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3), the output shape of this function is (7, 5 , 3). However, for the following vmap versions:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f
Copy after login

It outputs this error:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]
Copy after login

Can someone explain the reason behind this error?


Correct answer


The semantics of vmap is that it performs a single batch operation on one or more arrays. When you specify in_axes=(none, 0, 0), the meaning is "map along the leading dimensions of both y and z": What you see The error tells you that the leading dimensions of y and y have different sizes, so they are not batch compatible.

Your function f1 essentially uses broadcasting to encode three batch operations, so to replicate that logic using vmap you will need vmap's Three applications. You can express it like this:

@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f
Copy after login

The above is the detailed content of JAX `vmap` unexpected behavior with multiple parameters. For more information, please follow other related articles on the PHP Chinese website!

source:stackoverflow.com
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template
About us Disclaimer Sitemap
php.cn:Public welfare online PHP training,Help PHP learners grow quickly!