jax の vmap
が複数のパラメータに適用されると期待どおりに動作しないことがわかりました。たとえば、次の関数について考えてみましょう:
x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3)
の場合、この関数の出力形状は (7, 5、3)
。ただし、次の vmap バージョンの場合:
次のエラーが出力されます:
リーリー誰かがこのエラーの背後にある理由を説明できますか?
vmap
のセマンティクスは、1つ以上の配列に対して単一のバッチ操作を実行することです。 in_axes=(none, 0, 0)
を指定すると、「y
と z
の両方の先頭の次元に沿ってマップ」という意味になります。このエラーは、y
と y
の先頭のディメンションのサイズが異なるため、バッチ互換性がないことを示しています。
関数 f1
は基本的にブロードキャストを使用して 3 つのバッチ操作をエンコードしているため、vmap
を使用してそのロジックを複製するには、vmap
の 3 つのアプリケーションが必要になります。次のように表現できます:
以上がJAX `vmap` が複数のパラメーターを使用した場合の予期しない動作の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。