Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4503,6 +4503,27 @@ def aten_grid_sampler_3d_backward(
raise NotImplementedError()


@torch_op("aten::_grouped_mm", trace_only=True)
def aten_grouped_mm(
self: TFloat,
mat2: TFloat,
offs: Optional[TInt] = None,
bias: Optional[TFloat] = None,
out_dtype: Optional[int] = None,
) -> TFloat:
"""_grouped_mm(Tensor self, Tensor mat2, *, Tensor? offs=None, Tensor? bias=None, int? out_dtype=None) -> Tensor"""

# If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension.
# self: (G, M, K), mat2: (G, K, N) -> (G, M, N)
# TODO: Implement sparse mode when offs is not None.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you raise a not implemented error?

res = op.MatMul(self, mat2)
if bias is not None:
res = op.Add(res, bias)
if out_dtype is not None:
res = op.Cast(res, to=out_dtype)
return res


def aten_gru_cell(
input: TensorType,
hx: TensorType,
Expand Down
36 changes: 36 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@
M = 10


def sample_inputs_grouped_mm(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

cases = [
# (G, M, K), (G, K, N)
((2, 3, 4), (2, 4, 5)),
((1, 2, 2), (1, 2, 1)),
]

for self_shape, mat2_shape in cases:
self_t = make_arg(self_shape)
mat2_t = make_arg(mat2_shape)

# Test without bias
yield opinfo_core.SampleInput(self_t, args=(mat2_t,))

def _mock_grouped_mm(self, mat2, offs=None, bias=None, out_dtype=None):

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'offs' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning test

Unused argument 'out_dtype' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
res = torch.matmul(self, mat2)
if bias is not None:
res = res + bias
Comment on lines +46 to +50
Comment on lines +47 to +50
return res


def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -3144,4 +3172,12 @@
sample_inputs_func=sample_inputs_roi_pool,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._grouped_mm",
aten_name="_grouped_mm",
op=_mock_grouped_mm,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use

Suggested change
op=_mock_grouped_mm,
op=torch.ops.aten._grouped_mm,

?

dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_grouped_mm,
supports_out=False,
),
]
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def _where_input_wrangler(
reason="fixme: ORT does not support empty tensors as input",
),
TorchLibOpInfo("ge", core_ops.aten_ge),
TorchLibOpInfo("ops.aten._grouped_mm", core_ops.aten_grouped_mm),
TorchLibOpInfo("gt", core_ops.aten_gt),
TorchLibOpInfo("histc", core_ops.aten_histc)
.skip(
Expand Down
Loading