diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 67de7076fa..d06ba082d0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4503,6 +4503,27 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() +@torch_op("aten::_grouped_mm", trace_only=True) +def aten_grouped_mm( + self: TFloat, + mat2: TFloat, + offs: Optional[TInt] = None, + bias: Optional[TFloat] = None, + out_dtype: Optional[int] = None, +) -> TFloat: + """_grouped_mm(Tensor self, Tensor mat2, *, Tensor? offs=None, Tensor? bias=None, int? out_dtype=None) -> Tensor""" + + # If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension. + # self: (G, M, K), mat2: (G, K, N) -> (G, M, N) + # TODO: Implement sparse mode when offs is not None. + res = op.MatMul(self, mat2) + if bias is not None: + res = op.Add(res, bias) + if out_dtype is not None: + res = op.Cast(res, to=out_dtype) + return res + + def aten_gru_cell( input: TensorType, hx: TensorType, diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..19c3db8c4c 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -23,6 +23,34 @@ M = 10 +def sample_inputs_grouped_mm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + cases = [ + # (G, M, K), (G, K, N) + ((2, 3, 4), (2, 4, 5)), + ((1, 2, 2), (1, 2, 1)), + ] + + for self_shape, mat2_shape in cases: + self_t = make_arg(self_shape) + mat2_t = make_arg(mat2_shape) + + # Test without bias + yield opinfo_core.SampleInput(self_t, args=(mat2_t,)) + +def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None): + res = torch.matmul(self, mat2) + if bias is not None: + res = res + bias + return res + + def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -3144,4 +3172,12 @@ def __init__(self): sample_inputs_func=sample_inputs_roi_pool, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._grouped_mm", + aten_name="_grouped_mm", + op=_mock_grouped_mm, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_grouped_mm, + supports_out=False, + ), ] diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..61e5d6f537 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -750,6 +750,7 @@ def _where_input_wrangler( reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), + TorchLibOpInfo("ops.aten._grouped_mm", core_ops.aten_grouped_mm), TorchLibOpInfo("gt", core_ops.aten_gt), TorchLibOpInfo("histc", core_ops.aten_histc) .skip(