-
Notifications
You must be signed in to change notification settings - Fork 802
Adding Tests for CadenceFusedConvReluQuantizer #16358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c0af7d
636930a
4217e25
f289f26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,15 +46,18 @@ | |
|
|
||
| # Type alias for graph builder functions. | ||
| # These functions take a test instance and return a graph module and the target op node. | ||
| # For fused patterns (e.g., conv+relu), an optional third element specifies the node | ||
| # whose args contain the quantized inputs (e.g., conv node for conv+relu fusion). | ||
| GraphBuilderFn = Callable[ | ||
| ["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node] | ||
| ["QuantizerAnnotationTest"], | ||
| tuple[torch.fx.GraphModule, torch.fx.Node] | ||
| | tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node], | ||
| ] | ||
|
|
||
|
|
||
| # Quantizers intentionally excluded from annotation testing. | ||
| # These should be explicitly justified when added. | ||
| EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { | ||
| CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage | ||
| CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything | ||
| CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage | ||
| CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition | ||
|
|
@@ -64,14 +67,15 @@ | |
| # Test case definitions for quantizer annotation tests. | ||
| # Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs) | ||
| # Adding a new quantizer test only requires adding a tuple to this list. | ||
| # Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec). | ||
| QUANTIZER_ANNOTATION_TEST_CASES: list[ | ||
| tuple[ | ||
| str, | ||
| GraphBuilderFn, | ||
| CadenceQuantizer, | ||
| OpOverload, | ||
| QuantizationSpec, | ||
| list[QuantizationSpec], | ||
| list[QuantizationSpec | None], | ||
| ] | ||
| ] = [ | ||
| ( | ||
|
|
@@ -192,6 +196,26 @@ | |
| # For relu: only input_activation | ||
| [qconfig_A8W8.input_activation], | ||
| ), | ||
| ( | ||
| "default_addmm_A8W8", | ||
| lambda self: self._build_addmm_graph(), | ||
| CadenceDefaultQuantizer(), | ||
| torch.ops.aten.addmm.default, | ||
| qconfig_A8W8.output_activation, | ||
| # For addmm: [bias (DerivedQuantizationSpec), mat1, mat2] | ||
| # Use None to skip comparison for bias since it's a DerivedQuantizationSpec | ||
| [None, qconfig_A8W8.input_activation, qconfig_A8W8.weight], | ||
| ), | ||
| # CadenceFusedConvReluQuantizer test cases | ||
| ( | ||
| "fused_conv2d_relu_A8W8sym", | ||
| lambda self: self._build_conv2d_relu_graph(), | ||
| CadenceFusedConvReluQuantizer(), | ||
| torch.ops.aten.relu.default, | ||
| qconfig_A8W8sym.output_activation, | ||
| # For fused conv2d+relu: [input_activation, weight] from conv2d node | ||
| [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], | ||
| ), | ||
| ] | ||
|
|
||
| # Derive the set of tested quantizer classes from the test cases. | ||
|
|
@@ -408,6 +432,77 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: | |
| self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") | ||
| return gm, relu_nodes[0] | ||
|
|
||
| def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: | ||
| """Build a simple graph with an addmm operation.""" | ||
| builder = GraphBuilder() | ||
| # addmm: bias + (mat1 @ mat2) | ||
| # args: (bias, mat1, mat2) | ||
| bias = builder.placeholder("bias", torch.randn(5)) | ||
| mat1 = builder.placeholder("mat1", torch.randn(1, 10)) | ||
| mat2 = builder.placeholder("mat2", torch.randn(10, 5)) | ||
| addmm = builder.call_operator( | ||
| op=torch.ops.aten.addmm.default, | ||
| args=(bias, mat1, mat2), | ||
| meta=NodeMetadata( | ||
| {"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]} | ||
| ), | ||
| ) | ||
| builder.output([addmm]) | ||
| gm = builder.get_graph_module() | ||
|
|
||
| addmm_nodes = gm.graph.find_nodes( | ||
| op="call_function", | ||
| target=torch.ops.aten.addmm.default, | ||
| ) | ||
| self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node") | ||
| return gm, addmm_nodes[0] | ||
|
|
||
| def _build_conv2d_relu_graph( | ||
| self, | ||
| ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: | ||
| """Build a graph with a conv2d followed by relu (fused pattern). | ||
|
|
||
| Returns: | ||
| A tuple of (graph_module, relu_node, conv_node). | ||
| The relu_node is the target node where the annotation is placed. | ||
| The conv_node is the input source node whose args contain the quantized inputs. | ||
| """ | ||
| builder = GraphBuilder() | ||
| # Input shape: (batch, in_channels, height, width) | ||
| x = builder.placeholder("x", torch.randn(1, 3, 8, 8)) | ||
| # Weight shape: (out_channels, in_channels, kernel_h, kernel_w) | ||
| weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3)) | ||
| conv2d = builder.call_operator( | ||
| op=torch.ops.aten.conv2d.default, | ||
| args=(x, weight), | ||
| meta=NodeMetadata( | ||
| {"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]} | ||
| ), | ||
| ) | ||
| relu = builder.call_operator( | ||
| op=torch.ops.aten.relu.default, | ||
| args=(conv2d,), | ||
| meta=NodeMetadata( | ||
| {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} | ||
| ), | ||
| ) | ||
| builder.output([relu]) | ||
| gm = builder.get_graph_module() | ||
|
|
||
| relu_nodes = gm.graph.find_nodes( | ||
| op="call_function", | ||
| target=torch.ops.aten.relu.default, | ||
| ) | ||
| self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") | ||
|
|
||
| conv2d_nodes = gm.graph.find_nodes( | ||
| op="call_function", | ||
| target=torch.ops.aten.conv2d.default, | ||
| ) | ||
| self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node") | ||
|
|
||
| return gm, relu_nodes[0], conv2d_nodes[0] | ||
|
|
||
| @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) | ||
| def test_quantizer_annotation( | ||
| self, | ||
|
|
@@ -416,36 +511,56 @@ def test_quantizer_annotation( | |
| quantizer: CadenceQuantizer, | ||
| target: OpOverload, | ||
| expected_output_qspec: QuantizationSpec, | ||
| expected_input_qspecs: list[QuantizationSpec], | ||
| expected_input_qspecs: list[QuantizationSpec | None], | ||
| ) -> None: | ||
| """Parameterized test for quantizer annotations.""" | ||
| gm, op_node = graph_builder_fn(self) | ||
| result = graph_builder_fn(self) | ||
| # Handle both 2-element and 3-element returns from graph builders. | ||
| # For fused patterns, the 3rd element specifies the node whose args | ||
| # contain the quantized inputs (e.g., conv node for conv+relu fusion). | ||
| if len(result) == 3: | ||
| gm = result[0] | ||
| output_node = result[1] | ||
| input_source_node = result[2] | ||
| else: | ||
| gm = result[0] | ||
| output_node = result[1] | ||
| input_source_node = output_node | ||
|
|
||
| quantizer.annotate(gm) | ||
|
|
||
| annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY] | ||
| self.assertTrue(annotation._annotated) | ||
|
|
||
| # Verify output annotation | ||
| self.assertEqual(annotation.output_qspec, expected_output_qspec) | ||
|
|
||
| # Verify input annotations | ||
| self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) | ||
| for i, (input_node, input_qspec) in enumerate( | ||
| annotation.input_qspec_map.items() | ||
| ): | ||
| expected_arg = op_node.args[i] | ||
| assert isinstance(expected_arg, torch.fx.Node) | ||
| self.assertEqual( | ||
| input_node, | ||
| expected_arg, | ||
| f"Input node mismatch at index {i}", | ||
| ) | ||
| self.assertEqual( | ||
| input_qspec, | ||
| expected_input_qspecs[i], | ||
| f"Input qspec mismatch at index {i}", | ||
| # Verify output annotation (always on the output node) | ||
| output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY] | ||
| self.assertTrue(output_annotation._annotated) | ||
| self.assertEqual(output_annotation.output_qspec, expected_output_qspec) | ||
|
|
||
| # Verify input annotations (on the input source node, which may differ for fused patterns) | ||
| input_annotation: QuantizationAnnotation = input_source_node.meta[ | ||
| Q_ANNOTATION_KEY | ||
| ] | ||
| self.assertEqual( | ||
| len(input_annotation.input_qspec_map), len(expected_input_qspecs) | ||
| ) | ||
| for input_node, input_qspec in input_annotation.input_qspec_map.items(): | ||
| # Find the index of this input node in the input source node's args | ||
| arg_index = None | ||
| args = input_source_node.args | ||
| assert isinstance(args, tuple) | ||
| for i, arg in enumerate(args): | ||
| if arg is input_node: | ||
| arg_index = i | ||
| break | ||
| self.assertIsNotNone( | ||
| arg_index, | ||
| f"Input node {input_node} not found in input_source_node.args", | ||
| ) | ||
| # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) | ||
| if expected_input_qspecs[arg_index] is not None: | ||
| self.assertEqual( | ||
| input_qspec, | ||
| expected_input_qspecs[arg_index], | ||
| f"Input qspec mismatch at arg index {arg_index}", | ||
| ) | ||
|
Comment on lines
+558
to
+563
|
||
|
|
||
| def test_all_quantizers_have_annotation_tests(self) -> None: | ||
| """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using Python's assert statement in test code is not ideal because it can be disabled with optimization flags. Consider using self.assertIsInstance(args, tuple) instead to ensure the check always runs.