torch.nn.functional.grouped_mm — PyTorch 2.10 documentation
mat_a (Tensor) – Left operand. When 2D, its leading dimension is sliced into groups
according to offs. When 3D, its first dimension enumerates the groups
directly and offs must be None.
https://docs.pytorch.org/docs/2.10/generated/torch.nn.functional.grouped_mm.html


Seonglae Cho