Add onnxscript.nn module with Module and Parameter classes#2819
Add onnxscript.nn module with Module and Parameter classes#2819justinchuby wants to merge 25 commits intomainfrom
Conversation
Introduce a PyTorch-like nn.Module interface for building ONNX graphs: - Parameter(ir.Value): Subclasses ir.Value so parameters can be passed directly to ONNX ops. realize() qualifies names and registers as graph initializers. - Module: Base class with automatic child module/parameter registration via __setattr__, hierarchical naming via push_module/pop_module, and forward() for subclasses to override. - Iterators: parameters(), named_parameters(), modules(), named_modules() - Exported via onnxscript.nn Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- children() / named_children(): iterators for immediate child modules - state_dict(): returns dict mapping param names to tensor data - load_state_dict(): loads tensor data into parameters, with strict mode for missing/unexpected keys Also refactored Parameter to subclass ir.Value directly, eliminating the swap/restore mechanism in Module.__call__. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add tests for: - Module.name property - Plain attribute assignment via __setattr__ - NotImplementedError from base Module.forward() - Recursive parameters() with child modules Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This is an internal method called by Module.__call__, not part of the public API. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2819 +/- ##
==========================================
+ Coverage 70.76% 71.44% +0.68%
==========================================
Files 231 236 +5
Lines 27667 28496 +829
Branches 2775 2819 +44
==========================================
+ Hits 19579 20360 +781
- Misses 7127 7171 +44
- Partials 961 965 +4 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Pull request overview
This pull request introduces a PyTorch-like neural network module interface (onnxscript.nn) for building ONNX graphs in a structured, object-oriented manner. The implementation provides Module and Parameter classes that enable users to define reusable neural network components with automatic parameter registration and hierarchical naming, mirroring PyTorch's nn.Module API.
Changes:
- Added new
onnxscript.nnpackage withModuleandParameterclasses for building ONNX graphs with a PyTorch-like interface - Implemented automatic parameter and submodule registration via
__setattr__, hierarchical parameter naming, and state dict serialization/deserialization - Exposed the new package in the main
onnxscriptAPI for easy access
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
onnxscript/nn/__init__.py |
New package initialization exposing Module and Parameter as public API |
onnxscript/nn/_parameter.py |
Implementation of Parameter class that subclasses ir.Value and supports realization as graph initializers |
onnxscript/nn/_module.py |
Implementation of Module base class with automatic registration, iterators, and state dict functionality |
onnxscript/nn/_module_test.py |
Comprehensive test suite covering parameter/module registration, forward passes, iterators, and state dict operations |
onnxscript/__init__.py |
Updated to import and expose the new nn package in __all__ |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
onnxscript/nn/_module.py
Outdated
| def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: | ||
| builder: GraphBuilder = op.builder | ||
| module_name = self._name or "" | ||
| builder.push_module(module_name) |
There was a problem hiding this comment.
nit: may be use if self._name: builder.push_module(self._name) and same for pop?
There was a problem hiding this comment.
Just to be explicit here (independent of the implementation of push/pop) about the treatment when no name is specified. (Alternatively, we can make name a required parameter in the constructor ... not sure if people will want modules without names/hierarchical names)
There was a problem hiding this comment.
Done, added the conditional
| result: dict[str, ir.TensorProtocol | None] = {} | ||
| for name, param in self._parameters.items(): | ||
| full_name = f"{prefix}.{name}" if prefix else name | ||
| result[full_name] = param.const_value |
There was a problem hiding this comment.
This could be None, right? ... I guess I am just looking for how it will be used, or rather, how load_state_dict will be used in practice with real trained weights coming from somewhere else ... I suppose that will come later.
There was a problem hiding this comment.
Do you mean param.const_value can be None? I think that's ok. (not sure what None means yet but seems ok)
onnxscript/nn/_parameter.py
Outdated
| ) | ||
| if self.name: | ||
| self.name = builder.qualify_name(self.name) | ||
| builder.graph.initializers[self.name] = self # type: ignore[index] |
There was a problem hiding this comment.
Does the IR already support initializers without a value? Can it be serialized?
There was a problem hiding this comment.
Yes. It will not be serialized if empty. The idea is the const_values will be filled in later, after the model is built.
Append the node count to auto-generated output names (e.g. Add_output_0 instead of Add_output) to prevent name collisions when the same op type is called multiple times within a module. This matches the existing node naming strategy which already uses the count suffix. Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
ONNX convention uses / for node scope hierarchies. Add _qualify_node_name() that uses / as separator, keeping qualify_name() with . for parameter/initializer names. Node names and auto-generated value names now use / (e.g. layer1/Add_node_0, layer1/v_Add_0) while parameters keep . (e.g. layer1.weight). Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
_qualify_node_name now replaces all dots in the context prefix with / so nested scopes produce names like layer1/attention/Add_node_0 instead of layer1.attention/Add_node_0. Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
- Change _context_stack from cumulative strings to a list of (name, class_name) tuples in _scope_stack - push_module(name, class_name) stores individual scope entries; names like 'layers.3' are kept as single scopes (no dot splitting) - qualify_name uses '.' join for parameters (layers.3.self_attn.weight) - _qualify_value_name uses '/' join with v_ prefix (layers.3/self_attn/v_Add_0) - _qualify_node_name uses '/' join (layers.3/self_attn/Add_node_0) - Module.__call__ passes type(self).__qualname__ as class_name - call_op attaches metadata_props to every node: - namespace: scope path + op type (e.g. 'layer1/self_attn: Add') - pkg.onnxscript.class_hierarchy: list of class names + op type - pkg.onnxscript.name_scopes: list of scope names Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Value names use '.' separator with 'v_' prefix (e.g. v_layer1.attention.Add_0). Node names and namespace strings use '/' separator (e.g. layer1/attention/Add_node_0). Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
- namespace: each scope entry is 'name: class_name' joined by '/' e.g. 'layer1: DecoderLayer/self_attn: Attention/Add' - scope_names() and scope_classes() return all entries (no filtering) so class_hierarchy and name_scopes always have matching lengths - _scope_name_parts() filters empty names for initializer/value/node qualifying Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
f29767e to
feb1ff7
Compare
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
namespace now only contains module scopes (e.g. 'layer1: DecoderLayer/self_attn: Attention'). class_hierarchy and name_scopes reflect only the module stack, not the op. Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
|
Updated how scoped names are constructed. @gramalingam |
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
This pull request introduces a new PyTorch-like module interface for building ONNX graphs, enabling users to define reusable neural network components and manage parameters in a structured way. The main changes add the
onnxscript.nnpackage, expose its API, and implement core classes for module and parameter management.Addition of ONNX neural network module interface:
onnxscript.nnpackage and exposed it in the mainonnxscriptAPI, allowing users to access neural network module functionality. (onnxscript/__init__.py, [1] [2]ModuleandParameterclasses inonnxscript/nn/_module.pyandonnxscript/nn/_parameter.py, providing a PyTorch-like interface for defining ONNX graph modules, registering parameters, and managing module hierarchies. (onnxscript/nn/_module.py, [1];onnxscript/nn/_parameter.py, [2]onnxscript/nn/__init__.pyto exposeModuleandParameteras the public API of the new package. (onnxscript/nn/__init__.py, onnxscript/nn/init.pyR1-R9)Core module and parameter functionality:
Moduleclass supports automatic registration of parameters and child modules, implements methods for iterating over parameters/modules, and providesstate_dict/load_state_dictfor parameter serialization/deserialization, mirroring PyTorch's API. (onnxscript/nn/_module.py, onnxscript/nn/_module.pyR1-R206)Parameterclass subclassesir.Value, allowing direct use in ONNX ops and supporting initialization, realization, and representation of parameter tensors. (onnxscript/nn/_parameter.py, onnxscript/nn/_parameter.pyR1-R66)