Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions docs/tutorial/builder/graph_builder.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,110 @@ def build_linear(op, x, weight, bias_value):

This pattern keeps function signatures simple while preserving access to the
full builder API when needed.
## Calling Script Functions from OpBuilder

The `OpBuilder` provides a `call()` method to inline `@script`-decorated ONNX functions directly into the builder's graph. This enables composition of both imperative (builder) and declarative (`@script`) code within a single graph.

### Basic function inlining

Define an ONNX script function and then call it through `op.call()`:

```python
from onnxscript import script, opset23 as op23

# Define a reusable script function
@script(default_opset=op23)
def mul_add_relu(X, Y):
tmp = X * Y
tmp = tmp + X
return op23.Relu(tmp)

# Now build a graph using OpBuilder
graph = ir.Graph(
name="my_graph",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
x = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4]))
y = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4]))
graph.inputs.extend([x, y])

builder = onnxscript.GraphBuilder(graph)
op = builder.op

# Call the script function — it gets inlined into the graph
result = op.call(mul_add_relu, x, y)
graph.outputs.append(result)
```

The function body (three nodes: Mul, Add, Relu) is inlined directly into the graph.

### Renaming outputs with `_outputs`

By default, inlined function outputs keep their original names, qualified by the
current naming context. You can rename them explicitly with `_outputs`:

```python
@script(default_opset=op23)
def add_mul(X, Y):
a = X + Y
b = X * Y
return a, b

# Inline with custom output names
result_sum, result_prod = op.call(
add_mul, x, y,
_outputs=["custom_sum", "custom_product"]
)
```

### Adding hierarchical context with `_prefix`

Use `_prefix` to add a naming context to all nodes and intermediate values created
by the inlined function:

```python
result = op.call(
mul_add_relu, x, y,
_prefix="layer1"
)
# Node names will be "layer1.Mul_n...", "layer1.Add_n...", "layer1.Relu_n..."
# Intermediate value names will also start with "layer1."
```

You can combine both options:

```python
result_a, result_b = op.call(
add_mul, x, y,
_outputs=["sum_out", "prod_out"],
_prefix="math_ops"
)
# Final outputs: "sum_out", "prod_out" (renamed before prefix context)
# Intermediate values: "math_ops.Add_n...", "math_ops.Mul_n..." (with prefix)
```

### Using OpBuilder as the default_opset

`OpBuilder` can be passed directly as the `default_opset` when decorating a script
function. This enables scripted functions to use the same opset version as the
builder they will be inlined into:

```python
builder = onnxscript.GraphBuilder(graph)
op = builder.op

# Define the function *after* creating the builder, using op as default_opset
@script(default_opset=op)
def my_func(X, Y):
t = X + Y
return op.Relu(t) # Uses the op directly

# Inline it
result = op.call(my_func, x, y)
```

This pattern ensures consistency: the script function operates in the same domain
and opset version as the builder.
58 changes: 58 additions & 0 deletions onnxscript/_internal/_inliner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.

from __future__ import annotations

from typing import Mapping, Sequence

import onnx_ir as ir
from onnx_ir._cloner import Cloner


def instantiate(
function: ir.Function,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, ir.Attr],
*,
prefix: str = "",
) -> tuple[list[ir.Node], list[ir.Value | None]]:
"""Instantiate (inline) a function, substituting inputs and attributes.

Args:
function: The function to instantiate.
inputs: Actual input values to bind to the function's formal parameters.
attributes: Attribute values to substitute for reference attributes.
prefix: Optional prefix to prepend to node and output names.

Returns:
A tuple of (nodes, outputs) where nodes are the cloned function body
and outputs are the values corresponding to the function's outputs.
"""
formal_inputs = function.inputs
if len(inputs) > len(formal_inputs):
raise ValueError(
f"Too many inputs: got {len(inputs)}, "
f"but function has {len(formal_inputs)} parameters."
)
value_map: dict[ir.Value, ir.Value | None] = {

Check notice

Code scanning / lintrunner

PYLINT/R1721 Note

Unnecessary use of a comprehension, use dict(zip(formal_inputs, inputs)) instead. (unnecessary-comprehension)
See unnecessary-comprehension. To disable, use # pylint: disable=unnecessary-comprehension

Check notice

Code scanning / lintrunner

RUFF/C416 Note

Unnecessary dict comprehension (rewrite using dict()).
See https://docs.astral.sh/ruff/rules/unnecessary-comprehension
formal: actual for formal, actual in zip(formal_inputs, inputs)
}

