Access N-Dimensional Array with (N-1)-Dimensional Array
Given an N-dimensional array a and an (N-1)-dimensional array idx, a common task is to access elements in a specified by the indices in idx. This can be useful for performing operations such as finding maxima or retrieving specific values.
Elegant Solution Using Advanced Indexing
An elegant solution involves using advanced indexing with NumPy's ogrid function:
<code class="python">m, n = a.shape[1:] I, J = np.ogrid[:m, :n] a_max_values = a[idx, I, J] b_max_values = b[idx, I, J]</code>
This creates a meshgrid of indices and uses it to index into a and b, resulting in arrays containing the corresponding values.
General Case with Function
For a more general solution that works for any specified axis, we can define a function:
<code class="python">def argmax_to_max(arr, argmax, axis): new_shape = list(arr.shape) del new_shape[axis] grid = np.ogrid[tuple(map(slice, new_shape))] grid.insert(axis, argmax) return arr[tuple(grid)]</code>
This function takes an array, its argmax along a specified axis, and the axis itself. It then constructs a meshgrid and uses it to extract the corresponding elements.
Simplified Indexing with Custom Function
To further simplify the indexing process, we can create a helper function that generates a grid of indices:
<code class="python">def all_idx(idx, axis): grid = np.ogrid[tuple(map(slice, idx.shape))] grid.insert(axis, idx) return tuple(grid)</code>
This function returns a tuple of indices that can be used directly to index into input arrays:
<code class="python">axis = 0 a_max_values = a[all_idx(idx, axis=axis)] b_max_values = b[all_idx(idx, axis=axis)]</code>
The above is the detailed content of How to Access N-Dimensional Array with (N-1)-Dimensional Array Efficiently?. For more information, please follow other related articles on the PHP Chinese website!