Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3033248
Add onnxscript.nn package with Module and Parameter classes
justinchuby Feb 20, 2026
11c880c
Add children, named_children, state_dict, load_state_dict to Module
justinchuby Feb 20, 2026
b6c084f
Improve nn module test coverage to 100%
justinchuby Feb 20, 2026
48b6a14
Make Parameter.realize() private (_realize())
justinchuby Feb 20, 2026
caa95cc
Update onnxscript/nn/_parameter.py
justinchuby Feb 20, 2026
13c8e6d
Update onnxscript/nn/_parameter.py
justinchuby Feb 20, 2026
bb7fba1
fix: ensure unique output names in GraphBuilder._adapt_outputs
justinchuby Feb 21, 2026
391ab4c
test: update builder tests for v_ output name prefix
justinchuby Feb 21, 2026
7f0e6bf
fix: use / separator for node and value scope names
justinchuby Feb 21, 2026
e00bb6d
fix: conditionally push and pop module in __call__ based on module name
justinchuby Feb 21, 2026
21a5c13
fix: use / throughout node scope prefix, not just at the join point
justinchuby Feb 21, 2026
75b0751
feat: redesign module scope system with node metadata
justinchuby Feb 21, 2026
3135535
refactor: rename qualify_name to _qualify_initializer_name
justinchuby Feb 21, 2026
67ffb42
refactor: use '.' delimiter for value names, '/' for node names
justinchuby Feb 21, 2026
feb1ff7
fix: namespace uses 'name: class' pairs, align scope list lengths
justinchuby Feb 21, 2026
1a1a7a6
Update names
justinchuby Feb 21, 2026
9cd8f20
fix: append op_type to name_scopes so lengths match class_hierarchy
justinchuby Feb 21, 2026
77e0d68
fix: remove op_type from namespace, class_hierarchy, and name_scopes
justinchuby Feb 21, 2026
9613200
Format
justinchuby Feb 21, 2026
5667be7
refactor: use ast.literal_eval instead of eval in tests
justinchuby Feb 21, 2026
6bb9931
Update name check
justinchuby Feb 21, 2026
7741555
Implement ModuleList
justinchuby Feb 21, 2026
248742e
Update Parameter to take dtype from data
justinchuby Feb 22, 2026
d5e9710
pylint
justinchuby Feb 22, 2026
eae41ac
fix check error
justinchuby Feb 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"script",
"graph",
"ir",
"nn",
"optimizer",
"rewriter",
"version_converter",
Expand Down Expand Up @@ -127,7 +128,7 @@

# isort: on

from . import ir, optimizer, rewriter, version_converter
from . import ir, nn, optimizer, rewriter, version_converter
from ._internal.builder import GraphBuilder
from ._internal.utils import external_tensor
from ._internal.values import OnnxFunction, TracedOnnxFunction
Expand Down
9 changes: 9 additions & 0 deletions onnxscript/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""PyTorch-like module interface for building ONNX graphs."""

from onnxscript.nn._module import Module
from onnxscript.nn._parameter import Parameter

__all__ = ["Module", "Parameter"]
206 changes: 206 additions & 0 deletions onnxscript/nn/_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import Any, Iterator

import onnx_ir as ir

from onnxscript._internal.builder import GraphBuilder, OpBuilder
from onnxscript.nn._parameter import Parameter


class Module:
"""Base class for all onnxscript modules, mirroring PyTorch's nn.Module.

Subclasses define ``forward()`` to build ONNX subgraphs. Child modules
and parameters are registered automatically via ``__setattr__``.
Because ``Parameter`` subclasses ``ir.Value``, parameters like
``self.weight`` can be passed directly to ONNX ops.

Example::

class Linear(onnxscript.nn.Module):
def __init__(self, in_features, out_features, bias=True, name=None):
super().__init__(name)
self.weight = Parameter([out_features, in_features], name="weight")
if bias:
self.bias = Parameter([out_features], name="bias")
else:
self.bias = None

def forward(self, op, x):
w_t = op.Transpose(self.weight, perm=[1, 0])
result = op.MatMul(x, w_t)
if self.bias is not None:
result = op.Add(result, self.bias)
return result
"""

def __init__(self, name: str | None = None) -> None:
# Use object.__setattr__ to avoid triggering our __setattr__ override
# before _parameters and _modules dicts exist.
object.__setattr__(self, "_name", name)
object.__setattr__(self, "_parameters", {})
object.__setattr__(self, "_modules", {})

@property
def name(self) -> str | None:
return self._name

def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, Parameter):
# Auto-register parameters; set default name from attribute name.
if value.name is None:
value.name = name
self._parameters[name] = value
# Also store on the instance so getattr works outside forward()
object.__setattr__(self, name, value)
elif isinstance(value, Module):
# Auto-register child modules; inherit attribute name if unnamed.
if value._name is None:
object.__setattr__(value, "_name", name)
self._modules[name] = value
object.__setattr__(self, name, value)
else:
object.__setattr__(self, name, value)

def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any:
builder: GraphBuilder = op.builder
module_name = self._name or ""
builder.push_module(module_name)
Comment thread
justinchuby marked this conversation as resolved.
Outdated
try:
# Realize parameters: qualify names and register as graph initializers.
for param in self._parameters.values():
param._realize(builder) # pylint: disable=protected-access

result = self.forward(op, *args, **kwargs)
finally:
builder.pop_module()
return result

def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any:
"""Define the computation performed by this module.

