File size: 340 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch


def knn_gather(x, idx):
    """

    :param x: (B, N, C)

    :param idx: (B, N, K)

    :return: (B, N, K, C)

    """
    C = x.shape[-1]
    B, N, K = idx.shape
    idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, C)
    x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)

    return x_out