From d3617b06ea4d67b68b70391e8a574d938828f23b Mon Sep 17 00:00:00 2001 From: Legendx4060 Date: Fri, 13 Feb 2026 02:20:07 +0530 Subject: [PATCH 1/2] Add aten::_grouped_mm converter implementation Implements the converter for aten::_grouped_mm.default to address issue #2795. Handles the batch/dense mode where groups are implicit in the batch dimension using MatMul, with optional bias addition and dtype casting. --- .../function_libs/torch_lib/ops/core.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 483e0ea46f..f2e8d3b603 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4503,6 +4503,29 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() +@torch_op("aten::_grouped_mm") +def aten_grouped_mm( + self: TFloat, + mat2: TFloat, + offs: Optional[TInt] = None, + bias: Optional[TFloat] = None, + out_dtype: Optional[int] = None, +) -> TFloat: + """_grouped_mm(Tensor self, Tensor mat2, *, Tensor? offs=None, Tensor? bias=None, int? out_dtype=None) -> Tensor""" + + # If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension. + # self: (G, M, K), mat2: (G, K, N) -> (G, M, N) + if offs is None: + res = op.MatMul(self, mat2) + if bias is not None: + res = op.Add(res, bias) + if out_dtype is not None: + res = op.Cast(res, to=out_dtype) + return res + + raise NotImplementedError("aten::_grouped_mm with 'offs' is not supported.") + + def aten_gru_cell( input: TensorType, hx: TensorType, From 521dc1bc578553bc877f43b8a314e82aa1dfc081 Mon Sep 17 00:00:00 2001 From: Legendx4060 Date: Thu, 12 Mar 2026 18:16:49 +0530 Subject: [PATCH 2/2] Fix grouped_mm: Handle optional arguments correctly and add basic test cases --- .../function_libs/torch_lib/ops/core.py | 18 +++++----- tests/function_libs/torch_lib/extra_opinfo.py | 36 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f2e8d3b603..0cd3db26a2 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4503,7 +4503,7 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() -@torch_op("aten::_grouped_mm") +@torch_op("aten::_grouped_mm", trace_only=True) def aten_grouped_mm( self: TFloat, mat2: TFloat, @@ -4515,15 +4515,13 @@ def aten_grouped_mm( # If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension. # self: (G, M, K), mat2: (G, K, N) -> (G, M, N) - if offs is None: - res = op.MatMul(self, mat2) - if bias is not None: - res = op.Add(res, bias) - if out_dtype is not None: - res = op.Cast(res, to=out_dtype) - return res - - raise NotImplementedError("aten::_grouped_mm with 'offs' is not supported.") + # TODO: Implement sparse mode when offs is not None. + res = op.MatMul(self, mat2) + if bias is not None: + res = op.Add(res, bias) + if out_dtype is not None: + res = op.Cast(res, to=out_dtype) + return res def aten_gru_cell( diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..19c3db8c4c 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -23,6 +23,34 @@ M = 10 +def sample_inputs_grouped_mm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + cases = [ + # (G, M, K), (G, K, N) + ((2, 3, 4), (2, 4, 5)), + ((1, 2, 2), (1, 2, 1)), + ] + + for self_shape, mat2_shape in cases: + self_t = make_arg(self_shape) + mat2_t = make_arg(mat2_shape) + + # Test without bias + yield opinfo_core.SampleInput(self_t, args=(mat2_t,)) + +def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None): + res = torch.matmul(self, mat2) + if bias is not None: + res = res + bias + return res + + def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -3144,4 +3172,12 @@ def __init__(self): sample_inputs_func=sample_inputs_roi_pool, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._grouped_mm", + aten_name="_grouped_mm", + op=_mock_grouped_mm, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_grouped_mm, + supports_out=False, + ), ] diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..61e5d6f537 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -750,6 +750,7 @@ def _where_input_wrangler( reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), + TorchLibOpInfo("ops.aten._grouped_mm", core_ops.aten_grouped_mm), TorchLibOpInfo("gt", core_ops.aten_gt), TorchLibOpInfo("histc", core_ops.aten_histc) .skip(