From 50ca21bff3fe533805a51c283968b634a6a8b575 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 19 Feb 2026 16:59:31 -0800 Subject: [PATCH 01/14] Add call function Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/_inliner.py | 147 +++++++++++++++++++++++++++++++ onnxscript/_internal/builder.py | 16 ++++ 2 files changed, 163 insertions(+) create mode 100644 onnxscript/_internal/_inliner.py diff --git a/onnxscript/_internal/_inliner.py b/onnxscript/_internal/_inliner.py new file mode 100644 index 0000000000..bf65d6460d --- /dev/null +++ b/onnxscript/_internal/_inliner.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Callable, Mapping, Sequence +import onnx_ir as ir + +class _CopyReplace: + """Utilities for creating a copy of IR objects with substitutions for attributes/input values.""" + + def __init__( + self, + attr_map: Mapping[str, ir.Attr], + value_map: dict[ir.Value, ir.Value | None], + metadata_props: dict[str, str], + post_process: Callable[[ir.Node], None], + ) -> None: + self._value_map = value_map + self._attr_map = attr_map + self._metadata_props = metadata_props + self._post_process = post_process + + def clone_value(self, value: ir.Value) -> ir.Value | None: + if value in self._value_map: + return self._value_map[value] + # If the value is not in the value map, it must be a graph input. + assert value.producer() is None, f"Value {value} has no entry in the value map" + new_value = ir.Value( + name=value.name, + type=value.type, + shape=value.shape, + doc_string=value.doc_string, + const_value=value.const_value, + ) + self._value_map[value] = new_value + return new_value + + def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None: + if value is None: + return None + return self.clone_value(value) + + def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None: + if not attr.is_ref(): + if attr.type == ir.AttributeType.GRAPH: + graph = self.clone_graph(attr.as_graph()) + return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string) + elif attr.type == ir.AttributeType.GRAPHS: + graphs = [self.clone_graph(graph) for graph in attr.as_graphs()] + return ir.Attr( + key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string + ) + return attr + assert attr.is_ref() + ref_attr_name = attr.ref_attr_name + assert ref_attr_name is not None, "Reference attribute must have a name" + if ref_attr_name in self._attr_map: + ref_attr = self._attr_map[ref_attr_name] + if not ref_attr.is_ref(): + return ir.Attr( + key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string + ) + assert ref_attr.ref_attr_name is not None + return ir.RefAttr( + key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string + ) + # Note that if a function has an attribute-parameter X, and a call (node) to the function + # has no attribute X, all references to X in nodes inside the function body will be + # removed. This is just the ONNX representation of optional-attributes. + return None + + def clone_node(self, node: ir.Node) -> ir.Node: + new_inputs = [self.clone_optional_value(input) for input in node.inputs] + new_attributes = [ + new_value + for key, value in node.attributes.items() + if (new_value := self.clone_attr(key, value)) is not None + ] + + new_metadata = {**self._metadata_props, **node.metadata_props} + # TODO: For now, node metadata overrides callnode metadata if there is a conflict. + # Do we need to preserve both? + + new_node = ir.Node( + node.domain, + node.op_type, + new_inputs, + new_attributes, + overload=node.overload, + num_outputs=len(node.outputs), + graph=None, + name=node.name, + doc_string=node.doc_string, # type: ignore + metadata_props=new_metadata, + ) + new_outputs = new_node.outputs + for i, output in enumerate(node.outputs): + self._value_map[output] = new_outputs[i] + new_outputs[i].name = output.name if output.name is not None else f"output_{i}" + + self._post_process(new_node) + return new_node + + def clone_graph(self, graph: ir.Graph) -> ir.Graph: + input_values = [self.clone_value(v) for v in graph.inputs] + nodes = [self.clone_node(node) for node in graph] + initializers = [self.clone_value(init) for init in graph.initializers.values()] + output_values = [ + self.clone_value(v) for v in graph.outputs + ] # Looks up already cloned values + + return ir.Graph( + input_values, # type: ignore + output_values, # type: ignore + nodes=nodes, + initializers=initializers, # type: ignore + doc_string=graph.doc_string, + opset_imports=graph.opset_imports, + name=graph.name, + metadata_props=graph.metadata_props, + ) + +def instantiate ( + function: ir.Function, + inputs: Sequence[ir.Value | None], + attributes: Mapping[str, ir.Attr], + *, + prefix: str = "" +) -> tuple[list[ir.Node], list[ir.Value | None]]: + formal_inputs = function.inputs + if len(inputs) > len(formal_inputs): + raise ValueError("") + value_map = { + formal: actual for (formal, actual) in zip(formal_inputs, inputs) + } + def rename(node: ir.Node): + if prefix != "": + node.name = prefix + node.name + for output in node.outputs: + if output is not None: + output.name = prefix + output.name + cloner = _CopyReplace(attributes, value_map, {}, post_process=rename) + nodes = [cloner.clone_node(n) for n in function] + outputs = [value_map.get(v) for v in function.outputs] + return nodes, outputs + diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 0b626bd100..d00606f83d 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -326,6 +326,19 @@ def call_op( return node.outputs if len(node.outputs) > 1 else node.outputs[0] + def call(self, function, *args, **kwargs): + if isinstance(function, ir.Function): + function_ir = function + elif isinstance(function, onnxscript.values.OnnxFunction): + function_proto = function.to_function_proto() + function_ir = ir.serde.deserialize_function(function_proto) + else: + raise TypeError("Function must be an ir.Function or onnxscript.ONNXFunction") + nodes, outputs = inliner.instantiate(function_ir, args, kwargs) + for node in nodes: + self.add_node(node) + return outputs if len(outputs) > 1 else outputs[0] + def push_module(self, module: str) -> None: """Push a new naming context onto the stack (e.g. a layer or module name).""" current = self.context_name() @@ -377,3 +390,6 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) + + def call(self, function, *args, **kwargs): + return self._builder.call(function, *args, **kwargs) From 255e74787a00d93dd067cf861281140d6e7c0792 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 19 Feb 2026 17:09:32 -0800 Subject: [PATCH 02/14] Add test case Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/builder.py | 1 + onnxscript/_internal/builder_test.py | 42 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index d00606f83d..9735893556 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -9,6 +9,7 @@ import onnx_ir as ir import onnxscript._internal._inference as inference +import onnxscript._internal._inliner as inliner import onnxscript.optimizer # A permissible value for an op input, which can be converted to an ir.Value. diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 9edd0b68b4..0c2dbcbe9f 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -8,7 +8,11 @@ import onnx_ir as ir +import onnxscript import onnxscript._internal.builder as builder +from onnxscript import script +from onnxscript.onnx_opset import opset23 as op23 + _default_opset_version = 23 @@ -565,6 +569,44 @@ def test_attributes_are_created_properly(self): self.assertEqual(strs_attr.type, ir.AttributeType.STRINGS) self.assertEqual(list(strs_attr.value), ["a", "b", "c"]) + def test_call_inlines_onnxscript_function(self): + """Test that GraphBuilder.call inlines an @onnxscript.script function.""" + + @script(default_opset=op23) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op23.Relu(tmp) + + # Verify we got an OnnxFunction + self.assertIsInstance(mul_add_relu, onnxscript.values.OnnxFunction) + + # Create a GraphBuilder and call the function + op, x, y = _create_builder_with_inputs() + result = op.call(mul_add_relu, x, y) + + # The inlined function should produce 3 nodes: Mul, Add, Relu + nodes = list(op.builder.graph) + op_types = [n.op_type for n in nodes] + self.assertEqual(op_types, ["Mul", "Add", "Relu"]) + + # The result should be a single ir.Value (the Relu output) + self.assertIsInstance(result, ir.Value) + + # Verify connectivity: Relu takes the Add output + relu_node = nodes[2] + add_node = nodes[1] + self.assertIs(relu_node.inputs[0], add_node.outputs[0]) + + # Verify the Add takes the Mul output and original input x + mul_node = nodes[0] + self.assertIs(add_node.inputs[0], mul_node.outputs[0]) + self.assertIs(add_node.inputs[1], x) + + # Verify the Mul takes the original inputs x and y + self.assertIs(mul_node.inputs[0], x) + self.assertIs(mul_node.inputs[1], y) + if __name__ == "__main__": unittest.main() From ffddfd7aaee41273183c136263a4a3883388fb2e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 19 Feb 2026 17:18:22 -0800 Subject: [PATCH 03/14] Cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/_inliner.py | 161 ++++++--------------------- onnxscript/_internal/builder_test.py | 7 +- 2 files changed, 37 insertions(+), 131 deletions(-) diff --git a/onnxscript/_internal/_inliner.py b/onnxscript/_internal/_inliner.py index bf65d6460d..bef0fc85c4 100644 --- a/onnxscript/_internal/_inliner.py +++ b/onnxscript/_internal/_inliner.py @@ -3,144 +3,55 @@ from __future__ import annotations -from typing import Callable, Mapping, Sequence -import onnx_ir as ir - -class _CopyReplace: - """Utilities for creating a copy of IR objects with substitutions for attributes/input values.""" - - def __init__( - self, - attr_map: Mapping[str, ir.Attr], - value_map: dict[ir.Value, ir.Value | None], - metadata_props: dict[str, str], - post_process: Callable[[ir.Node], None], - ) -> None: - self._value_map = value_map - self._attr_map = attr_map - self._metadata_props = metadata_props - self._post_process = post_process - - def clone_value(self, value: ir.Value) -> ir.Value | None: - if value in self._value_map: - return self._value_map[value] - # If the value is not in the value map, it must be a graph input. - assert value.producer() is None, f"Value {value} has no entry in the value map" - new_value = ir.Value( - name=value.name, - type=value.type, - shape=value.shape, - doc_string=value.doc_string, - const_value=value.const_value, - ) - self._value_map[value] = new_value - return new_value - - def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None: - if value is None: - return None - return self.clone_value(value) - - def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None: - if not attr.is_ref(): - if attr.type == ir.AttributeType.GRAPH: - graph = self.clone_graph(attr.as_graph()) - return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string) - elif attr.type == ir.AttributeType.GRAPHS: - graphs = [self.clone_graph(graph) for graph in attr.as_graphs()] - return ir.Attr( - key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string - ) - return attr - assert attr.is_ref() - ref_attr_name = attr.ref_attr_name - assert ref_attr_name is not None, "Reference attribute must have a name" - if ref_attr_name in self._attr_map: - ref_attr = self._attr_map[ref_attr_name] - if not ref_attr.is_ref(): - return ir.Attr( - key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string - ) - assert ref_attr.ref_attr_name is not None - return ir.RefAttr( - key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string - ) - # Note that if a function has an attribute-parameter X, and a call (node) to the function - # has no attribute X, all references to X in nodes inside the function body will be - # removed. This is just the ONNX representation of optional-attributes. - return None - - def clone_node(self, node: ir.Node) -> ir.Node: - new_inputs = [self.clone_optional_value(input) for input in node.inputs] - new_attributes = [ - new_value - for key, value in node.attributes.items() - if (new_value := self.clone_attr(key, value)) is not None - ] +from typing import Mapping, Sequence - new_metadata = {**self._metadata_props, **node.metadata_props} - # TODO: For now, node metadata overrides callnode metadata if there is a conflict. - # Do we need to preserve both? - - new_node = ir.Node( - node.domain, - node.op_type, - new_inputs, - new_attributes, - overload=node.overload, - num_outputs=len(node.outputs), - graph=None, - name=node.name, - doc_string=node.doc_string, # type: ignore - metadata_props=new_metadata, - ) - new_outputs = new_node.outputs - for i, output in enumerate(node.outputs): - self._value_map[output] = new_outputs[i] - new_outputs[i].name = output.name if output.name is not None else f"output_{i}" - - self._post_process(new_node) - return new_node - - def clone_graph(self, graph: ir.Graph) -> ir.Graph: - input_values = [self.clone_value(v) for v in graph.inputs] - nodes = [self.clone_node(node) for node in graph] - initializers = [self.clone_value(init) for init in graph.initializers.values()] - output_values = [ - self.clone_value(v) for v in graph.outputs - ] # Looks up already cloned values +import onnx_ir as ir +from onnx_ir._cloner import Cloner - return ir.Graph( - input_values, # type: ignore - output_values, # type: ignore - nodes=nodes, - initializers=initializers, # type: ignore - doc_string=graph.doc_string, - opset_imports=graph.opset_imports, - name=graph.name, - metadata_props=graph.metadata_props, - ) -def instantiate ( +def instantiate( function: ir.Function, inputs: Sequence[ir.Value | None], attributes: Mapping[str, ir.Attr], *, - prefix: str = "" + prefix: str = "", ) -> tuple[list[ir.Node], list[ir.Value | None]]: + """Instantiate (inline) a function, substituting inputs and attributes. + + Args: + function: The function to instantiate. + inputs: Actual input values to bind to the function's formal parameters. + attributes: Attribute values to substitute for reference attributes. + prefix: Optional prefix to prepend to node and output names. + + Returns: + A tuple of (nodes, outputs) where nodes are the cloned function body + and outputs are the values corresponding to the function's outputs. + """ formal_inputs = function.inputs if len(inputs) > len(formal_inputs): - raise ValueError("") - value_map = { - formal: actual for (formal, actual) in zip(formal_inputs, inputs) + raise ValueError( + f"Too many inputs: got {len(inputs)}, " + f"but function has {len(formal_inputs)} parameters." + ) + value_map: dict[ir.Value, ir.Value | None] = { + formal: actual for formal, actual in zip(formal_inputs, inputs) } - def rename(node: ir.Node): - if prefix != "": - node.name = prefix + node.name + + def rename(node: ir.Node) -> None: + if prefix: + node.name = prefix + (node.name or "") for output in node.outputs: if output is not None: - output.name = prefix + output.name - cloner = _CopyReplace(attributes, value_map, {}, post_process=rename) + output.name = prefix + (output.name or "") + + cloner = Cloner( + attr_map=attributes, + value_map=value_map, + metadata_props={}, + post_process=rename, + resolve_ref_attrs=True, + ) nodes = [cloner.clone_node(n) for n in function] outputs = [value_map.get(v) for v in function.outputs] return nodes, outputs diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 0c2dbcbe9f..12a8556673 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -8,10 +8,8 @@ import onnx_ir as ir -import onnxscript import onnxscript._internal.builder as builder -from onnxscript import script -from onnxscript.onnx_opset import opset23 as op23 +from onnxscript import opset23 as op23, script _default_opset_version = 23 @@ -578,9 +576,6 @@ def mul_add_relu(X, Y): tmp = tmp + X return op23.Relu(tmp) - # Verify we got an OnnxFunction - self.assertIsInstance(mul_add_relu, onnxscript.values.OnnxFunction) - # Create a GraphBuilder and call the function op, x, y = _create_builder_with_inputs() result = op.call(mul_add_relu, x, y) From 2786955baf6997596ce9f8b755c1b9b0106a24bf Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Feb 2026 10:13:26 -0800 Subject: [PATCH 04/14] Add naming options to function call Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/builder.py | 29 ++++++- onnxscript/_internal/builder_test.py | 122 +++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 9735893556..531190c366 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -327,7 +327,7 @@ def call_op( return node.outputs if len(node.outputs) > 1 else node.outputs[0] - def call(self, function, *args, **kwargs): + def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): if isinstance(function, ir.Function): function_ir = function elif isinstance(function, onnxscript.values.OnnxFunction): @@ -335,9 +335,32 @@ def call(self, function, *args, **kwargs): function_ir = ir.serde.deserialize_function(function_proto) else: raise TypeError("Function must be an ir.Function or onnxscript.ONNXFunction") + output_renaming : dict[str, str] = {} + if _outputs is not None: + if len(_outputs) != len(function_ir.outputs): + raise ValueError( + f"Number of provided output names {_outputs} does not match " + f"number of function outputs {len(function_ir.outputs)}." + ) + for output, name in zip(function_ir.outputs, _outputs): + output_renaming[output.name] = self.qualify_name(name) + else: + for output in function_ir.outputs: + output_renaming[output.name] = self.qualify_name(output.name) nodes, outputs = inliner.instantiate(function_ir, args, kwargs) + if _prefix: + self.push_module(_prefix) for node in nodes: + node.name = self.qualify_name(node.name) + for output in node.outputs: + if output.name: + if output.name in output_renaming: + output.name = output_renaming[output.name] + else: + output.name = self.qualify_name(output.name) self.add_node(node) + if _prefix: + self.pop_module() return outputs if len(outputs) > 1 else outputs[0] def push_module(self, module: str) -> None: @@ -392,5 +415,5 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) - def call(self, function, *args, **kwargs): - return self._builder.call(function, *args, **kwargs) + def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): + return self._builder.call(function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 12a8556673..54b223d69f 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -602,6 +602,128 @@ def mul_add_relu(X, Y): self.assertIs(mul_node.inputs[0], x) self.assertIs(mul_node.inputs[1], y) + def test_call_with_outputs_option(self): + """Test that GraphBuilder.call respects the _outputs option for renaming.""" + + @script(default_opset=op23) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + # Create a GraphBuilder and call the function with custom output names + op, x, y = _create_builder_with_inputs() + result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) + + # The result should be a list of 2 ir.Values (when function returns multiple outputs) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names are correctly set + self.assertEqual(sum_result.name, "sum_result") + self.assertEqual(product_result.name, "product_result") + + # Verify the nodes were created correctly + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0].op_type, "Add") + self.assertEqual(nodes[1].op_type, "Mul") + + def test_call_with_prefix_option(self): + """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + + @script(default_opset=op23) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op23.Relu(tmp) + + # Create a GraphBuilder and call the function with a prefix + op, x, y = _create_builder_with_inputs() + result = op.call(mul_add_relu, x, y, _prefix="layer1") + + # The nodes should have the prefix in their names + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 3) + + # Check that all node names start with the prefix + for node in nodes: + self.assertTrue(node.name.startswith("layer1."), f"Node name {node.name} should start with layer1.") + + # Verify the result is a single ir.Value + self.assertIsInstance(result, ir.Value) + + def test_call_with_outputs_and_prefix_options(self): + """Test that GraphBuilder.call respects both _outputs and _prefix options together. + + Note: _outputs names are set before the prefix context is applied, so they don't get + the prefix in their names. However, the inlined nodes do get the prefix applied, and + intermediate values (not renamed by _outputs) do get the prefix applied. + """ + + @script(default_opset=op23) + def add_mul(X, Y): + # Intermediate values that are not explicitly renamed by _outputs + XSquare = X * X + YSquare = Y * Y + # Final outputs that will be renamed by _outputs + a = XSquare + Y + b = XSquare * YSquare + return a, b + + # Create a GraphBuilder and call the function with both options + op, x, y = _create_builder_with_inputs() + result = op.call( + add_mul, x, y, + _outputs=["custom_sum", "custom_product"], + _prefix="math_ops" + ) + + # The result should be a list of 2 ir.Values + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names are set (without prefix, as _outputs renaming happens before prefix context) + self.assertEqual(sum_result.name, "custom_sum") + self.assertEqual(product_result.name, "custom_product") + + # Verify all nodes have the prefix applied to their names + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 4) # Mul (XSquare), Mul (YSquare), Add, Mul (final) + + # All node names should start with prefix + for node in nodes: + self.assertTrue(node.name.startswith("math_ops."), f"Node name {node.name} should start with math_ops.") + + # Verify intermediate value names also get the prefix + # The first Mul produces XSquare + x_square = nodes[0].outputs[0] + self.assertTrue(x_square.name.startswith("math_ops."), f"Intermediate value {x_square.name} should have prefix") + + # The second Mul produces YSquare + y_square = nodes[1].outputs[0] + self.assertTrue(y_square.name.startswith("math_ops."), f"Intermediate value {y_square.name} should have prefix") + + def test_call_outputs_mismatch_error(self): + """Test that GraphBuilder.call raises an error if _outputs has wrong count.""" + + @script(default_opset=op23) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + # Create a GraphBuilder and try to call with wrong number of output names + op, x, y = _create_builder_with_inputs() + + # The function returns 2 outputs, but we provide only 1 name + with self.assertRaises(ValueError) as cm: + result = op.call(add_mul, x, y, _outputs=["only_one_name"]) + + self.assertIn("does not match", str(cm.exception)) + if __name__ == "__main__": unittest.main() From 52368016a896b6078a09ac18788d4defeeb6c94a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Feb 2026 14:04:58 -0800 Subject: [PATCH 05/14] Support OpBuilder in converter Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/builder.py | 13 ++++- onnxscript/_internal/builder_test.py | 71 +++++++++++++++------------- onnxscript/_internal/converter.py | 22 +++++++-- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 531190c366..b82c92b450 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -183,12 +183,13 @@ def _adapt_outputs( if isinstance(outputs, int): if outputs < 0: raise ValueError(f"Number of outputs must be non-negative, got {outputs}") + count = self.graph.num_nodes() if outputs == 1: - name = f"{op_type}_output" if op_type else "output" + name = f"{op_type}_n{count}_output" if op_type else f"n{count}_output" return [ir.Value(name=self.qualify_name(name))] else: names = [ - f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) + f"{op_type}_n{count}_output{i}" if op_type else f"n{count}_output{i}" for i in range(outputs) ] return [ir.Value(name=self.qualify_name(n)) for n in names] adapted_outputs = [] @@ -402,6 +403,14 @@ def __init__( def builder(self) -> GraphBuilder: return self._builder + @property + def domain(self) -> str: + return self._domain + + @property + def version(self) -> int | None: + return self._version + def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): if "_domain" not in kwargs: kwargs["_domain"] = self._domain diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 54b223d69f..b0472992d4 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -193,12 +193,12 @@ def test_value_naming_with_ir_value_objects(self): self.assertIs(t3, out3) def test_default_output_naming_strategy(self): - """Test the default naming strategy for generated output values using op_type_output format.""" + """Test the default naming strategy for generated output values using op_type_nX_output format.""" def _ops_with_default_names( op: builder.OpBuilder, x: ir.Value, y: ir.Value ) -> ir.Value: - # Single output operations should be named {op_type}_output + # Single output operations should be named {op_type}_nX_output where X is node count t1 = op.Add(x, y) t2 = op.Mul(x, y) z = op.Add(t1, t2) @@ -216,14 +216,14 @@ def _ops_with_default_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names follow the {op_type}_output pattern for single outputs - self.assertEqual(nodes[0].outputs[0].name, "Add_output") - self.assertEqual(nodes[1].outputs[0].name, "Mul_output") - self.assertEqual(nodes[2].outputs[0].name, "Add_output") + # Check output names follow the {op_type}_nX_output pattern for single outputs + self.assertEqual(nodes[0].outputs[0].name, "Add_n0_output") + self.assertEqual(nodes[1].outputs[0].name, "Mul_n1_output") + self.assertEqual(nodes[2].outputs[0].name, "Add_n2_output") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "Add_output") + self.assertEqual(graph.outputs[0].name, "Add_n2_output") def test_hierarchical_naming(self): """Test the hierarchical naming strategy (for value and node names).""" @@ -231,35 +231,35 @@ def test_hierarchical_naming(self): # Test node and value naming at root level t1 = op.Add(x, y) - self.assertEqual(t1.name, "Add_output") + self.assertEqual(t1.name, "Add_n0_output") self.assertEqual(t1.producer().name, "Add_node_0") t2 = op.Mul(t1, y) - self.assertEqual(t2.name, "Mul_output") + self.assertEqual(t2.name, "Mul_n1_output") self.assertEqual(t2.producer().name, "Mul_node_1") # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1.Add_output") + self.assertEqual(t3.name, "layer1.Add_n2_output") self.assertEqual(t3.producer().name, "layer1.Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention.Mul_output") + self.assertEqual(t4.name, "layer1.attention.Mul_n3_output") self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1.Add_output") + self.assertEqual(t5.name, "layer1.Add_n4_output") self.assertEqual(t5.producer().name, "layer1.Add_node_4") # Pop back to root context op.builder.pop_module() t6 = op.Mul(t5, y) - self.assertEqual(t6.name, "Mul_output") + self.assertEqual(t6.name, "Mul_n5_output") self.assertEqual(t6.producer().name, "Mul_node_5") def test_shape_inference_add(self): @@ -276,6 +276,9 @@ def test_shape_inference_add(self): # Verify output shape is inferred correctly self.assertIsNotNone(result.shape) self.assertEqual(list(result.shape), [2, 3, 4]) + + # Verify the default name uses the node count + self.assertEqual(result.name, "Add_n0_output") def test_custom_domain_explicit(self): """Test using operations from custom domains with explicit _domain parameter.""" @@ -313,7 +316,7 @@ def test_custom_domain_with_version(self): # Verify output value is created self.assertIsNotNone(result) - self.assertEqual(result.name, "MicrosoftOp_output") + self.assertEqual(result.name, "MicrosoftOp_n0_output") def test_multiple_custom_domain_operations(self): """Test mixing operations from multiple domains.""" @@ -569,15 +572,16 @@ def test_attributes_are_created_properly(self): def test_call_inlines_onnxscript_function(self): """Test that GraphBuilder.call inlines an @onnxscript.script function.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() - @script(default_opset=op23) + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) def mul_add_relu(X, Y): tmp = X * Y tmp = tmp + X - return op23.Relu(tmp) + return op.Relu(tmp) - # Create a GraphBuilder and call the function - op, x, y = _create_builder_with_inputs() result = op.call(mul_add_relu, x, y) # The inlined function should produce 3 nodes: Mul, Add, Relu @@ -604,15 +608,16 @@ def mul_add_relu(X, Y): def test_call_with_outputs_option(self): """Test that GraphBuilder.call respects the _outputs option for renaming.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() - @script(default_opset=op23) + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) def add_mul(X, Y): a = X + Y b = X * Y return a, b - # Create a GraphBuilder and call the function with custom output names - op, x, y = _create_builder_with_inputs() result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) # The result should be a list of 2 ir.Values (when function returns multiple outputs) @@ -632,15 +637,16 @@ def add_mul(X, Y): def test_call_with_prefix_option(self): """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() - @script(default_opset=op23) + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) def mul_add_relu(X, Y): tmp = X * Y tmp = tmp + X - return op23.Relu(tmp) + return op.Relu(tmp) - # Create a GraphBuilder and call the function with a prefix - op, x, y = _create_builder_with_inputs() result = op.call(mul_add_relu, x, y, _prefix="layer1") # The nodes should have the prefix in their names @@ -661,8 +667,11 @@ def test_call_with_outputs_and_prefix_options(self): the prefix in their names. However, the inlined nodes do get the prefix applied, and intermediate values (not renamed by _outputs) do get the prefix applied. """ + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() - @script(default_opset=op23) + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) def add_mul(X, Y): # Intermediate values that are not explicitly renamed by _outputs XSquare = X * X @@ -672,8 +681,6 @@ def add_mul(X, Y): b = XSquare * YSquare return a, b - # Create a GraphBuilder and call the function with both options - op, x, y = _create_builder_with_inputs() result = op.call( add_mul, x, y, _outputs=["custom_sum", "custom_product"], @@ -708,16 +715,16 @@ def add_mul(X, Y): def test_call_outputs_mismatch_error(self): """Test that GraphBuilder.call raises an error if _outputs has wrong count.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() - @script(default_opset=op23) + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) def add_mul(X, Y): a = X + Y b = X * Y return a, b - # Create a GraphBuilder and try to call with wrong number of output names - op, x, y = _create_builder_with_inputs() - # The function returns 2 outputs, but we provide only 1 name with self.assertRaises(ValueError) as cm: result = op.call(add_mul, x, y, _outputs=["only_one_name"]) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 6ebe5bda4a..aeb7e2df87 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -20,6 +20,7 @@ analysis, ast_utils, autocast, + builder, irbuilder, param_manipulation, sourceinfo, @@ -171,14 +172,18 @@ def __init__( opset: values.Opset | None = None, global_names: dict[str, Any] | None = None, source: str | None = None, - default_opset: values.Opset | None = None, + default_opset: Union[values.Opset, builder.OpBuilder, None] = None, ): self.source = source if global_names is not None: # We make a copy in case function eval modifies it. self.globals = global_names.copy() self.this_module = opset - self.default_opset_ = default_opset + # Convert OpBuilder to Opset if necessary and store the converted value + if isinstance(default_opset, builder.OpBuilder): + self.default_opset_ = values.Opset(default_opset.domain, default_opset.version) + else: + self.default_opset_ = default_opset # States initialized by `_init_function_translation` self._outer: list[irbuilder.IRFunction] = [] @@ -231,8 +236,13 @@ def _find_onnx_opset(self, node: ast.AST) -> values.Opset | None: if isinstance(opset_expr, ast.Name): if opset_expr.id in self.globals: opset = self.globals[opset_expr.id] - if isinstance(opset, values.Opset) and opset.domain == "": - return opset + # Accept both values.Opset and builder.OpBuilder + if isinstance(opset, values.Opset): + if opset.domain == "": + return opset + elif isinstance(opset, builder.OpBuilder): + if opset.domain == "": + return opset for child in ast.iter_child_nodes(node): res = self._find_onnx_opset(child) if res is not None: @@ -954,12 +964,14 @@ def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset: val = self._lookup(node.id, self._source_of(node), raise_exception=False) if isinstance(val, values.Opset): return val + elif isinstance(val, builder.OpBuilder): + # Convert OpBuilder to Opset for compatibility + return values.Opset(val.domain, val.version) self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.") elif isinstance(node, ast.Attribute): self.fail(node, "Nested module unimplemented.") # TODO else: self.fail(node, "Invalid opset expression.") - # pylint: enable=inconsistent-return-statements def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable=R1710 """Return an Op""" From f66a197faf61dce3d6ad681705e62cec7cf0dd55 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Feb 2026 14:21:36 -0800 Subject: [PATCH 06/14] Update documentation Signed-off-by: Ganesan Ramalingam --- docs/tutorial/builder/graph_builder.md | 107 +++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/docs/tutorial/builder/graph_builder.md b/docs/tutorial/builder/graph_builder.md index 81b1e7df87..cfbeba1357 100644 --- a/docs/tutorial/builder/graph_builder.md +++ b/docs/tutorial/builder/graph_builder.md @@ -406,3 +406,110 @@ def build_linear(op, x, weight, bias_value): This pattern keeps function signatures simple while preserving access to the full builder API when needed. +## Calling Script Functions from OpBuilder + +The `OpBuilder` provides a `call()` method to inline `@script`-decorated ONNX functions directly into the builder's graph. This enables composition of both imperative (builder) and declarative (`@script`) code within a single graph. + +### Basic function inlining + +Define an ONNX script function and then call it through `op.call()`: + +```python +from onnxscript import script, opset23 as op23 + +# Define a reusable script function +@script(default_opset=op23) +def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op23.Relu(tmp) + +# Now build a graph using OpBuilder +graph = ir.Graph( + name="my_graph", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, +) +x = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4])) +y = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4])) +graph.inputs.extend([x, y]) + +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# Call the script function — it gets inlined into the graph +result = op.call(mul_add_relu, x, y) +graph.outputs.append(result) +``` + +The function body (three nodes: Mul, Add, Relu) is inlined directly into the graph. + +### Renaming outputs with `_outputs` + +By default, inlined function outputs keep their original names, qualified by the +current naming context. You can rename them explicitly with `_outputs`: + +```python +@script(default_opset=op23) +def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + +# Inline with custom output names +result_sum, result_prod = op.call( + add_mul, x, y, + _outputs=["custom_sum", "custom_product"] +) +``` + +### Adding hierarchical context with `_prefix` + +Use `_prefix` to add a naming context to all nodes and intermediate values created +by the inlined function: + +```python +result = op.call( + mul_add_relu, x, y, + _prefix="layer1" +) +# Node names will be "layer1.Mul_n...", "layer1.Add_n...", "layer1.Relu_n..." +# Intermediate value names will also start with "layer1." +``` + +You can combine both options: + +```python +result_a, result_b = op.call( + add_mul, x, y, + _outputs=["sum_out", "prod_out"], + _prefix="math_ops" +) +# Final outputs: "sum_out", "prod_out" (renamed before prefix context) +# Intermediate values: "math_ops.Add_n...", "math_ops.Mul_n..." (with prefix) +``` + +### Using OpBuilder as the default_opset + +`OpBuilder` can be passed directly as the `default_opset` when decorating a script +function. This enables scripted functions to use the same opset version as the +builder they will be inlined into: + +```python +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# Define the function *after* creating the builder, using op as default_opset +@script(default_opset=op) +def my_func(X, Y): + t = X + Y + return op.Relu(t) # Uses the op directly + +# Inline it +result = op.call(my_func, x, y) +``` + +This pattern ensures consistency: the script function operates in the same domain +and opset version as the builder. \ No newline at end of file From f05335761513fc6001b321c35883326879ff5eea Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 20 Feb 2026 16:58:12 -0800 Subject: [PATCH 07/14] Potential fix for code scanning alert no. 22056: Unused import Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- onnxscript/_internal/builder_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index b0472992d4..eea474bc6e 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -9,7 +9,7 @@ import onnx_ir as ir import onnxscript._internal.builder as builder -from onnxscript import opset23 as op23, script +from onnxscript import script _default_opset_version = 23 From 8d8b2ea1c321228c59233510350f4a10db3a7570 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 20 Feb 2026 16:58:40 -0800 Subject: [PATCH 08/14] Potential fix for code scanning alert no. 22055: Unused local variable Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- onnxscript/_internal/builder_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index eea474bc6e..3ef5993d42 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -727,7 +727,7 @@ def add_mul(X, Y): # The function returns 2 outputs, but we provide only 1 name with self.assertRaises(ValueError) as cm: - result = op.call(add_mul, x, y, _outputs=["only_one_name"]) + op.call(add_mul, x, y, _outputs=["only_one_name"]) self.assertIn("does not match", str(cm.exception)) From c6692fac1cfa51f3c88f03034fc179bdd9dd3a16 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 20 Feb 2026 16:59:10 -0800 Subject: [PATCH 09/14] Update onnxscript/_internal/builder.py Co-authored-by: Justin Chu --- onnxscript/_internal/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index b82c92b450..dbf45459c6 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -9,7 +9,7 @@ import onnx_ir as ir import onnxscript._internal._inference as inference -import onnxscript._internal._inliner as inliner +from onnxscript._internal import _inliner import onnxscript.optimizer # A permissible value for an op input, which can be converted to an ir.Value. From 241ba428b0697c26145570393fe2362b1537201e Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 20 Feb 2026 16:59:51 -0800 Subject: [PATCH 10/14] Update onnxscript/_internal/_inliner.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/_internal/_inliner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/_inliner.py b/onnxscript/_internal/_inliner.py index bef0fc85c4..f7e69c6fc0 100644 --- a/onnxscript/_internal/_inliner.py +++ b/onnxscript/_internal/_inliner.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations From d6d92f18c315cb5606af3874d92f0a0c582005c5 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 20 Feb 2026 17:00:14 -0800 Subject: [PATCH 11/14] Update onnxscript/_internal/builder.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/_internal/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index dbf45459c6..c39199637b 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -336,7 +336,7 @@ def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: function_ir = ir.serde.deserialize_function(function_proto) else: raise TypeError("Function must be an ir.Function or onnxscript.ONNXFunction") - output_renaming : dict[str, str] = {} + output_renaming: dict[str, str] = {} if _outputs is not None: if len(_outputs) != len(function_ir.outputs): raise ValueError( From d2958f35bb2a723c22ca5b7727888dde63b7319c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Feb 2026 17:15:35 -0800 Subject: [PATCH 12/14] Address review feedback Signed-off-by: Ganesan Ramalingam --- docs/tutorial/builder/graph_builder.md | 2 +- onnxscript/_internal/builder.py | 6 +++--- onnxscript/_internal/builder_test.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/tutorial/builder/graph_builder.md b/docs/tutorial/builder/graph_builder.md index cfbeba1357..55bd83a90e 100644 --- a/docs/tutorial/builder/graph_builder.md +++ b/docs/tutorial/builder/graph_builder.md @@ -512,4 +512,4 @@ result = op.call(my_func, x, y) ``` This pattern ensures consistency: the script function operates in the same domain -and opset version as the builder. \ No newline at end of file +and opset version as the builder. diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index c39199637b..a51c2722d0 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -331,11 +331,11 @@ def call_op( def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): if isinstance(function, ir.Function): function_ir = function - elif isinstance(function, onnxscript.values.OnnxFunction): + elif isinstance(function, onnxscript.OnnxFunction): function_proto = function.to_function_proto() function_ir = ir.serde.deserialize_function(function_proto) else: - raise TypeError("Function must be an ir.Function or onnxscript.ONNXFunction") + raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") output_renaming: dict[str, str] = {} if _outputs is not None: if len(_outputs) != len(function_ir.outputs): @@ -363,7 +363,7 @@ def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: if _prefix: self.pop_module() return outputs if len(outputs) > 1 else outputs[0] - + def push_module(self, module: str) -> None: """Push a new naming context onto the stack (e.g. a layer or module name).""" current = self.context_name() diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 3ef5993d42..ef33336157 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -276,7 +276,7 @@ def test_shape_inference_add(self): # Verify output shape is inferred correctly self.assertIsNotNone(result.shape) self.assertEqual(list(result.shape), [2, 3, 4]) - + # Verify the default name uses the node count self.assertEqual(result.name, "Add_n0_output") @@ -652,17 +652,17 @@ def mul_add_relu(X, Y): # The nodes should have the prefix in their names nodes = list(op.builder.graph) self.assertEqual(len(nodes), 3) - + # Check that all node names start with the prefix for node in nodes: self.assertTrue(node.name.startswith("layer1."), f"Node name {node.name} should start with layer1.") - + # Verify the result is a single ir.Value self.assertIsInstance(result, ir.Value) def test_call_with_outputs_and_prefix_options(self): """Test that GraphBuilder.call respects both _outputs and _prefix options together. - + Note: _outputs names are set before the prefix context is applied, so they don't get the prefix in their names. However, the inlined nodes do get the prefix applied, and intermediate values (not renamed by _outputs) do get the prefix applied. @@ -682,7 +682,7 @@ def add_mul(X, Y): return a, b result = op.call( - add_mul, x, y, + add_mul, x, y, _outputs=["custom_sum", "custom_product"], _prefix="math_ops" ) @@ -699,7 +699,7 @@ def add_mul(X, Y): # Verify all nodes have the prefix applied to their names nodes = list(op.builder.graph) self.assertEqual(len(nodes), 4) # Mul (XSquare), Mul (YSquare), Add, Mul (final) - + # All node names should start with prefix for node in nodes: self.assertTrue(node.name.startswith("math_ops."), f"Node name {node.name} should start with math_ops.") @@ -708,7 +708,7 @@ def add_mul(X, Y): # The first Mul produces XSquare x_square = nodes[0].outputs[0] self.assertTrue(x_square.name.startswith("math_ops."), f"Intermediate value {x_square.name} should have prefix") - + # The second Mul produces YSquare y_square = nodes[1].outputs[0] self.assertTrue(y_square.name.startswith("math_ops."), f"Intermediate value {y_square.name} should have prefix") @@ -728,7 +728,7 @@ def add_mul(X, Y): # The function returns 2 outputs, but we provide only 1 name with self.assertRaises(ValueError) as cm: op.call(add_mul, x, y, _outputs=["only_one_name"]) - + self.assertIn("does not match", str(cm.exception)) From a35b8dddc69ded6f3682450848329b327234f425 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Feb 2026 17:18:23 -0800 Subject: [PATCH 13/14] Address PR feedback Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index aeb7e2df87..921b1e0fe5 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -242,7 +242,7 @@ def _find_onnx_opset(self, node: ast.AST) -> values.Opset | None: return opset elif isinstance(opset, builder.OpBuilder): if opset.domain == "": - return opset + return values.Opset(opset.domain, opset.version) for child in ast.iter_child_nodes(node): res = self._find_onnx_opset(child) if res is not None: From 85418f9afb0b8a1eaba0a192a8899c79906a59d3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Feb 2026 17:19:41 -0800 Subject: [PATCH 14/14] Fix import Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index a51c2722d0..ca185a1028 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -9,7 +9,7 @@ import onnx_ir as ir import onnxscript._internal._inference as inference -from onnxscript._internal import _inliner +from onnxscript._internal import _inliner as inliner import onnxscript.optimizer # A permissible value for an op input, which can be converted to an ir.Value.