diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 1dc0875871..fc0267271e 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -229,6 +229,8 @@ def __init__(self, graph: ir.Graph) -> None: # and allows sharing them across different layers/contexts. self._constant_cache: dict[tuple[Any, ir.DataType | None], ir.Value] = {} + self._functions: dict[ir.OperatorIdentifier, ir.Function] = {} + def opset(self, domain: str, version: int = 1) -> OpBuilder: """Create an OpBuilder bound to the given domain and version.""" return OpBuilder(self, domain, version) @@ -241,6 +243,10 @@ def op(self) -> OpBuilder: def graph(self) -> ir.Graph: return self._graph + @property + def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: + return self._functions + def initializer( self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True ) -> ir.Value: @@ -543,16 +549,18 @@ def call_op( def call( self, - function, + function: ir.Function | onnxscript.OnnxFunction, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", + _inline: bool = True, **kwargs, ): if isinstance(function, ir.Function): graph = function.graph elif isinstance(function, onnxscript.OnnxFunction): graph = function.graph() + function = function.function_ir else: raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") output_renaming: dict[str, str] = {} @@ -567,18 +575,36 @@ def call( else: for output in graph.outputs: output_renaming[output.name] = self._qualify_value_name(output.name) - nodes, outputs = _inliner.instantiate(graph, args, kwargs) if _prefix: self.push_module(_prefix) - for node in nodes: - node.name = self._qualify_node_name(node.name) - for output in node.outputs: - if output.name: - if output.name in output_renaming: - output.name = output_renaming[output.name] - else: - output.name = self._qualify_value_name(output.name) + + if _inline: + nodes, outputs = _inliner.instantiate(graph, args, kwargs) + + for node in nodes: + node.name = self._qualify_node_name(node.name) + for output in node.outputs: + if output.name: + if output.name in output_renaming: + output.name = output_renaming[output.name] + else: + output.name = self._qualify_value_name(output.name) + self.add_node(node) + else: + node = ir.node( + op_type=function.name, + inputs=args, + attributes=kwargs or None, + outputs=[ + ir.Value(name=output_renaming[output.name]) for output in graph.outputs + ], + domain=function.domain, + name=self._qualify_node_name(function.name), + ) + outputs = node.outputs self.add_node(node) + self._functions[function.identifier()] = function + if _prefix: self.pop_module() return outputs if len(outputs) > 1 else outputs[0] @@ -684,15 +710,19 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) + def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: + return self._builder.functions + def call( self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", + _inline: bool = True, **kwargs, ): - """Call a function and inline it into the graph. + """Call a function and optionally inline it into the graph. Args: function: The function to call (ir.Function or onnxscript.OnnxFunction). @@ -700,11 +730,14 @@ def call( _outputs: Optional sequence of output names. If provided, must match the number of function outputs. _prefix: Optional prefix for module scoping (e.g., "layers.0"). + _inline: If True, the function body is inlined into the caller graph instead of being + called as a separate node. When False, the function will be added + to the ``.functions`` dictionary. Defaults to True. **kwargs: Keyword arguments to pass to the function. Returns: The output value(s) from the function call. """ return self._builder.call( - function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs + function, *args, _outputs=_outputs, _prefix=_prefix, _inline=_inline, **kwargs ) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index f6f301954b..710e61056c 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -848,6 +848,157 @@ def add_mul(X, Y): self.assertIn("does not match", str(cm.exception)) + def test_call_inline_false_creates_single_function_node(self): + """Test that _inline=False creates a single function call node instead of inlining.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y, _inline=False) + + # With _inline=False, only a single node should be created (the function call) + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + + node = nodes[0] + self.assertEqual(node.op_type, "mul_add_relu") + self.assertEqual(list(node.inputs), [x, y]) + + # The result should be a single ir.Value + self.assertIsInstance(result, ir.Value) + self.assertIs(result, node.outputs[0]) + + def test_call_inline_false_registers_function(self): + """Test that _inline=False registers the function in GraphBuilder.functions.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + op.call(simple_add, x, y, _inline=False) + + # The function should be registered + self.assertEqual(len(op.builder.functions), 1) + registered = next(iter(op.builder.functions.values())) + self.assertEqual(registered.name, "simple_add") + + def test_call_inline_true_does_not_register_function(self): + """Test that _inline=True (default) does not register the function.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + op.call(simple_add, x, y, _inline=True) + + # No function should be registered when inlining + self.assertEqual(len(op.builder.functions), 0) + + def test_call_inline_false_with_outputs_option(self): + """Test that _inline=False respects the _outputs option for renaming.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + result = op.call( + add_mul, x, y, _outputs=["sum_result", "product_result"], _inline=False + ) + + # The result should be a sequence of 2 ir.Values + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names + self.assertEqual(sum_result.name, "v_sum_result") + self.assertEqual(product_result.name, "v_product_result") + + # Only one node (the function call), not inlined + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "add_mul") + + def test_call_inline_false_with_prefix_option(self): + """Test that _inline=False respects the _prefix option for hierarchical naming.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y, _prefix="layer1", _inline=False) + + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + + # The node name should have the prefix + self.assertTrue( + nodes[0].name.startswith("layer1/"), + f"Node name {nodes[0].name} should start with layer1/", + ) + + self.assertIsInstance(result, ir.Value) + + def test_call_inline_false_via_op_builder(self): + """Test that _inline=False works when called through OpBuilder.call.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + # Call through OpBuilder (not GraphBuilder directly) + result = op.call(simple_add, x, y, _inline=False) + + # Should produce a single function call node + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "simple_add") + self.assertIsInstance(result, ir.Value) + + # Function should be registered + self.assertEqual(len(op.builder.functions), 1) + + def test_call_inline_true_produces_more_nodes_than_inline_false(self): + """Test that inlining produces individual op nodes while non-inlining produces one.""" + # Inline version + op1, x1, y1 = _create_builder_with_inputs() + + @script(default_opset=op1) + def mul_add(X, Y): + tmp = X * Y + return op1.Add(tmp, X) + + op1.call(mul_add, x1, y1, _inline=True) + inline_nodes = list(op1.builder.graph) + + # Non-inline version + op2, x2, y2 = _create_builder_with_inputs() + + @script(default_opset=op2) + def mul_add2(X, Y): + tmp = X * Y + return op2.Add(tmp, X) + + op2.call(mul_add2, x2, y2, _inline=False) + non_inline_nodes = list(op2.builder.graph) + + # Inlining should produce 2 nodes (Mul, Add), non-inlining should produce 1 + self.assertEqual(len(inline_nodes), 2) + self.assertEqual(len(non_inline_nodes), 1) + self.assertEqual(non_inline_nodes[0].op_type, "mul_add2") + class BuildSubgraphTest(unittest.TestCase): """Tests for GraphBuilder.subgraph().""" diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 051cb3e686..c3fd1b3dad 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -222,6 +222,10 @@ def __call__(self, *args, **kwargs): def name(self) -> str: return self._name + @property + def domain(self) -> str: + return self._opset.domain + @property def opset(self) -> Opset: return self._opset