Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __init__(self, graph: ir.Graph) -> None:
# and allows sharing them across different layers/contexts.
self._constant_cache: dict[tuple[Any, ir.DataType | None], ir.Value] = {}

self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}

def opset(self, domain: str, version: int = 1) -> OpBuilder:
"""Create an OpBuilder bound to the given domain and version."""
return OpBuilder(self, domain, version)
Expand All @@ -241,6 +243,10 @@ def op(self) -> OpBuilder:
def graph(self) -> ir.Graph:
return self._graph

@property
def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
return self._functions

def initializer(
self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True
) -> ir.Value:
Expand Down Expand Up @@ -543,16 +549,18 @@ def call_op(

def call(
self,
function,
function: ir.Function | onnxscript.OnnxFunction,
*args,
_outputs: Sequence[str] | None = None,
_prefix: str = "",
_inline: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we enforce that _prefix is "" if _inline is False?

**kwargs,
):
if isinstance(function, ir.Function):
graph = function.graph
elif isinstance(function, onnxscript.OnnxFunction):
graph = function.graph()
function = function.function_ir
else:
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
output_renaming: dict[str, str] = {}
Expand All @@ -567,18 +575,36 @@ def call(
else:
for output in graph.outputs:
output_renaming[output.name] = self._qualify_value_name(output.name)
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
if _prefix:
self.push_module(_prefix)
for node in nodes:
node.name = self._qualify_node_name(node.name)
for output in node.outputs:
if output.name:
if output.name in output_renaming:
output.name = output_renaming[output.name]
else:
output.name = self._qualify_value_name(output.name)

if _inline:
nodes, outputs = _inliner.instantiate(graph, args, kwargs)

for node in nodes:
node.name = self._qualify_node_name(node.name)
for output in node.outputs:
if output.name:
if output.name in output_renaming:
output.name = output_renaming[output.name]
else:
output.name = self._qualify_value_name(output.name)
self.add_node(node)
else:
node = ir.node(
op_type=function.name,
inputs=args,
attributes=kwargs or None,
outputs=[
ir.Value(name=output_renaming[output.name]) for output in graph.outputs
],
domain=function.domain,
name=self._qualify_node_name(function.name),
)
Comment on lines +594 to +603
outputs = node.outputs
Comment on lines +598 to +604
self.add_node(node)
self._functions[function.identifier()] = function

if _prefix:
self.pop_module()
return outputs if len(outputs) > 1 else outputs[0]
Expand Down Expand Up @@ -684,27 +710,34 @@ def __getattr__(self, op_type: str) -> Callable:
def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
return self._builder.initializer(tensor, name)

def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
return self._builder.functions

def call(
self,
function,
*args,
_outputs: Sequence[str] | None = None,
_prefix: str = "",
_inline: bool = True,
**kwargs,
):
"""Call a function and inline it into the graph.
"""Call a function and optionally inline it into the graph.

Args:
function: The function to call (ir.Function or onnxscript.OnnxFunction).
*args: Positional arguments to pass to the function.
_outputs: Optional sequence of output names. If provided, must match the
number of function outputs.
_prefix: Optional prefix for module scoping (e.g., "layers.0").
_inline: If True, the function body is inlined into the caller graph instead of being
called as a separate node. When False, the function will be added
to the ``.functions`` dictionary. Defaults to True.
Comment on lines +734 to +735
**kwargs: Keyword arguments to pass to the function.

Returns:
The output value(s) from the function call.
"""
return self._builder.call(
function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs
function, *args, _outputs=_outputs, _prefix=_prefix, _inline=_inline, **kwargs
)
151 changes: 151 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,157 @@ def add_mul(X, Y):

self.assertIn("does not match", str(cm.exception))

def test_call_inline_false_creates_single_function_node(self):
"""Test that _inline=False creates a single function call node instead of inlining."""
op, x, y = _create_builder_with_inputs()

@script(default_opset=op)
def mul_add_relu(X, Y):
tmp = X * Y
tmp = tmp + X
return op.Relu(tmp)

result = op.call(mul_add_relu, x, y, _inline=False)

# With _inline=False, only a single node should be created (the function call)
nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)

node = nodes[0]
self.assertEqual(node.op_type, "mul_add_relu")
self.assertEqual(list(node.inputs), [x, y])

# The result should be a single ir.Value
self.assertIsInstance(result, ir.Value)
self.assertIs(result, node.outputs[0])

def test_call_inline_false_registers_function(self):
"""Test that _inline=False registers the function in GraphBuilder.functions."""
op, x, y = _create_builder_with_inputs()

@script(default_opset=op)
def simple_add(X, Y):
return op.Add(X, Y)

op.call(simple_add, x, y, _inline=False)

# The function should be registered
self.assertEqual(len(op.builder.functions), 1)
registered = next(iter(op.builder.functions.values()))
self.assertEqual(registered.name, "simple_add")

def test_call_inline_true_does_not_register_function(self):
"""Test that _inline=True (default) does not register the function."""
op, x, y = _create_builder_with_inputs()

@script(default_opset=op)
def simple_add(X, Y):
return op.Add(X, Y)

op.call(simple_add, x, y, _inline=True)

# No function should be registered when inlining
self.assertEqual(len(op.builder.functions), 0)

def test_call_inline_false_with_outputs_option(self):
"""Test that _inline=False respects the _outputs option for renaming."""
op, x, y = _create_builder_with_inputs()

@script(default_opset=op)
def add_mul(X, Y):
a = X + Y
b = X * Y
return a, b

result = op.call(
add_mul, x, y, _outputs=["sum_result", "product_result"], _inline=False
)

# The result should be a sequence of 2 ir.Values
self.assertEqual(len(result), 2)
sum_result, product_result = result

# Verify output names
self.assertEqual(sum_result.name, "v_sum_result")
self.assertEqual(product_result.name, "v_product_result")

# Only one node (the function call), not inlined
nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)
self.assertEqual(nodes[0].op_type, "add_mul")

def test_call_inline_false_with_prefix_option(self):
"""Test that _inline=False respects the _prefix option for hierarchical naming."""
op, x, y = _create_builder_with_inputs()

@script(default_opset=op)
def mul_add_relu(X, Y):
tmp = X * Y
tmp = tmp + X
return op.Relu(tmp)

result = op.call(mul_add_relu, x, y, _prefix="layer1", _inline=False)

nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)

# The node name should have the prefix
self.assertTrue(
nodes[0].name.startswith("layer1/"),
f"Node name {nodes[0].name} should start with layer1/",
)

self.assertIsInstance(result, ir.Value)

def test_call_inline_false_via_op_builder(self):
"""Test that _inline=False works when called through OpBuilder.call."""
op, x, y = _create_builder_with_inputs()

@script(default_opset=op)
def simple_add(X, Y):
return op.Add(X, Y)

# Call through OpBuilder (not GraphBuilder directly)
result = op.call(simple_add, x, y, _inline=False)

# Should produce a single function call node
nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)
self.assertEqual(nodes[0].op_type, "simple_add")
self.assertIsInstance(result, ir.Value)

# Function should be registered
self.assertEqual(len(op.builder.functions), 1)

def test_call_inline_true_produces_more_nodes_than_inline_false(self):
"""Test that inlining produces individual op nodes while non-inlining produces one."""
# Inline version
op1, x1, y1 = _create_builder_with_inputs()

@script(default_opset=op1)
def mul_add(X, Y):
tmp = X * Y
return op1.Add(tmp, X)

op1.call(mul_add, x1, y1, _inline=True)
inline_nodes = list(op1.builder.graph)

# Non-inline version
op2, x2, y2 = _create_builder_with_inputs()

@script(default_opset=op2)
def mul_add2(X, Y):
tmp = X * Y
return op2.Add(tmp, X)

op2.call(mul_add2, x2, y2, _inline=False)
non_inline_nodes = list(op2.builder.graph)

# Inlining should produce 2 nodes (Mul, Add), non-inlining should produce 1
self.assertEqual(len(inline_nodes), 2)
self.assertEqual(len(non_inline_nodes), 1)
self.assertEqual(non_inline_nodes[0].op_type, "mul_add2")


class BuildSubgraphTest(unittest.TestCase):
"""Tests for GraphBuilder.subgraph()."""
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/_internal/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ def __call__(self, *args, **kwargs):
def name(self) -> str:
return self._name

@property
def domain(self) -> str:
return self._opset.domain

@property
def opset(self) -> Opset:
return self._opset
Expand Down
Loading