minuet.nn.functional.scan#
Functions
|
Compute the input & output masks (metadata tables) from the kernel map. |
|
Compute the size of each weights of the kernel map. |
- compute_kernel_map_masks(num_sources: int, kernel_map: Tensor, kernel_map_sizes: Tensor)#
Compute the input & output masks (metadata tables) from the kernel map. This method should not be directly used by the user.
- Parameters:
num_sources – the number of source coordinates
kernel_map – the tensor of shape \((N_\text{weight}, D)\) representing the kernel map
kernel_map_sizes – the tensor that stores the size of the kernel map for each weight
- Returns:
the tensor of input masks and the tensor of output masks
- compute_kernel_map_sizes(kernel_map: Tensor)#
Compute the size of each weights of the kernel map. This method should not be directly used by the user.
- Parameters:
kernel_map – the tensor of shape \((N_\text{weight}, D)\) representing the kernel map
- Returns:
- a tensor of shape \((N_\text{weight})\) representing the size of the
kernel map