diff --git a/docs/tutorial/builder/graph_builder.md b/docs/tutorial/builder/graph_builder.md index 81b1e7df87..55bd83a90e 100644 --- a/docs/tutorial/builder/graph_builder.md +++ b/docs/tutorial/builder/graph_builder.md @@ -406,3 +406,110 @@ def build_linear(op, x, weight, bias_value): This pattern keeps function signatures simple while preserving access to the full builder API when needed. +## Calling Script Functions from OpBuilder + +The `OpBuilder` provides a `call()` method to inline `@script`-decorated ONNX functions directly into the builder's graph. This enables composition of both imperative (builder) and declarative (`@script`) code within a single graph. + +### Basic function inlining + +Define an ONNX script function and then call it through `op.call()`: + +```python +from onnxscript import script, opset23 as op23 + +# Define a reusable script function +@script(default_opset=op23) +def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op23.Relu(tmp) + +# Now build a graph using OpBuilder +graph = ir.Graph( + name="my_graph", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, +) +x = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4])) +y = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4])) +graph.inputs.extend([x, y]) + +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# Call the script function — it gets inlined into the graph +result = op.call(mul_add_relu, x, y) +graph.outputs.append(result) +``` + +The function body (three nodes: Mul, Add, Relu) is inlined directly into the graph. + +### Renaming outputs with `_outputs` + +By default, inlined function outputs keep their original names, qualified by the +current naming context. You can rename them explicitly with `_outputs`: + +```python +@script(default_opset=op23) +def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + +# Inline with custom output names +result_sum, result_prod = op.call( + add_mul, x, y, + _outputs=["custom_sum", "custom_product"] +) +``` + +### Adding hierarchical context with `_prefix` + +Use `_prefix` to add a naming context to all nodes and intermediate values created +by the inlined function: + +```python +result = op.call( + mul_add_relu, x, y, + _prefix="layer1" +) +# Node names will be "layer1.Mul_n...", "layer1.Add_n...", "layer1.Relu_n..." +# Intermediate value names will also start with "layer1." +``` + +You can combine both options: + +```python +result_a, result_b = op.call( + add_mul, x, y, + _outputs=["sum_out", "prod_out"], + _prefix="math_ops" +) +# Final outputs: "sum_out", "prod_out" (renamed before prefix context) +# Intermediate values: "math_ops.Add_n...", "math_ops.Mul_n..." (with prefix) +``` + +### Using OpBuilder as the default_opset + +`OpBuilder` can be passed directly as the `default_opset` when decorating a script +function. This enables scripted functions to use the same opset version as the +builder they will be inlined into: + +```python +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# Define the function *after* creating the builder, using op as default_opset +@script(default_opset=op) +def my_func(X, Y): + t = X + Y + return op.Relu(t) # Uses the op directly + +# Inline it +result = op.call(my_func, x, y) +``` + +This pattern ensures consistency: the script function operates in the same domain +and opset version as the builder. diff --git a/onnxscript/_internal/_inliner.py b/onnxscript/_internal/_inliner.py new file mode 100644 index 0000000000..f7e69c6fc0 --- /dev/null +++ b/onnxscript/_internal/_inliner.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Mapping, Sequence + +import onnx_ir as ir +from onnx_ir._cloner import Cloner + + +def instantiate( + function: ir.Function, + inputs: Sequence[ir.Value | None], + attributes: Mapping[str, ir.Attr], + *, + prefix: str = "", +) -> tuple[list[ir.Node], list[ir.Value | None]]: + """Instantiate (inline) a function, substituting inputs and attributes. + + Args: + function: The function to instantiate. + inputs: Actual input values to bind to the function's formal parameters. + attributes: Attribute values to substitute for reference attributes. + prefix: Optional prefix to prepend to node and output names. + + Returns: + A tuple of (nodes, outputs) where nodes are the cloned function body + and outputs are the values corresponding to the function's outputs. + """ + formal_inputs = function.inputs + if len(inputs) > len(formal_inputs): + raise ValueError( + f"Too many inputs: got {len(inputs)}, " + f"but function has {len(formal_inputs)} parameters." + ) + value_map: dict[ir.Value, ir.Value | None] = { + formal: actual for formal, actual in zip(formal_inputs, inputs) + } + + def rename(node: ir.Node) -> None: + if prefix: + node.name = prefix + (node.name or "") + for output in node.outputs: + if output is not None: + output.name = prefix + (output.name or "") + + cloner = Cloner( + attr_map=attributes, + value_map=value_map, + metadata_props={}, + post_process=rename, + resolve_ref_attrs=True, + ) + nodes = [cloner.clone_node(n) for n in function] + outputs = [value_map.get(v) for v in function.outputs] + return nodes, outputs + diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 0b626bd100..ca185a1028 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -9,6 +9,7 @@ import onnx_ir as ir import onnxscript._internal._inference as inference +from onnxscript._internal import _inliner as inliner import onnxscript.optimizer # A permissible value for an op input, which can be converted to an ir.Value. @@ -182,12 +183,13 @@ def _adapt_outputs( if isinstance(outputs, int): if outputs < 0: raise ValueError(f"Number of outputs must be non-negative, got {outputs}") + count = self.graph.num_nodes() if outputs == 1: - name = f"{op_type}_output" if op_type else "output" + name = f"{op_type}_n{count}_output" if op_type else f"n{count}_output" return [ir.Value(name=self.qualify_name(name))] else: names = [ - f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) + f"{op_type}_n{count}_output{i}" if op_type else f"n{count}_output{i}" for i in range(outputs) ] return [ir.Value(name=self.qualify_name(n)) for n in names] adapted_outputs = [] @@ -326,6 +328,42 @@ def call_op( return node.outputs if len(node.outputs) > 1 else node.outputs[0] + def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): + if isinstance(function, ir.Function): + function_ir = function + elif isinstance(function, onnxscript.OnnxFunction): + function_proto = function.to_function_proto() + function_ir = ir.serde.deserialize_function(function_proto) + else: + raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") + output_renaming: dict[str, str] = {} + if _outputs is not None: + if len(_outputs) != len(function_ir.outputs): + raise ValueError( + f"Number of provided output names {_outputs} does not match " + f"number of function outputs {len(function_ir.outputs)}." + ) + for output, name in zip(function_ir.outputs, _outputs): + output_renaming[output.name] = self.qualify_name(name) + else: + for output in function_ir.outputs: + output_renaming[output.name] = self.qualify_name(output.name) + nodes, outputs = inliner.instantiate(function_ir, args, kwargs) + if _prefix: + self.push_module(_prefix) + for node in nodes: + node.name = self.qualify_name(node.name) + for output in node.outputs: + if output.name: + if output.name in output_renaming: + output.name = output_renaming[output.name] + else: + output.name = self.qualify_name(output.name) + self.add_node(node) + if _prefix: + self.pop_module() + return outputs if len(outputs) > 1 else outputs[0] + def push_module(self, module: str) -> None: """Push a new naming context onto the stack (e.g. a layer or module name).""" current = self.context_name() @@ -365,6 +403,14 @@ def __init__( def builder(self) -> GraphBuilder: return self._builder + @property + def domain(self) -> str: + return self._domain + + @property + def version(self) -> int | None: + return self._version + def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): if "_domain" not in kwargs: kwargs["_domain"] = self._domain @@ -377,3 +423,6 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) + + def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): + return self._builder.call(function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 9edd0b68b4..ef33336157 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -9,6 +9,8 @@ import onnx_ir as ir import onnxscript._internal.builder as builder +from onnxscript import script + _default_opset_version = 23 @@ -191,12 +193,12 @@ def test_value_naming_with_ir_value_objects(self): self.assertIs(t3, out3) def test_default_output_naming_strategy(self): - """Test the default naming strategy for generated output values using op_type_output format.""" + """Test the default naming strategy for generated output values using op_type_nX_output format.""" def _ops_with_default_names( op: builder.OpBuilder, x: ir.Value, y: ir.Value ) -> ir.Value: - # Single output operations should be named {op_type}_output + # Single output operations should be named {op_type}_nX_output where X is node count t1 = op.Add(x, y) t2 = op.Mul(x, y) z = op.Add(t1, t2) @@ -214,14 +216,14 @@ def _ops_with_default_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names follow the {op_type}_output pattern for single outputs - self.assertEqual(nodes[0].outputs[0].name, "Add_output") - self.assertEqual(nodes[1].outputs[0].name, "Mul_output") - self.assertEqual(nodes[2].outputs[0].name, "Add_output") + # Check output names follow the {op_type}_nX_output pattern for single outputs + self.assertEqual(nodes[0].outputs[0].name, "Add_n0_output") + self.assertEqual(nodes[1].outputs[0].name, "Mul_n1_output") + self.assertEqual(nodes[2].outputs[0].name, "Add_n2_output") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "Add_output") + self.assertEqual(graph.outputs[0].name, "Add_n2_output") def test_hierarchical_naming(self): """Test the hierarchical naming strategy (for value and node names).""" @@ -229,35 +231,35 @@ def test_hierarchical_naming(self): # Test node and value naming at root level t1 = op.Add(x, y) - self.assertEqual(t1.name, "Add_output") + self.assertEqual(t1.name, "Add_n0_output") self.assertEqual(t1.producer().name, "Add_node_0") t2 = op.Mul(t1, y) - self.assertEqual(t2.name, "Mul_output") + self.assertEqual(t2.name, "Mul_n1_output") self.assertEqual(t2.producer().name, "Mul_node_1") # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1.Add_output") + self.assertEqual(t3.name, "layer1.Add_n2_output") self.assertEqual(t3.producer().name, "layer1.Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention.Mul_output") + self.assertEqual(t4.name, "layer1.attention.Mul_n3_output") self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1.Add_output") + self.assertEqual(t5.name, "layer1.Add_n4_output") self.assertEqual(t5.producer().name, "layer1.Add_node_4") # Pop back to root context op.builder.pop_module() t6 = op.Mul(t5, y) - self.assertEqual(t6.name, "Mul_output") + self.assertEqual(t6.name, "Mul_n5_output") self.assertEqual(t6.producer().name, "Mul_node_5") def test_shape_inference_add(self): @@ -275,6 +277,9 @@ def test_shape_inference_add(self): self.assertIsNotNone(result.shape) self.assertEqual(list(result.shape), [2, 3, 4]) + # Verify the default name uses the node count + self.assertEqual(result.name, "Add_n0_output") + def test_custom_domain_explicit(self): """Test using operations from custom domains with explicit _domain parameter.""" op, x, y = _create_builder_with_inputs() @@ -311,7 +316,7 @@ def test_custom_domain_with_version(self): # Verify output value is created self.assertIsNotNone(result) - self.assertEqual(result.name, "MicrosoftOp_output") + self.assertEqual(result.name, "MicrosoftOp_n0_output") def test_multiple_custom_domain_operations(self): """Test mixing operations from multiple domains.""" @@ -565,6 +570,167 @@ def test_attributes_are_created_properly(self): self.assertEqual(strs_attr.type, ir.AttributeType.STRINGS) self.assertEqual(list(strs_attr.value), ["a", "b", "c"]) + def test_call_inlines_onnxscript_function(self): + """Test that GraphBuilder.call inlines an @onnxscript.script function.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y) + + # The inlined function should produce 3 nodes: Mul, Add, Relu + nodes = list(op.builder.graph) + op_types = [n.op_type for n in nodes] + self.assertEqual(op_types, ["Mul", "Add", "Relu"]) + + # The result should be a single ir.Value (the Relu output) + self.assertIsInstance(result, ir.Value) + + # Verify connectivity: Relu takes the Add output + relu_node = nodes[2] + add_node = nodes[1] + self.assertIs(relu_node.inputs[0], add_node.outputs[0]) + + # Verify the Add takes the Mul output and original input x + mul_node = nodes[0] + self.assertIs(add_node.inputs[0], mul_node.outputs[0]) + self.assertIs(add_node.inputs[1], x) + + # Verify the Mul takes the original inputs x and y + self.assertIs(mul_node.inputs[0], x) + self.assertIs(mul_node.inputs[1], y) + + def test_call_with_outputs_option(self): + """Test that GraphBuilder.call respects the _outputs option for renaming.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) + + # The result should be a list of 2 ir.Values (when function returns multiple outputs) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names are correctly set + self.assertEqual(sum_result.name, "sum_result") + self.assertEqual(product_result.name, "product_result") + + # Verify the nodes were created correctly + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0].op_type, "Add") + self.assertEqual(nodes[1].op_type, "Mul") + + def test_call_with_prefix_option(self): + """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y, _prefix="layer1") + + # The nodes should have the prefix in their names + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 3) + + # Check that all node names start with the prefix + for node in nodes: + self.assertTrue(node.name.startswith("layer1."), f"Node name {node.name} should start with layer1.") + + # Verify the result is a single ir.Value + self.assertIsInstance(result, ir.Value) + + def test_call_with_outputs_and_prefix_options(self): + """Test that GraphBuilder.call respects both _outputs and _prefix options together. + + Note: _outputs names are set before the prefix context is applied, so they don't get + the prefix in their names. However, the inlined nodes do get the prefix applied, and + intermediate values (not renamed by _outputs) do get the prefix applied. + """ + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def add_mul(X, Y): + # Intermediate values that are not explicitly renamed by _outputs + XSquare = X * X + YSquare = Y * Y + # Final outputs that will be renamed by _outputs + a = XSquare + Y + b = XSquare * YSquare + return a, b + + result = op.call( + add_mul, x, y, + _outputs=["custom_sum", "custom_product"], + _prefix="math_ops" + ) + + # The result should be a list of 2 ir.Values + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names are set (without prefix, as _outputs renaming happens before prefix context) + self.assertEqual(sum_result.name, "custom_sum") + self.assertEqual(product_result.name, "custom_product") + + # Verify all nodes have the prefix applied to their names + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 4) # Mul (XSquare), Mul (YSquare), Add, Mul (final) + + # All node names should start with prefix + for node in nodes: + self.assertTrue(node.name.startswith("math_ops."), f"Node name {node.name} should start with math_ops.") + + # Verify intermediate value names also get the prefix + # The first Mul produces XSquare + x_square = nodes[0].outputs[0] + self.assertTrue(x_square.name.startswith("math_ops."), f"Intermediate value {x_square.name} should have prefix") + + # The second Mul produces YSquare + y_square = nodes[1].outputs[0] + self.assertTrue(y_square.name.startswith("math_ops."), f"Intermediate value {y_square.name} should have prefix") + + def test_call_outputs_mismatch_error(self): + """Test that GraphBuilder.call raises an error if _outputs has wrong count.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + # The function returns 2 outputs, but we provide only 1 name + with self.assertRaises(ValueError) as cm: + op.call(add_mul, x, y, _outputs=["only_one_name"]) + + self.assertIn("does not match", str(cm.exception)) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 6ebe5bda4a..921b1e0fe5 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -20,6 +20,7 @@ analysis, ast_utils, autocast, + builder, irbuilder, param_manipulation, sourceinfo, @@ -171,14 +172,18 @@ def __init__( opset: values.Opset | None = None, global_names: dict[str, Any] | None = None, source: str | None = None, - default_opset: values.Opset | None = None, + default_opset: Union[values.Opset, builder.OpBuilder, None] = None, ): self.source = source if global_names is not None: # We make a copy in case function eval modifies it. self.globals = global_names.copy() self.this_module = opset - self.default_opset_ = default_opset + # Convert OpBuilder to Opset if necessary and store the converted value + if isinstance(default_opset, builder.OpBuilder): + self.default_opset_ = values.Opset(default_opset.domain, default_opset.version) + else: + self.default_opset_ = default_opset # States initialized by `_init_function_translation` self._outer: list[irbuilder.IRFunction] = [] @@ -231,8 +236,13 @@ def _find_onnx_opset(self, node: ast.AST) -> values.Opset | None: if isinstance(opset_expr, ast.Name): if opset_expr.id in self.globals: opset = self.globals[opset_expr.id] - if isinstance(opset, values.Opset) and opset.domain == "": - return opset + # Accept both values.Opset and builder.OpBuilder + if isinstance(opset, values.Opset): + if opset.domain == "": + return opset + elif isinstance(opset, builder.OpBuilder): + if opset.domain == "": + return values.Opset(opset.domain, opset.version) for child in ast.iter_child_nodes(node): res = self._find_onnx_opset(child) if res is not None: @@ -954,12 +964,14 @@ def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset: val = self._lookup(node.id, self._source_of(node), raise_exception=False) if isinstance(val, values.Opset): return val + elif isinstance(val, builder.OpBuilder): + # Convert OpBuilder to Opset for compatibility + return values.Opset(val.domain, val.version) self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.") elif isinstance(node, ast.Attribute): self.fail(node, "Nested module unimplemented.") # TODO else: self.fail(node, "Invalid opset expression.") - # pylint: enable=inconsistent-return-statements def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable=R1710 """Return an Op"""