-
Notifications
You must be signed in to change notification settings - Fork 102
Allow GraphBuilder to call script functions #2820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
50ca21b
255e747
ffddfd7
2786955
5236801
f66a197
f053357
8d8b2ea
c6692fa
241ba42
d6d92f1
d2958f3
a35b8dd
85418f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
Check warningCode scanning / lintrunner RUFF-FORMAT/format Warning
Run lintrunner -a to apply this patch.
|
||
| # Licensed under the MIT License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Mapping, Sequence | ||
|
|
||
| import onnx_ir as ir | ||
| from onnx_ir._cloner import Cloner | ||
|
|
||
|
|
||
| def instantiate( | ||
| function: ir.Function, | ||
| inputs: Sequence[ir.Value | None], | ||
| attributes: Mapping[str, ir.Attr], | ||
| *, | ||
| prefix: str = "", | ||
| ) -> tuple[list[ir.Node], list[ir.Value | None]]: | ||
| """Instantiate (inline) a function, substituting inputs and attributes. | ||
|
|
||
| Args: | ||
| function: The function to instantiate. | ||
| inputs: Actual input values to bind to the function's formal parameters. | ||
| attributes: Attribute values to substitute for reference attributes. | ||
| prefix: Optional prefix to prepend to node and output names. | ||
|
|
||
| Returns: | ||
| A tuple of (nodes, outputs) where nodes are the cloned function body | ||
| and outputs are the values corresponding to the function's outputs. | ||
| """ | ||
| formal_inputs = function.inputs | ||
| if len(inputs) > len(formal_inputs): | ||
| raise ValueError( | ||
| f"Too many inputs: got {len(inputs)}, " | ||
| f"but function has {len(formal_inputs)} parameters." | ||
| ) | ||
| value_map: dict[ir.Value, ir.Value | None] = { | ||
Check noticeCode scanning / lintrunner PYLINT/R1721 Note
Unnecessary use of a comprehension, use dict(zip(formal_inputs, inputs)) instead. (unnecessary-comprehension)
See unnecessary-comprehension. To disable, use # pylint: disable=unnecessary-comprehension Check noticeCode scanning / lintrunner RUFF/C416 Note
Unnecessary dict comprehension (rewrite using dict()).
See https://docs.astral.sh/ruff/rules/unnecessary-comprehension |
||
| formal: actual for formal, actual in zip(formal_inputs, inputs) | ||
| } | ||
|
|
||
| def rename(node: ir.Node) -> None: | ||
| if prefix: | ||
| node.name = prefix + (node.name or "") | ||
| for output in node.outputs: | ||
| if output is not None: | ||
| output.name = prefix + (output.name or "") | ||
|
|
||
| cloner = Cloner( | ||
| attr_map=attributes, | ||
| value_map=value_map, | ||
| metadata_props={}, | ||
| post_process=rename, | ||
| resolve_ref_attrs=True, | ||
| ) | ||
| nodes = [cloner.clone_node(n) for n in function] | ||
| outputs = [value_map.get(v) for v in function.outputs] | ||
| return nodes, outputs | ||
|
|
||
Check warningCode scanning / lintrunner RUFF/W391 Warning
Extra newline at end of file.
See https://docs.astral.sh/ruff/rules/too-many-newlines-at-end-of-file |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| import onnx_ir as ir | ||
|
|
||
| import onnxscript._internal._inference as inference | ||
| from onnxscript._internal import _inliner as inliner | ||
| import onnxscript.optimizer | ||
|
|
||
| # A permissible value for an op input, which can be converted to an ir.Value. | ||
|
|
@@ -182,12 +183,13 @@ def _adapt_outputs( | |
| if isinstance(outputs, int): | ||
| if outputs < 0: | ||
| raise ValueError(f"Number of outputs must be non-negative, got {outputs}") | ||
| count = self.graph.num_nodes() | ||
| if outputs == 1: | ||
| name = f"{op_type}_output" if op_type else "output" | ||
| name = f"{op_type}_n{count}_output" if op_type else f"n{count}_output" | ||
| return [ir.Value(name=self.qualify_name(name))] | ||
| else: | ||
| names = [ | ||
| f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) | ||
| f"{op_type}_n{count}_output{i}" if op_type else f"n{count}_output{i}" for i in range(outputs) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There may be some conflicts: I updated the naming patterns in #2819 |
||
| ] | ||
| return [ir.Value(name=self.qualify_name(n)) for n in names] | ||
| adapted_outputs = [] | ||
|
|
@@ -326,6 +328,42 @@ def call_op( | |
|
|
||
| return node.outputs if len(node.outputs) > 1 else node.outputs[0] | ||
|
|
||
| def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): | ||
| if isinstance(function, ir.Function): | ||
| function_ir = function | ||
| elif isinstance(function, onnxscript.OnnxFunction): | ||
| function_proto = function.to_function_proto() | ||
| function_ir = ir.serde.deserialize_function(function_proto) | ||
| else: | ||
| raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") | ||
| output_renaming: dict[str, str] = {} | ||
| if _outputs is not None: | ||
| if len(_outputs) != len(function_ir.outputs): | ||
| raise ValueError( | ||
| f"Number of provided output names {_outputs} does not match " | ||
| f"number of function outputs {len(function_ir.outputs)}." | ||
| ) | ||
| for output, name in zip(function_ir.outputs, _outputs): | ||
| output_renaming[output.name] = self.qualify_name(name) | ||
| else: | ||
| for output in function_ir.outputs: | ||
| output_renaming[output.name] = self.qualify_name(output.name) | ||
| nodes, outputs = inliner.instantiate(function_ir, args, kwargs) | ||
| if _prefix: | ||
| self.push_module(_prefix) | ||
| for node in nodes: | ||
| node.name = self.qualify_name(node.name) | ||
| for output in node.outputs: | ||
| if output.name: | ||
| if output.name in output_renaming: | ||
| output.name = output_renaming[output.name] | ||
| else: | ||
| output.name = self.qualify_name(output.name) | ||
| self.add_node(node) | ||
| if _prefix: | ||
| self.pop_module() | ||
| return outputs if len(outputs) > 1 else outputs[0] | ||
|
|
||
| def push_module(self, module: str) -> None: | ||
| """Push a new naming context onto the stack (e.g. a layer or module name).""" | ||
| current = self.context_name() | ||
|
|
@@ -365,6 +403,14 @@ def __init__( | |
| def builder(self) -> GraphBuilder: | ||
| return self._builder | ||
|
|
||
| @property | ||
| def domain(self) -> str: | ||
| return self._domain | ||
|
|
||
| @property | ||
| def version(self) -> int | None: | ||
| return self._version | ||
|
|
||
| def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): | ||
| if "_domain" not in kwargs: | ||
| kwargs["_domain"] = self._domain | ||
|
|
@@ -377,3 +423,6 @@ def __getattr__(self, op_type: str) -> Callable: | |
|
|
||
| def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: | ||
| return self._builder.initializer(tensor, name) | ||
|
|
||
| def call(self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", **kwargs): | ||
| return self._builder.call(function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs) | ||
Check warning
Code scanning / lintrunner
RUFF/format Warning