Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 142 additions & 27 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@

# Type alias for graph builder functions.
# These functions take a test instance and return a graph module and the target op node.
# For fused patterns (e.g., conv+relu), an optional third element specifies the node
# whose args contain the quantized inputs (e.g., conv node for conv+relu fusion).
GraphBuilderFn = Callable[
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
["QuantizerAnnotationTest"],
tuple[torch.fx.GraphModule, torch.fx.Node]
| tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node],
]


# Quantizers intentionally excluded from annotation testing.
# These should be explicitly justified when added.
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
Expand All @@ -64,14 +67,15 @@
# Test case definitions for quantizer annotation tests.
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
# Adding a new quantizer test only requires adding a tuple to this list.
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
QUANTIZER_ANNOTATION_TEST_CASES: list[
tuple[
str,
GraphBuilderFn,
CadenceQuantizer,
OpOverload,
QuantizationSpec,
list[QuantizationSpec],
list[QuantizationSpec | None],
]
] = [
(
Expand Down Expand Up @@ -192,6 +196,26 @@
# For relu: only input_activation
[qconfig_A8W8.input_activation],
),
(
"default_addmm_A8W8",
lambda self: self._build_addmm_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.addmm.default,
qconfig_A8W8.output_activation,
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
),
# CadenceFusedConvReluQuantizer test cases
(
"fused_conv2d_relu_A8W8sym",
lambda self: self._build_conv2d_relu_graph(),
CadenceFusedConvReluQuantizer(),
torch.ops.aten.relu.default,
qconfig_A8W8sym.output_activation,
# For fused conv2d+relu: [input_activation, weight] from conv2d node
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
),
]

# Derive the set of tested quantizer classes from the test cases.
Expand Down Expand Up @@ -408,6 +432,77 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
return gm, relu_nodes[0]

def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with an addmm operation."""
builder = GraphBuilder()
# addmm: bias + (mat1 @ mat2)
# args: (bias, mat1, mat2)
bias = builder.placeholder("bias", torch.randn(5))
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
addmm = builder.call_operator(
op=torch.ops.aten.addmm.default,
args=(bias, mat1, mat2),
meta=NodeMetadata(
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
),
)
builder.output([addmm])
gm = builder.get_graph_module()

addmm_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.addmm.default,
)
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
return gm, addmm_nodes[0]

def _build_conv2d_relu_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
"""Build a graph with a conv2d followed by relu (fused pattern).

Returns:
A tuple of (graph_module, relu_node, conv_node).
The relu_node is the target node where the annotation is placed.
The conv_node is the input source node whose args contain the quantized inputs.
"""
builder = GraphBuilder()
# Input shape: (batch, in_channels, height, width)
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
conv2d = builder.call_operator(
op=torch.ops.aten.conv2d.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
),
)
relu = builder.call_operator(
op=torch.ops.aten.relu.default,
args=(conv2d,),
meta=NodeMetadata(
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
),
)
builder.output([relu])
gm = builder.get_graph_module()

relu_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.relu.default,
)
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")

conv2d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv2d.default,
)
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")

return gm, relu_nodes[0], conv2d_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
Expand All @@ -416,36 +511,56 @@ def test_quantizer_annotation(
quantizer: CadenceQuantizer,
target: OpOverload,
expected_output_qspec: QuantizationSpec,
expected_input_qspecs: list[QuantizationSpec],
expected_input_qspecs: list[QuantizationSpec | None],
) -> None:
"""Parameterized test for quantizer annotations."""
gm, op_node = graph_builder_fn(self)
result = graph_builder_fn(self)
# Handle both 2-element and 3-element returns from graph builders.
# For fused patterns, the 3rd element specifies the node whose args
# contain the quantized inputs (e.g., conv node for conv+relu fusion).
if len(result) == 3:
gm = result[0]
output_node = result[1]
input_source_node = result[2]
else:
gm = result[0]
output_node = result[1]
input_source_node = output_node

quantizer.annotate(gm)

annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)

# Verify output annotation
self.assertEqual(annotation.output_qspec, expected_output_qspec)

# Verify input annotations
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
for i, (input_node, input_qspec) in enumerate(
annotation.input_qspec_map.items()
):
expected_arg = op_node.args[i]
assert isinstance(expected_arg, torch.fx.Node)
self.assertEqual(
input_node,
expected_arg,
f"Input node mismatch at index {i}",
)
self.assertEqual(
input_qspec,
expected_input_qspecs[i],
f"Input qspec mismatch at index {i}",
# Verify output annotation (always on the output node)
output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(output_annotation._annotated)
self.assertEqual(output_annotation.output_qspec, expected_output_qspec)

# Verify input annotations (on the input source node, which may differ for fused patterns)
input_annotation: QuantizationAnnotation = input_source_node.meta[
Q_ANNOTATION_KEY
]
self.assertEqual(
len(input_annotation.input_qspec_map), len(expected_input_qspecs)
)
for input_node, input_qspec in input_annotation.input_qspec_map.items():
# Find the index of this input node in the input source node's args
arg_index = None
args = input_source_node.args
assert isinstance(args, tuple)
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using Python's assert statement in test code is not ideal because it can be disabled with optimization flags. Consider using self.assertIsInstance(args, tuple) instead to ensure the check always runs.

Suggested change
assert isinstance(args, tuple)
self.assertIsInstance(args, tuple)

Copilot uses AI. Check for mistakes.
for i, arg in enumerate(args):
if arg is input_node:
arg_index = i
break
self.assertIsNotNone(
arg_index,
f"Input node {input_node} not found in input_source_node.args",
)
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
if expected_input_qspecs[arg_index] is not None:
self.assertEqual(
input_qspec,
expected_input_qspecs[arg_index],
f"Input qspec mismatch at arg index {arg_index}",
)
Comment on lines +558 to +563
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a potential IndexError when accessing expected_input_qspecs[arg_index]. While the code checks that len(input_annotation.input_qspec_map) equals len(expected_input_qspecs), it doesn't guarantee that arg_index will be within bounds. For example, if input_source_node has args at positions [0, 1, 2] and the input_qspec_map contains entries for args at positions [1, 2], the arg_index could be 2, but expected_input_qspecs might only have 2 elements (indices 0 and 1). Consider adding a bounds check before accessing expected_input_qspecs[arg_index].

Copilot uses AI. Check for mistakes.

def test_all_quantizers_have_annotation_tests(self) -> None:
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
Expand Down
Loading