ホームページ > バックエンド開発 > Python チュートリアル > JAX `vmap` が複数のパラメーターを使用した場合の予期しない動作

JAX `vmap` が複数のパラメーターを使用した場合の予期しない動作

王林
リリース: 2024-02-09 09:21:07
転載
1132 人が閲覧しました

JAX `vmap` 对于多个参数的意外行为

質問内容

jax の vmap が複数のパラメータに適用されると期待どおりに動作しないことがわかりました。たとえば、次の関数について考えてみましょう:

リーリー

x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3) の場合、この関数の出力形状は (7, 5、3)。ただし、次の vmap バージョンの場合:

リーリー

次のエラーが出力されます:

リーリー

誰かがこのエラーの背後にある理由を説明できますか?


正解


vmapのセマンティクスは、1つ以上の配列に対して単一のバッチ操作を実行することです。 in_axes=(none, 0, 0) を指定すると、「yz の両方の先頭の次元に沿ってマップ」という意味になります。このエラーは、yy の先頭のディメンションのサイズが異なるため、バッチ互換性がないことを示しています。

関数 f1 は基本的にブロードキャストを使用して 3 つのバッチ操作をエンコードしているため、vmap を使用してそのロジックを複製するには、vmap の 3 つのアプリケーションが必要になります。次のように表現できます:

リーリー

以上がJAX `vmap` が複数のパラメーターを使用した場合の予期しない動作の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

ソース:stackoverflow.com
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート