使用(N-1) 維數組存取N 維數組
給定一個N 維數組a 和一個(N- 1)維數組idx,一個常見的任務是存取idx中索引指定的元素。這對於執行查找最大值或檢索特定值等操作非常有用。
使用高級索引的優雅解決方案
一個優雅的解決方案涉及使用NumPy 的ogrid 函數進行高級索引:
<code class="python">m, n = a.shape[1:] I, J = np.ogrid[:m, :n] a_max_values = a[idx, I, J] b_max_values = b[idx, I, J]</code>
這將建立一個索引網格,並使用它來索引a 和b,從而產生包含相應值的陣列。
函數的一般情況
對於適用於任何指定軸的更通用的解決方案,我們可以定義一個函數:
<code class="python">def argmax_to_max(arr, argmax, axis): new_shape = list(arr.shape) del new_shape[axis] grid = np.ogrid[tuple(map(slice, new_shape))] grid.insert(axis, argmax) return arr[tuple(grid)]</code>
此函數接受一個陣列、其沿指定軸的argmax 以及軸本身。然後它構造一個網格並使用它來提取相應的元素。
使用自訂函數簡化索引
為了進一步簡化索引過程,我們可以建立一個輔助函數來產生索引網格:
<code class="python">def all_idx(idx, axis): grid = np.ogrid[tuple(map(slice, idx.shape))] grid.insert(axis, idx) return tuple(grid)</code>
此函數傳回索引元組,可直接用於索引輸入數組:
<code class="python">axis = 0 a_max_values = a[all_idx(idx, axis=axis)] b_max_values = b[all_idx(idx, axis=axis)]</code>
以上是如何用(N-1)維數組高效率存取N維數組?的詳細內容。更多資訊請關注PHP中文網其他相關文章!