使用 (N-1) 维数组访问 N 维数组
给定一个 N 维数组 a 和一个 (N- 1)维数组idx,一个常见的任务是访问idx中索引指定的元素。这对于执行查找最大值或检索特定值等操作非常有用。
使用高级索引的优雅解决方案
一个优雅的解决方案涉及使用 NumPy 的 ogrid 函数进行高级索引:
<code class="python">m, n = a.shape[1:] I, J = np.ogrid[:m, :n] a_max_values = a[idx, I, J] b_max_values = b[idx, I, J]</code>
这将创建一个索引网格,并使用它来索引 a 和 b,从而生成包含相应值的数组。
函数的一般情况
对于适用于任何指定轴的更通用的解决方案,我们可以定义一个函数:
<code class="python">def argmax_to_max(arr, argmax, axis): new_shape = list(arr.shape) del new_shape[axis] grid = np.ogrid[tuple(map(slice, new_shape))] grid.insert(axis, argmax) return arr[tuple(grid)]</code>
此函数接受一个数组、其沿指定轴的 argmax 以及轴本身。然后它构造一个网格并使用它来提取相应的元素。
使用自定义函数简化索引
为了进一步简化索引过程,我们可以创建一个辅助函数生成索引网格:
<code class="python">def all_idx(idx, axis): grid = np.ogrid[tuple(map(slice, idx.shape))] grid.insert(axis, idx) return tuple(grid)</code>
此函数返回一个索引元组,可直接用于索引输入数组:
<code class="python">axis = 0 a_max_values = a[all_idx(idx, axis=axis)] b_max_values = b[all_idx(idx, axis=axis)]</code>
以上是如何用(N-1)维数组高效访问N维数组?的详细内容。更多信息请关注PHP中文网其他相关文章!