minuet.nn.functional.indexing#

Functions

arg_sort_coordinates(coordinates[, ...])

Sorting the coordinates in the given coordinate tensor.

build_sorted_index(coordinates[, ...])

Build sorted index for coordinate tensor(s).

query_sorted_index_with_offsets(sources, ...)

Build the kernel map with the sorted coordinates by querying the sorted indices.

arg_sort_coordinates(coordinates: Tensor, batch_dims: Tensor | None = None, dtype: dtype | None = None, enable_flattening: bool = True)#

Sorting the coordinates in the given coordinate tensor. Multiple coordinate tensors can be handled together by specifying the batch_dims tensor, which stores the start and the end indices of each coordinate tensors. By default, cub::DeviceRadixSort does not support coordinate sorting. Sorting coordinates are achieved by \(N\) independent launches of cub::DeviceRadixSort for \(N\)-D coordinates.

To optimize this, for coordinates with small ranges, we could compress them into a single int64 or int32 value and launch a single cub::DeviceRadixSort. The flag enable_flattening controls whether to enable this optimization. If the coordinate range is too large to compress into the maximum possible integer type, it will fall back to the naive way of sorting coordinates (i.e. with \(N\) cub::DeviceRadixSort kernel launches).

Parameters:
  • coordinates – the coordinate tensor(s) for which index is to be built

  • batch_dims – the batch_dims tensor if there are multiple coordinate tensors

  • enable_flattening – whether to enable coordinate flatting to speedup sorting.

Returns:

A tensor specifying the order of the sorted coordinates

build_sorted_index(coordinates: Tensor, batch_dims: Tensor | None = None, enable_flattening: bool = True)#

Build sorted index for coordinate tensor(s). Multiple coordinate tensors can be handled together by specifying the batch_dims tensor, which stores the start and the end indices of each coordinate tensors.

Parameters:
  • coordinates – the coordinate tensor(s) for which index is to be built

  • batch_dims – the batch_dims tensor if there are multiple coordinate tensors

  • enable_flattening – whether to enable coordinate flatting to speedup sorting.

Returns:

A tensor specifying the order of the sorted coordinates

query_sorted_index_with_offsets(sources: Tensor, targets: Tensor, offsets: Tensor, source_batch_dims: Tensor | None = None, target_batch_dims: Tensor | None = None)#

Build the kernel map with the sorted coordinates by querying the sorted indices. Multiple requests can be handled together by specifying the source_batch_dims and target_batch_dims tensor, which stores the start and the end indices of each input coordinate tensor and output coordinate tensor respectively.

Parameters:
  • sources – the sorted input coordinates of shape \([N_\text{in}, D]\)

  • targets – the sorted output coordinates of shape \([N_\text{out}, D]\)

  • offsets – the offsets of the weights of shape \([N_\text{weight}, D]\)

  • source_batch_dims – the batch_dims tensor (of the input coordinates) if there are multiple coordinate tensors

  • target_batch_dims – the batch_dims tensor (of the output coordinates) if there are multiple coordinate tensors

Returns:

A tensor specifying the kernel map of shape \([N_\text{weight}, N_\text{out}]\)