Home > Backend Development > Python Tutorial > How to Access N-Dimensional Array with (N-1)-Dimensional Array Efficiently?

How to Access N-Dimensional Array with (N-1)-Dimensional Array Efficiently?

Susan Sarandon
Release: 2024-10-21 11:57:03
Original
242 people have browsed it

How to Access N-Dimensional Array with (N-1)-Dimensional Array Efficiently?

Access N-Dimensional Array with (N-1)-Dimensional Array

Given an N-dimensional array a and an (N-1)-dimensional array idx, a common task is to access elements in a specified by the indices in idx. This can be useful for performing operations such as finding maxima or retrieving specific values.

Elegant Solution Using Advanced Indexing

An elegant solution involves using advanced indexing with NumPy's ogrid function:

<code class="python">m, n = a.shape[1:]
I, J = np.ogrid[:m, :n]
a_max_values = a[idx, I, J]
b_max_values = b[idx, I, J]</code>
Copy after login

This creates a meshgrid of indices and uses it to index into a and b, resulting in arrays containing the corresponding values.

General Case with Function

For a more general solution that works for any specified axis, we can define a function:

<code class="python">def argmax_to_max(arr, argmax, axis):
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]</code>
Copy after login

This function takes an array, its argmax along a specified axis, and the axis itself. It then constructs a meshgrid and uses it to extract the corresponding elements.

Simplified Indexing with Custom Function

To further simplify the indexing process, we can create a helper function that generates a grid of indices:

<code class="python">def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)</code>
Copy after login

This function returns a tuple of indices that can be used directly to index into input arrays:

<code class="python">axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]</code>
Copy after login

The above is the detailed content of How to Access N-Dimensional Array with (N-1)-Dimensional Array Efficiently?. For more information, please follow other related articles on the PHP Chinese website!

source:php
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Latest Articles by Author
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template