def rename(node: ir.Node) -> None:
if prefix:
node.name = prefix + (node.name or "")
for output in node.outputs:
if output is not None:
output.name = prefix + (output.name or "")

cloner = Cloner(
attr_map=attributes,
value_map=value_map,
metadata_props={},
post_process=rename,
resolve_ref_attrs=True,
)
nodes = [cloner.clone_node(n) for n in function]
outputs = [value_map.get(v) for v in function.outputs]
return nodes, outputs

Check warning

Code scanning / lintrunner

RUFF/W391 Warning

53 changes: 51 additions & 2 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import onnx_ir as ir

import onnxscript._internal._inference as inference
from onnxscript._internal import _inliner as inliner
import onnxscript.optimizer

# A permissible value for an op input, which can be converted to an ir.Value.
Expand Down Expand Up @@ -182,12 +183,13 @@ def _adapt_outputs(
if isinstance(outputs, int):
if outputs < 0:
raise ValueError(f"Number of outputs must be non-negative, got {outputs}")
count = self.graph.num_nodes()
if outputs == 1:
name = f"{op_type}_output" if op_type else "output"
name = f"{op_type}_n{count}_output" if op_type else f"n{count}_output"
return [ir.Value(name=self.qualify_name(name))]
else:
names = [
f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs)
f"{op_type}_n{count}_output{i}" if op_type else f"n{count}_output{i}" for i in range(outputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be some conflicts: I updated the naming patterns in #2819

]
return [ir.Value(name=self.qualify_name(n)) for n in names]
adapted_outputs = []
Expand Down Expand Up @@ -326,6 +328,42 @@ def call_op(

return node.outputs if len(node.outputs) > 1 else node.outputs[0]

def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs):
if isinstance(function, ir.Function):
function_ir = function
elif isinstance(function, onnxscript.OnnxFunction):
function_proto = function.to_function_proto()
function_ir = ir.serde.deserialize_function(function_proto)
else:
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
output_renaming: dict[str, str] = {}
if _outputs is not None:
if len(_outputs) != len(function_ir.outputs):
raise ValueError(
f"Number of provided output names {_outputs} does not match "
f"number of function outputs {len(function_ir.outputs)}."
)
for output, name in zip(function_ir.outputs, _outputs):
output_renaming[output.name] = self.qualify_name(name)
else:
for output in function_ir.outputs:
output_renaming[output.name] = self.qualify_name(output.name)
nodes, outputs = inliner.instantiate(function_ir, args, kwargs)
if _prefix:
self.push_module(_prefix)
for node in nodes:
node.name = self.qualify_name(node.name)
for output in node.outputs:
if output.name:
if output.name in output_renaming:
output.name = output_renaming[output.name]
else:
output.name = self.qualify_name(output.name)
self.add_node(node)
if _prefix:
self.pop_module()
return outputs if len(outputs) > 1 else outputs[0]

def push_module(self, module: str) -> None:
"""Push a new naming context onto the stack (e.g. a layer or module name)."""
current = self.context_name()
Expand Down Expand Up @@ -365,6 +403,14 @@ def __init__(
def builder(self) -> GraphBuilder:
return self._builder

@property
def domain(self) -> str:
return self._domain

@property
def version(self) -> int | None:
return self._version

def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
if "_domain" not in kwargs:
kwargs["_domain"] = self._domain
Expand All @@ -377,3 +423,6 @@ def __getattr__(self, op_type: str) -> Callable:

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
return self._builder.initializer(tensor, name)

def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs):
return self._builder.call(function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs)
Loading
Loading