Gathers values along an axis specified by dim with indices.
torch.gather — PyTorch 2.3 documentation
input and index must have the same number of dimensions.
It is also required that index.size(d) <= input.size(d) for all
dimensions d != dim. out will have the same shape as index.
Note that input and index do not broadcast against each other.
https://pytorch.org/docs/stable/generated/torch.gather.html

Seonglae Cho