Must be overridden by subclasses. Receives an ``OpBuilder`` as the
first argument so that ONNX ops can be called as ``op.MatMul(x, w)``.
"""
raise NotImplementedError(f"{type(self).__name__} must implement forward()")

# ------------------------------------------------------------------
# Iterators
# ------------------------------------------------------------------

def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
"""Return an iterator over module parameters."""
yield from self._parameters.values()
if recurse:
for module in self._modules.values():
yield from module.parameters(recurse=True)

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[tuple[str, Parameter]]:
"""Return an iterator over module parameters, yielding (name, Parameter) pairs."""
for name, param in self._parameters.items():
full_name = f"{prefix}.{name}" if prefix else name
yield full_name, param
if recurse:
for mod_name, module in self._modules.items():
sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name
yield from module.named_parameters(prefix=sub_prefix, recurse=True)

def children(self) -> Iterator[Module]:
"""Return an iterator over immediate child modules."""
yield from self._modules.values()

def named_children(self) -> Iterator[tuple[str, Module]]:
"""Return an iterator over immediate child modules, yielding (name, Module) pairs."""
yield from self._modules.items()

def modules(self) -> Iterator[Module]:
"""Return an iterator over all modules in the tree (including self)."""
yield self
for module in self._modules.values():
yield from module.modules()

def named_modules(self, prefix: str = "") -> Iterator[tuple[str, Module]]:
"""Return an iterator over all modules, yielding (name, Module) pairs."""
yield prefix, self
for name, module in self._modules.items():
sub_prefix = f"{prefix}.{name}" if prefix else name
yield from module.named_modules(prefix=sub_prefix)

# ------------------------------------------------------------------
# State dict
# ------------------------------------------------------------------

def state_dict(self, prefix: str = "") -> dict[str, ir.TensorProtocol | None]:
"""Return a dictionary mapping parameter names to their tensor data.

Mirrors ``torch.nn.Module.state_dict()``. Keys use dot-separated
hierarchical names (e.g. ``"layer1.weight"``). Values are the
``const_value`` of each parameter (``None`` if uninitialized).
"""
result: dict[str, ir.TensorProtocol | None] = {}
for name, param in self._parameters.items():
full_name = f"{prefix}.{name}" if prefix else name
result[full_name] = param.const_value
Comment thread
justinchuby marked this conversation as resolved.
for mod_name, module in self._modules.items():
sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name
result.update(module.state_dict(prefix=sub_prefix))
return result

def load_state_dict(
self,
state_dict: dict[str, ir.TensorProtocol],
strict: bool = True,
) -> None:
"""Load parameter data from a state dictionary.

Mirrors ``torch.nn.Module.load_state_dict()``. Sets ``const_value``
on each matching parameter.

Args:
state_dict: Mapping of parameter names to tensor data.
strict: If ``True`` (default), raises ``KeyError`` for missing
keys and ``ValueError`` for unexpected keys.
"""
self._load_state_dict_recursive(state_dict, prefix="", strict=strict)

def _load_state_dict_recursive(
self,
state_dict: dict[str, ir.TensorProtocol],
prefix: str,
strict: bool,
) -> set[str]:
"""Recursively load state and return the set of consumed keys."""
consumed: set[str] = set()
for name, param in self._parameters.items():
full_name = f"{prefix}.{name}" if prefix else name
if full_name in state_dict:
param.const_value = state_dict[full_name]
consumed.add(full_name)
elif strict:
raise KeyError(f"Missing key in state_dict: {full_name!r}")
for mod_name, module in self._modules.items():
sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name
consumed |= module._load_state_dict_recursive( # pylint: disable=protected-access
state_dict, prefix=sub_prefix, strict=strict
)
if strict and prefix == "":
unexpected = set(state_dict.keys()) - consumed
if unexpected:
raise ValueError(f"Unexpected keys in state_dict: {unexpected}")
return consumed

def __repr__(self) -> str:
lines = [f"{type(self).__name__}("]
for name, module in self._modules.items():
mod_repr = repr(module).replace("\n", "\n ")
lines.append(f" ({name}): {mod_repr}")
for name, param in self._parameters.items():
lines.append(f" ({name}): {param!r}")
lines.append(")")
return "\n".join(lines)
Loading
Loading