From 3033248472768ce4277ce2d7e11b6af898d0fb6b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 08:32:25 -0800 Subject: [PATCH 01/25] Add onnxscript.nn package with Module and Parameter classes Introduce a PyTorch-like nn.Module interface for building ONNX graphs: - Parameter(ir.Value): Subclasses ir.Value so parameters can be passed directly to ONNX ops. realize() qualifies names and registers as graph initializers. - Module: Base class with automatic child module/parameter registration via __setattr__, hierarchical naming via push_module/pop_module, and forward() for subclasses to override. - Iterators: parameters(), named_parameters(), modules(), named_modules() - Exported via onnxscript.nn Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/__init__.py | 3 +- onnxscript/nn/__init__.py | 9 + onnxscript/nn/_module.py | 133 ++++++++++++ onnxscript/nn/_module_test.py | 382 ++++++++++++++++++++++++++++++++++ onnxscript/nn/_parameter.py | 66 ++++++ 5 files changed, 592 insertions(+), 1 deletion(-) create mode 100644 onnxscript/nn/__init__.py create mode 100644 onnxscript/nn/_module.py create mode 100644 onnxscript/nn/_module_test.py create mode 100644 onnxscript/nn/_parameter.py diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index 34ea9a0778..9783e37c24 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -5,6 +5,7 @@ "script", "graph", "ir", + "nn", "optimizer", "rewriter", "version_converter", @@ -127,7 +128,7 @@ # isort: on -from . import ir, optimizer, rewriter, version_converter +from . import ir, nn, optimizer, rewriter, version_converter from ._internal.builder import GraphBuilder from ._internal.utils import external_tensor from ._internal.values import OnnxFunction, TracedOnnxFunction diff --git a/onnxscript/nn/__init__.py b/onnxscript/nn/__init__.py new file mode 100644 index 0000000000..181409812a --- /dev/null +++ b/onnxscript/nn/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""PyTorch-like module interface for building ONNX graphs.""" + +from onnxscript.nn._module import Module +from onnxscript.nn._parameter import Parameter + +__all__ = ["Module", "Parameter"] diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py new file mode 100644 index 0000000000..034760c9c5 --- /dev/null +++ b/onnxscript/nn/_module.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Iterator + +from onnxscript._internal.builder import GraphBuilder, OpBuilder +from onnxscript.nn._parameter import Parameter + + +class Module: + """Base class for all onnxscript modules, mirroring PyTorch's nn.Module. + + Subclasses define ``forward()`` to build ONNX subgraphs. Child modules + and parameters are registered automatically via ``__setattr__``. + Because ``Parameter`` subclasses ``ir.Value``, parameters like + ``self.weight`` can be passed directly to ONNX ops. + + Example:: + + class Linear(onnxscript.nn.Module): + def __init__(self, in_features, out_features, bias=True, name=None): + super().__init__(name) + self.weight = Parameter([out_features, in_features], name="weight") + if bias: + self.bias = Parameter([out_features], name="bias") + else: + self.bias = None + + def forward(self, op, x): + w_t = op.Transpose(self.weight, perm=[1, 0]) + result = op.MatMul(x, w_t) + if self.bias is not None: + result = op.Add(result, self.bias) + return result + """ + + def __init__(self, name: str | None = None) -> None: + # Use object.__setattr__ to avoid triggering our __setattr__ override + # before _parameters and _modules dicts exist. + object.__setattr__(self, "_name", name) + object.__setattr__(self, "_parameters", {}) + object.__setattr__(self, "_modules", {}) + + @property + def name(self) -> str | None: + return self._name + + def __setattr__(self, name: str, value: Any) -> None: + if isinstance(value, Parameter): + # Auto-register parameters; set default name from attribute name. + if value.name is None: + value.name = name + self._parameters[name] = value + # Also store on the instance so getattr works outside forward() + object.__setattr__(self, name, value) + elif isinstance(value, Module): + # Auto-register child modules; inherit attribute name if unnamed. + if value._name is None: + object.__setattr__(value, "_name", name) + self._modules[name] = value + object.__setattr__(self, name, value) + else: + object.__setattr__(self, name, value) + + def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + builder: GraphBuilder = op.builder + module_name = self._name or "" + builder.push_module(module_name) + try: + # Realize parameters: qualify names and register as graph initializers. + for param in self._parameters.values(): + param.realize(builder) + + result = self.forward(op, *args, **kwargs) + finally: + builder.pop_module() + return result + + def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + """Define the computation performed by this module. + + Must be overridden by subclasses. Receives an ``OpBuilder`` as the + first argument so that ONNX ops can be called as ``op.MatMul(x, w)``. + """ + raise NotImplementedError(f"{type(self).__name__} must implement forward()") + + # ------------------------------------------------------------------ + # Iterators + # ------------------------------------------------------------------ + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + """Return an iterator over module parameters.""" + yield from self._parameters.values() + if recurse: + for module in self._modules.values(): + yield from module.parameters(recurse=True) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[tuple[str, Parameter]]: + """Return an iterator over module parameters, yielding (name, Parameter) pairs.""" + for name, param in self._parameters.items(): + full_name = f"{prefix}.{name}" if prefix else name + yield full_name, param + if recurse: + for mod_name, module in self._modules.items(): + sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name + yield from module.named_parameters(prefix=sub_prefix, recurse=True) + + def modules(self) -> Iterator[Module]: + """Return an iterator over all modules in the tree (including self).""" + yield self + for module in self._modules.values(): + yield from module.modules() + + def named_modules(self, prefix: str = "") -> Iterator[tuple[str, Module]]: + """Return an iterator over all modules, yielding (name, Module) pairs.""" + yield prefix, self + for name, module in self._modules.items(): + sub_prefix = f"{prefix}.{name}" if prefix else name + yield from module.named_modules(prefix=sub_prefix) + + def __repr__(self) -> str: + lines = [f"{type(self).__name__}("] + for name, module in self._modules.items(): + mod_repr = repr(module).replace("\n", "\n ") + lines.append(f" ({name}): {mod_repr}") + for name, param in self._parameters.items(): + lines.append(f" ({name}): {param!r}") + lines.append(")") + return "\n".join(lines) diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py new file mode 100644 index 0000000000..8f4e8d7083 --- /dev/null +++ b/onnxscript/nn/_module_test.py @@ -0,0 +1,382 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import unittest +from typing import Any + +import onnx_ir as ir + +from onnxscript._internal.builder import GraphBuilder, OpBuilder +from onnxscript.nn import Module, Parameter + + +def _create_graph_and_op() -> tuple[ir.Graph, OpBuilder]: + """Create an empty graph and its default OpBuilder.""" + graph = ir.Graph( + name="test_model", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, + ) + builder = GraphBuilder(graph) + return graph, builder.op + + +class ParameterTest(unittest.TestCase): + def test_parameter_repr(self): + p = Parameter([3, 4], dtype=ir.DataType.FLOAT, name="weight") + self.assertIn("weight", repr(p)) + self.assertIn("3, 4", repr(p)) + + def test_realize_creates_initializer(self): + graph, op = _create_graph_and_op() + p = Parameter([3, 4], dtype=ir.DataType.FLOAT, name="weight") + value = p.realize(op.builder) + + self.assertIs(value, p) # realize returns self + self.assertIsInstance(value, ir.Value) + self.assertEqual(value.name, "weight") + self.assertEqual(value.type.dtype, ir.DataType.FLOAT) + self.assertEqual(list(value.shape), [3, 4]) + # Should be registered as initializer + self.assertIn("weight", graph.initializers) + + def test_realize_is_idempotent(self): + _, op = _create_graph_and_op() + p = Parameter([3, 4], name="weight") + v1 = p.realize(op.builder) + v2 = p.realize(op.builder) + self.assertIs(v1, v2) + + def test_realize_with_data(self): + graph, op = _create_graph_and_op() + tensor = ir.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=ir.DataType.FLOAT) + p = Parameter([2, 2], name="weight", data=tensor) + value = p.realize(op.builder) + + self.assertIs(value.const_value, tensor) + self.assertIn("weight", graph.initializers) + # The initializer in the graph IS the parameter itself + self.assertIs(graph.initializers["weight"], p) + + def test_realize_qualifies_name(self): + graph, op = _create_graph_and_op() + op.builder.push_module("layer1") + p = Parameter([3], name="bias") + value = p.realize(op.builder) + op.builder.pop_module() + + self.assertEqual(value.name, "layer1.bias") + self.assertIn("layer1.bias", graph.initializers) + + +class ModuleBasicTest(unittest.TestCase): + def test_parameter_auto_registration(self): + """Parameters assigned as attributes are automatically registered.""" + + class MyModule(Module): + def __init__(self): + super().__init__("my_mod") + self.weight = Parameter([3, 4], name="weight") + self.bias = Parameter([3], name="bias") + + def forward(self, op, x): + pass + + m = MyModule() + param_names = list(m._parameters.keys()) # pylint: disable=protected-access + self.assertEqual(param_names, ["weight", "bias"]) + + def test_module_auto_registration(self): + """Child modules assigned as attributes are automatically registered.""" + + class Child(Module): + def __init__(self): + super().__init__("child") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + p = Parent() + self.assertIn("child", p._modules) # pylint: disable=protected-access + self.assertIs(p._modules["child"], p.child) # pylint: disable=protected-access + + def test_module_inherits_attribute_name(self): + """Child module with no explicit name inherits the attribute name.""" + + class Child(Module): + def __init__(self): + super().__init__() # name=None + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.my_layer = Child() + + def forward(self, op): + pass + + p = Parent() + self.assertEqual(p.my_layer._name, "my_layer") # pylint: disable=protected-access + + def test_parameter_inherits_attribute_name(self): + """Parameter with no explicit name inherits the attribute name.""" + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.weight = Parameter([3, 4]) # name=None + + def forward(self, op, x): + pass + + m = MyModule() + self.assertEqual(m.weight.name, "weight") + + +class ModuleForwardTest(unittest.TestCase): + def test_parameter_is_ir_value_in_forward(self): + """Parameters are ir.Value instances usable directly in forward().""" + captured: dict[str, Any] = {} + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.weight = Parameter([3, 4], name="weight") + + def forward(self, op): + # self.weight IS a Parameter (which IS an ir.Value) + captured["is_ir_value"] = isinstance(self.weight, ir.Value) + captured["is_parameter"] = isinstance(self.weight, Parameter) + captured["weight_name"] = self.weight.name + + _, op = _create_graph_and_op() + m = MyModule() + m(op) + + self.assertTrue(captured["is_ir_value"]) + self.assertTrue(captured["is_parameter"]) + self.assertEqual(captured["weight_name"], "mod.weight") + + # After forward, self.weight is still the same Parameter + self.assertIsInstance(m.weight, Parameter) + + def test_parameter_naming_with_hierarchy(self): + """Parameters get hierarchically qualified names.""" + captured_names: list[str] = [] + + class Inner(Module): + def __init__(self): + super().__init__() + self.w = Parameter([2, 2], name="w") + + def forward(self, op): + captured_names.append(self.w.name) + + class Outer(Module): + def __init__(self): + super().__init__("outer") + self.layer = Inner() + + def forward(self, op): + self.layer(op) + + graph, op = _create_graph_and_op() + m = Outer() + m(op) + + self.assertEqual(captured_names, ["outer.layer.w"]) + self.assertIn("outer.layer.w", graph.initializers) + + def test_forward_with_ops(self): + """Module forward can use OpBuilder to create ONNX nodes.""" + + class AddBias(Module): + def __init__(self, size): + super().__init__("add_bias") + self.bias = Parameter([size], name="bias") + + def forward(self, op, x): + return op.Add(x, self.bias) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([3]), + ) + graph.inputs.append(x) + + m = AddBias(3) + result = m(op, x) + + self.assertIsInstance(result, ir.Value) + nodes = list(graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Add") + + def test_composition_multiple_submodules(self): + """Multiple submodules compose correctly with independent parameters.""" + + class Linear(Module): + def __init__(self, in_f, out_f, name=None): + super().__init__(name) + self.weight = Parameter([out_f, in_f], name="weight") + + def forward(self, op, x): + w_t = op.Transpose(self.weight, perm=[1, 0]) + return op.MatMul(x, w_t) + + class TwoLinear(Module): + def __init__(self): + super().__init__("model") + self.fc1 = Linear(4, 3, name="fc1") + self.fc2 = Linear(3, 2, name="fc2") + + def forward(self, op, x): + h = self.fc1(op, x) + return self.fc2(op, h) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = TwoLinear() + result = m(op, x) + + self.assertIsInstance(result, ir.Value) + # Check initializer names are hierarchical + self.assertIn("model.fc1.weight", graph.initializers) + self.assertIn("model.fc2.weight", graph.initializers) + # Check nodes: Transpose, MatMul, Transpose, MatMul + op_types = [node.op_type for node in graph] + self.assertEqual(op_types, ["Transpose", "MatMul", "Transpose", "MatMul"]) + + +class ModuleIteratorTest(unittest.TestCase): + def test_parameters_iterator(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w1 = Parameter([3], name="w1") + self.w2 = Parameter([4], name="w2") + + def forward(self, op): + pass + + m = MyModule() + params = list(m.parameters()) + self.assertEqual(len(params), 2) + + def test_named_parameters_recursive(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.p = Parameter([2], name="p") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + named = dict(m.named_parameters()) + self.assertIn("p", named) + self.assertIn("child.w", named) + + def test_modules_iterator(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class B(Module): + def __init__(self): + super().__init__("b") + self.a = A() + + def forward(self, op): + pass + + m = B() + mods = list(m.modules()) + self.assertEqual(len(mods), 2) + self.assertIs(mods[0], m) + self.assertIs(mods[1], m.a) + + def test_named_modules_iterator(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class B(Module): + def __init__(self): + super().__init__("b") + self.a = A() + + def forward(self, op): + pass + + m = B() + named = dict(m.named_modules()) + self.assertIn("", named) # self + self.assertIn("a", named) + + +class ModuleReprTest(unittest.TestCase): + def test_repr(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + r = repr(m) + self.assertIn("Parent", r) + self.assertIn("child", r) + self.assertIn("Child", r) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py new file mode 100644 index 0000000000..4591de6a77 --- /dev/null +++ b/onnxscript/nn/_parameter.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Sequence + +import onnx_ir as ir + +from onnxscript._internal.builder import GraphBuilder + + +class Parameter(ir.Value): + """A module parameter that is also an ``ir.Value``. + + Since ``Parameter`` subclasses ``ir.Value``, it can be passed directly + to ONNX ops inside ``Module.forward()`` without any conversion. + Calling :meth:`realize` qualifies the name with the current module + context and registers the parameter as a graph initializer. + + Args: + shape: Shape of the parameter tensor. + dtype: Data type of the parameter. Defaults to FLOAT. + name: Name for the parameter. If None, the attribute name from + the parent Module is used. + data: Optional initial tensor data. If provided, the initializer + will carry this as its const_value. + """ + + def __init__( + self, + shape: Sequence[int | ir.SymbolicDim | None], + dtype: ir.DataType = ir.DataType.FLOAT, + name: str | None = None, + data: ir.TensorProtocol | None = None, + ) -> None: + super().__init__( + name=name, + shape=ir.Shape(shape), + type=ir.TensorType(dtype), + const_value=data, + ) + self._realized = False + + @property + def dtype(self) -> ir.DataType | None: # type: ignore[override] + """Return the element data type of this parameter.""" + return self.type.dtype if self.type is not None else None + + def realize(self, builder: GraphBuilder) -> Parameter: + """Qualify the name and register as a graph initializer. + + Uses direct assignment to ``graph.initializers[...]`` to skip the + const_value check. Idempotent: subsequent calls are no-ops. + """ + if self._realized: + return self + + if self.name: + self.name = builder.qualify_name(self.name) + builder.graph.initializers[self.name] = self # type: ignore[index] + self._realized = True + return self + + def __repr__(self) -> str: + return f"Parameter(shape={list(self.shape)}, dtype={self.dtype}, name={self.name!r})" From 11c880c218ae0728ddda87b8755b3d98403e2a64 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 08:37:27 -0800 Subject: [PATCH 02/25] Add children, named_children, state_dict, load_state_dict to Module - children() / named_children(): iterators for immediate child modules - state_dict(): returns dict mapping param names to tensor data - load_state_dict(): loads tensor data into parameters, with strict mode for missing/unexpected keys Also refactored Parameter to subclass ir.Value directly, eliminating the swap/restore mechanism in Module.__call__. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/nn/_module.py | 73 +++++++++++ onnxscript/nn/_module_test.py | 227 ++++++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index 034760c9c5..c6483145d5 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -5,6 +5,8 @@ from typing import Any, Iterator +import onnx_ir as ir + from onnxscript._internal.builder import GraphBuilder, OpBuilder from onnxscript.nn._parameter import Parameter @@ -109,6 +111,14 @@ def named_parameters( sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name yield from module.named_parameters(prefix=sub_prefix, recurse=True) + def children(self) -> Iterator[Module]: + """Return an iterator over immediate child modules.""" + yield from self._modules.values() + + def named_children(self) -> Iterator[tuple[str, Module]]: + """Return an iterator over immediate child modules, yielding (name, Module) pairs.""" + yield from self._modules.items() + def modules(self) -> Iterator[Module]: """Return an iterator over all modules in the tree (including self).""" yield self @@ -122,6 +132,69 @@ def named_modules(self, prefix: str = "") -> Iterator[tuple[str, Module]]: sub_prefix = f"{prefix}.{name}" if prefix else name yield from module.named_modules(prefix=sub_prefix) + # ------------------------------------------------------------------ + # State dict + # ------------------------------------------------------------------ + + def state_dict(self, prefix: str = "") -> dict[str, ir.TensorProtocol | None]: + """Return a dictionary mapping parameter names to their tensor data. + + Mirrors ``torch.nn.Module.state_dict()``. Keys use dot-separated + hierarchical names (e.g. ``"layer1.weight"``). Values are the + ``const_value`` of each parameter (``None`` if uninitialized). + """ + result: dict[str, ir.TensorProtocol | None] = {} + for name, param in self._parameters.items(): + full_name = f"{prefix}.{name}" if prefix else name + result[full_name] = param.const_value + for mod_name, module in self._modules.items(): + sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name + result.update(module.state_dict(prefix=sub_prefix)) + return result + + def load_state_dict( + self, + state_dict: dict[str, ir.TensorProtocol], + strict: bool = True, + ) -> None: + """Load parameter data from a state dictionary. + + Mirrors ``torch.nn.Module.load_state_dict()``. Sets ``const_value`` + on each matching parameter. + + Args: + state_dict: Mapping of parameter names to tensor data. + strict: If ``True`` (default), raises ``KeyError`` for missing + keys and ``ValueError`` for unexpected keys. + """ + self._load_state_dict_recursive(state_dict, prefix="", strict=strict) + + def _load_state_dict_recursive( + self, + state_dict: dict[str, ir.TensorProtocol], + prefix: str, + strict: bool, + ) -> set[str]: + """Recursively load state and return the set of consumed keys.""" + consumed: set[str] = set() + for name, param in self._parameters.items(): + full_name = f"{prefix}.{name}" if prefix else name + if full_name in state_dict: + param.const_value = state_dict[full_name] + consumed.add(full_name) + elif strict: + raise KeyError(f"Missing key in state_dict: {full_name!r}") + for mod_name, module in self._modules.items(): + sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name + consumed |= module._load_state_dict_recursive( # pylint: disable=protected-access + state_dict, prefix=sub_prefix, strict=strict + ) + if strict and prefix == "": + unexpected = set(state_dict.keys()) - consumed + if unexpected: + raise ValueError(f"Unexpected keys in state_dict: {unexpected}") + return consumed + def __repr__(self) -> str: lines = [f"{type(self).__name__}("] for name, module in self._modules.items(): diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 8f4e8d7083..265c091930 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -378,5 +378,232 @@ def forward(self, op): self.assertIn("Child", r) +class ModuleChildrenTest(unittest.TestCase): + def test_children(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class B(Module): + def __init__(self): + super().__init__("b") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.a = A() + self.b = B() + + def forward(self, op): + pass + + m = Parent() + kids = list(m.children()) + self.assertEqual(len(kids), 2) + self.assertIs(kids[0], m.a) + self.assertIs(kids[1], m.b) + + def test_named_children(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.layer = A() + + def forward(self, op): + pass + + m = Parent() + named = dict(m.named_children()) + self.assertIn("layer", named) + self.assertIs(named["layer"], m.layer) + + def test_children_does_not_recurse(self): + """children() only yields immediate children, not grandchildren.""" + + class Grandchild(Module): + def __init__(self): + super().__init__("gc") + + def forward(self, op): + pass + + class Child(Module): + def __init__(self): + super().__init__("child") + self.gc = Grandchild() + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + kids = list(m.children()) + self.assertEqual(len(kids), 1) + self.assertIs(kids[0], m.child) + + +class ModuleStateDictTest(unittest.TestCase): + def test_state_dict_flat(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + self.b = Parameter([3], name="b") + + def forward(self, op): + pass + + m = MyModule() + sd = m.state_dict() + self.assertIn("w", sd) + self.assertIn("b", sd) + # Uninitialized parameters have None data + self.assertIsNone(sd["w"]) + self.assertIsNone(sd["b"]) + + def test_state_dict_hierarchical(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.p = Parameter([2], name="p") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + sd = m.state_dict() + self.assertIn("p", sd) + self.assertIn("child.w", sd) + + def test_state_dict_with_data(self): + tensor = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + p = Parameter([3], name="w", data=tensor) + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = p + + def forward(self, op): + pass + + m = MyModule() + sd = m.state_dict() + self.assertIs(sd["w"], tensor) + + def test_load_state_dict(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + self.b = Parameter([3], name="b") + + def forward(self, op): + pass + + m = MyModule() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + b_data = ir.tensor([0.1, 0.2, 0.3], dtype=ir.DataType.FLOAT) + m.load_state_dict({"w": w_data, "b": b_data}) + + self.assertIs(m.w.const_value, w_data) + self.assertIs(m.b.const_value, b_data) + + def test_load_state_dict_hierarchical(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + m.load_state_dict({"child.w": w_data}) + self.assertIs(m.child.w.const_value, w_data) + + def test_load_state_dict_strict_missing_key(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + m = MyModule() + with self.assertRaises(KeyError): + m.load_state_dict({}) + + def test_load_state_dict_strict_unexpected_key(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + m = MyModule() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + with self.assertRaises(ValueError): + m.load_state_dict({"w": w_data, "extra": w_data}) + + def test_load_state_dict_non_strict(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + self.b = Parameter([3], name="b") + + def forward(self, op): + pass + + m = MyModule() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + # Only load w, skip b — no error because strict=False + m.load_state_dict({"w": w_data}, strict=False) + self.assertIs(m.w.const_value, w_data) + self.assertIsNone(m.b.const_value) + + if __name__ == "__main__": unittest.main() From b6c084f9dc434661d50bec8afc8a4501547dcae3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 08:38:56 -0800 Subject: [PATCH 03/25] Improve nn module test coverage to 100% Add tests for: - Module.name property - Plain attribute assignment via __setattr__ - NotImplementedError from base Module.forward() - Recursive parameters() with child modules Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/nn/_module_test.py | 58 +++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 265c091930..d64e63b258 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -147,6 +147,40 @@ def forward(self, op, x): m = MyModule() self.assertEqual(m.weight.name, "weight") + def test_name_property(self): + """The name property returns the module name.""" + + class MyModule(Module): + def __init__(self): + super().__init__("my_mod") + + def forward(self, op): + pass + + m = MyModule() + self.assertEqual(m.name, "my_mod") + + def test_setattr_plain_attribute(self): + """Non-Parameter, non-Module attributes are stored normally.""" + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.hidden_size = 128 + + def forward(self, op): + pass + + m = MyModule() + self.assertEqual(m.hidden_size, 128) + + def test_forward_not_implemented(self): + """Calling forward() on base Module raises NotImplementedError.""" + m = Module("base") + _, op = _create_graph_and_op() + with self.assertRaises(NotImplementedError): + m(op) + class ModuleForwardTest(unittest.TestCase): def test_parameter_is_ir_value_in_forward(self): @@ -286,6 +320,30 @@ def forward(self, op): params = list(m.parameters()) self.assertEqual(len(params), 2) + def test_parameters_recursive(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.p = Parameter([2], name="p") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + params = list(m.parameters(recurse=True)) + self.assertEqual(len(params), 2) + params_no_recurse = list(m.parameters(recurse=False)) + self.assertEqual(len(params_no_recurse), 1) + def test_named_parameters_recursive(self): class Child(Module): def __init__(self): From 48b6a14fe45178610e995f0ed49a91e2fc0c7c70 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 08:51:54 -0800 Subject: [PATCH 04/25] Make Parameter.realize() private (_realize()) This is an internal method called by Module.__call__, not part of the public API. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/nn/_module.py | 2 +- onnxscript/nn/_module_test.py | 12 ++++++------ onnxscript/nn/_parameter.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index c6483145d5..2ac18dd4de 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -73,7 +73,7 @@ def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: try: # Realize parameters: qualify names and register as graph initializers. for param in self._parameters.values(): - param.realize(builder) + param._realize(builder) # pylint: disable=protected-access result = self.forward(op, *args, **kwargs) finally: diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index d64e63b258..172ba16a13 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -34,9 +34,9 @@ def test_parameter_repr(self): def test_realize_creates_initializer(self): graph, op = _create_graph_and_op() p = Parameter([3, 4], dtype=ir.DataType.FLOAT, name="weight") - value = p.realize(op.builder) + value = p._realize(op.builder) # pylint: disable=protected-access - self.assertIs(value, p) # realize returns self + self.assertIs(value, p) # _realize returns self self.assertIsInstance(value, ir.Value) self.assertEqual(value.name, "weight") self.assertEqual(value.type.dtype, ir.DataType.FLOAT) @@ -47,15 +47,15 @@ def test_realize_creates_initializer(self): def test_realize_is_idempotent(self): _, op = _create_graph_and_op() p = Parameter([3, 4], name="weight") - v1 = p.realize(op.builder) - v2 = p.realize(op.builder) + v1 = p._realize(op.builder) # pylint: disable=protected-access + v2 = p._realize(op.builder) # pylint: disable=protected-access self.assertIs(v1, v2) def test_realize_with_data(self): graph, op = _create_graph_and_op() tensor = ir.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=ir.DataType.FLOAT) p = Parameter([2, 2], name="weight", data=tensor) - value = p.realize(op.builder) + value = p._realize(op.builder) # pylint: disable=protected-access self.assertIs(value.const_value, tensor) self.assertIn("weight", graph.initializers) @@ -66,7 +66,7 @@ def test_realize_qualifies_name(self): graph, op = _create_graph_and_op() op.builder.push_module("layer1") p = Parameter([3], name="bias") - value = p.realize(op.builder) + value = p._realize(op.builder) # pylint: disable=protected-access op.builder.pop_module() self.assertEqual(value.name, "layer1.bias") diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index 4591de6a77..1d294a90a5 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -47,7 +47,7 @@ def dtype(self) -> ir.DataType | None: # type: ignore[override] """Return the element data type of this parameter.""" return self.type.dtype if self.type is not None else None - def realize(self, builder: GraphBuilder) -> Parameter: + def _realize(self, builder: GraphBuilder) -> Parameter: """Qualify the name and register as a graph initializer. Uses direct assignment to ``graph.initializers[...]`` to skip the From caa95cc68a5a304df0b8471cfc7652c5e3bc94ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 09:34:30 -0800 Subject: [PATCH 05/25] Update onnxscript/nn/_parameter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/nn/_parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index 1d294a90a5..a62a9e03a9 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -15,7 +15,7 @@ class Parameter(ir.Value): Since ``Parameter`` subclasses ``ir.Value``, it can be passed directly to ONNX ops inside ``Module.forward()`` without any conversion. - Calling :meth:`realize` qualifies the name with the current module + Calling :meth:`_realize` qualifies the name with the current module context and registers the parameter as a graph initializer. Args: From 13c8e6d36a2fa1408f8a431528bb5fa6d6c8267f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 09:34:54 -0800 Subject: [PATCH 06/25] Update onnxscript/nn/_parameter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/nn/_parameter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index a62a9e03a9..eac6fc8a07 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -56,6 +56,12 @@ def _realize(self, builder: GraphBuilder) -> Parameter: if self._realized: return self + if self.name is None: + raise ValueError( + "Parameter._realize() called on a Parameter without a name. " + "Ensure the Parameter is attached to a Module attribute or otherwise " + "initialized with a name before realization." + ) if self.name: self.name = builder.qualify_name(self.name) builder.graph.initializers[self.name] = self # type: ignore[index] From bb7fba12817c52b3c815af5600bf580d5ada26b8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 16:29:39 -0800 Subject: [PATCH 07/25] fix: ensure unique output names in GraphBuilder._adapt_outputs Append the node count to auto-generated output names (e.g. Add_output_0 instead of Add_output) to prevent name collisions when the same op type is called multiple times within a module. This matches the existing node naming strategy which already uses the count suffix. Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 6 ++-- onnxscript/_internal/builder_test.py | 54 +++++++++++++++++++++------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 0b626bd100..bfe3c4da7c 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -180,14 +180,16 @@ def _adapt_outputs( self, outputs: int | Sequence[str | ir.Value], op_type: str = "" ) -> Sequence[ir.Value]: if isinstance(outputs, int): + count = self.graph.num_nodes() if outputs < 0: raise ValueError(f"Number of outputs must be non-negative, got {outputs}") if outputs == 1: - name = f"{op_type}_output" if op_type else "output" + name = f"{op_type}_output_{count}" if op_type else f"output_{count}" return [ir.Value(name=self.qualify_name(name))] else: names = [ - f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) + (f"{op_type}_output{i}_{count}" if op_type else f"output{i}_{count}") + for i in range(outputs) ] return [ir.Value(name=self.qualify_name(n)) for n in names] adapted_outputs = [] diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 9edd0b68b4..0eedb4e7f3 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -214,14 +214,14 @@ def _ops_with_default_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names follow the {op_type}_output pattern for single outputs - self.assertEqual(nodes[0].outputs[0].name, "Add_output") - self.assertEqual(nodes[1].outputs[0].name, "Mul_output") - self.assertEqual(nodes[2].outputs[0].name, "Add_output") + # Check output names follow the {op_type}_output_{count} pattern for single outputs + self.assertEqual(nodes[0].outputs[0].name, "Add_output_0") + self.assertEqual(nodes[1].outputs[0].name, "Mul_output_1") + self.assertEqual(nodes[2].outputs[0].name, "Add_output_2") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "Add_output") + self.assertEqual(graph.outputs[0].name, "Add_output_2") def test_hierarchical_naming(self): """Test the hierarchical naming strategy (for value and node names).""" @@ -229,35 +229,35 @@ def test_hierarchical_naming(self): # Test node and value naming at root level t1 = op.Add(x, y) - self.assertEqual(t1.name, "Add_output") + self.assertEqual(t1.name, "Add_output_0") self.assertEqual(t1.producer().name, "Add_node_0") t2 = op.Mul(t1, y) - self.assertEqual(t2.name, "Mul_output") + self.assertEqual(t2.name, "Mul_output_1") self.assertEqual(t2.producer().name, "Mul_node_1") # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1.Add_output") + self.assertEqual(t3.name, "layer1.Add_output_2") self.assertEqual(t3.producer().name, "layer1.Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention.Mul_output") + self.assertEqual(t4.name, "layer1.attention.Mul_output_3") self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1.Add_output") + self.assertEqual(t5.name, "layer1.Add_output_4") self.assertEqual(t5.producer().name, "layer1.Add_node_4") # Pop back to root context op.builder.pop_module() t6 = op.Mul(t5, y) - self.assertEqual(t6.name, "Mul_output") + self.assertEqual(t6.name, "Mul_output_5") self.assertEqual(t6.producer().name, "Mul_node_5") def test_shape_inference_add(self): @@ -311,7 +311,7 @@ def test_custom_domain_with_version(self): # Verify output value is created self.assertIsNotNone(result) - self.assertEqual(result.name, "MicrosoftOp_output") + self.assertEqual(result.name, "MicrosoftOp_output_0") def test_multiple_custom_domain_operations(self): """Test mixing operations from multiple domains.""" @@ -519,6 +519,36 @@ def test_pop_module_raises_on_empty_stack(self): with self.assertRaises(RuntimeError): op.builder.pop_module() + def test_output_names_are_unique_for_same_op_type(self): + """Test that repeated calls to the same op produce unique output names.""" + op, x, y = _create_builder_with_inputs() + + t1 = op.Add(x, y) + t2 = op.Add(x, y) + t3 = op.Add(x, y) + + # Each Add output should have a unique name via the node count suffix + self.assertEqual(t1.name, "Add_output_0") + self.assertEqual(t2.name, "Add_output_1") + self.assertEqual(t3.name, "Add_output_2") + + # Verify all names are distinct + names = [t1.name, t2.name, t3.name] + self.assertEqual(len(set(names)), 3) + + def test_multi_output_names_are_unique(self): + """Test that multi-output ops produce unique names with counter suffix.""" + op, x, y = _create_builder_with_inputs() + + # First multi-output call + out1_a, out1_b = op.TopK(x, 1, axis=-1, _outputs=2) + # Second multi-output call + out2_a, out2_b = op.TopK(y, 1, axis=-1, _outputs=2) + + # Each call should produce unique names + self.assertNotEqual(out1_a.name, out2_a.name) + self.assertNotEqual(out1_b.name, out2_b.name) + def test_attributes_are_created_properly(self): """Test that int, float, str, and list attributes are set correctly on a node.""" op, x, y = _create_builder_with_inputs() From 391ab4c308dd1751434d7e060d9478637ce6bde7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:03:51 -0800 Subject: [PATCH 08/25] test: update builder tests for v_ output name prefix Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 4 ++-- onnxscript/_internal/builder_test.py | 30 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index bfe3c4da7c..06aebe7832 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -184,11 +184,11 @@ def _adapt_outputs( if outputs < 0: raise ValueError(f"Number of outputs must be non-negative, got {outputs}") if outputs == 1: - name = f"{op_type}_output_{count}" if op_type else f"output_{count}" + name = f"v_{op_type}_{count}" if op_type else f"f_{count}" return [ir.Value(name=self.qualify_name(name))] else: names = [ - (f"{op_type}_output{i}_{count}" if op_type else f"output{i}_{count}") + (f"v_{op_type}_{count}_{i}" if op_type else f"v_{count}_{i}") for i in range(outputs) ] return [ir.Value(name=self.qualify_name(n)) for n in names] diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 0eedb4e7f3..049ecb997a 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -214,14 +214,14 @@ def _ops_with_default_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names follow the {op_type}_output_{count} pattern for single outputs - self.assertEqual(nodes[0].outputs[0].name, "Add_output_0") - self.assertEqual(nodes[1].outputs[0].name, "Mul_output_1") - self.assertEqual(nodes[2].outputs[0].name, "Add_output_2") + # Check output names follow the v_{op_type}_{count} pattern for single outputs + self.assertEqual(nodes[0].outputs[0].name, "v_Add_0") + self.assertEqual(nodes[1].outputs[0].name, "v_Mul_1") + self.assertEqual(nodes[2].outputs[0].name, "v_Add_2") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "Add_output_2") + self.assertEqual(graph.outputs[0].name, "v_Add_2") def test_hierarchical_naming(self): """Test the hierarchical naming strategy (for value and node names).""" @@ -229,35 +229,35 @@ def test_hierarchical_naming(self): # Test node and value naming at root level t1 = op.Add(x, y) - self.assertEqual(t1.name, "Add_output_0") + self.assertEqual(t1.name, "v_Add_0") self.assertEqual(t1.producer().name, "Add_node_0") t2 = op.Mul(t1, y) - self.assertEqual(t2.name, "Mul_output_1") + self.assertEqual(t2.name, "v_Mul_1") self.assertEqual(t2.producer().name, "Mul_node_1") # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1.Add_output_2") + self.assertEqual(t3.name, "layer1.v_Add_2") self.assertEqual(t3.producer().name, "layer1.Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention.Mul_output_3") + self.assertEqual(t4.name, "layer1.attention.v_Mul_3") self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1.Add_output_4") + self.assertEqual(t5.name, "layer1.v_Add_4") self.assertEqual(t5.producer().name, "layer1.Add_node_4") # Pop back to root context op.builder.pop_module() t6 = op.Mul(t5, y) - self.assertEqual(t6.name, "Mul_output_5") + self.assertEqual(t6.name, "v_Mul_5") self.assertEqual(t6.producer().name, "Mul_node_5") def test_shape_inference_add(self): @@ -311,7 +311,7 @@ def test_custom_domain_with_version(self): # Verify output value is created self.assertIsNotNone(result) - self.assertEqual(result.name, "MicrosoftOp_output_0") + self.assertEqual(result.name, "v_MicrosoftOp_0") def test_multiple_custom_domain_operations(self): """Test mixing operations from multiple domains.""" @@ -528,9 +528,9 @@ def test_output_names_are_unique_for_same_op_type(self): t3 = op.Add(x, y) # Each Add output should have a unique name via the node count suffix - self.assertEqual(t1.name, "Add_output_0") - self.assertEqual(t2.name, "Add_output_1") - self.assertEqual(t3.name, "Add_output_2") + self.assertEqual(t1.name, "v_Add_0") + self.assertEqual(t2.name, "v_Add_1") + self.assertEqual(t3.name, "v_Add_2") # Verify all names are distinct names = [t1.name, t2.name, t3.name] From 7f0e6bf5fff81b53a449659adeabd671c67dafc1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:09:06 -0800 Subject: [PATCH 09/25] fix: use / separator for node and value scope names ONNX convention uses / for node scope hierarchies. Add _qualify_node_name() that uses / as separator, keeping qualify_name() with . for parameter/initializer names. Node names and auto-generated value names now use / (e.g. layer1/Add_node_0, layer1/v_Add_0) while parameters keep . (e.g. layer1.weight). Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 24 ++++++++++++++++++------ onnxscript/_internal/builder_test.py | 22 +++++++++++----------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 06aebe7832..ff065af17b 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -185,21 +185,21 @@ def _adapt_outputs( raise ValueError(f"Number of outputs must be non-negative, got {outputs}") if outputs == 1: name = f"v_{op_type}_{count}" if op_type else f"f_{count}" - return [ir.Value(name=self.qualify_name(name))] + return [ir.Value(name=self._qualify_node_name(name))] else: names = [ (f"v_{op_type}_{count}_{i}" if op_type else f"v_{count}_{i}") for i in range(outputs) ] - return [ir.Value(name=self.qualify_name(n)) for n in names] + return [ir.Value(name=self._qualify_node_name(n)) for n in names] adapted_outputs = [] for output in outputs: if isinstance(output, ir.Value): if output.name: - output.name = self.qualify_name(output.name) + output.name = self._qualify_node_name(output.name) adapted_outputs.append(output) elif isinstance(output, str): - adapted_outputs.append(ir.Value(name=self.qualify_name(output))) + adapted_outputs.append(ir.Value(name=self._qualify_node_name(output))) else: raise TypeError("Output type not supported.") return adapted_outputs @@ -306,7 +306,7 @@ def call_op( outputs = kwargs.pop("_outputs", 1) count = self.graph.num_nodes() - node_name = self.qualify_name(f"{op_type}_node_{count}") + node_name = self._qualify_node_name(f"{op_type}_node_{count}") output_values = self._adapt_outputs(outputs, op_type) @@ -348,10 +348,22 @@ def context_name(self) -> str: return self._context_stack[-1] if self._context_stack else "" def qualify_name(self, name: str) -> str: - """Prepend the current hierarchical context prefix to the given name.""" + """Prepend the current hierarchical context prefix to the given name. + + Uses ``.`` as separator, appropriate for parameter and initializer names. + """ prefix = self.context_name() return f"{prefix}.{name}" if prefix else name + def _qualify_node_name(self, name: str) -> str: + """Prepend the current hierarchical context prefix to a node name. + + Uses ``/`` as separator, following the ONNX convention for node + scope hierarchies. + """ + prefix = self.context_name() + return f"{prefix}/{name}" if prefix else name + class OpBuilder: """Dynamic op dispatcher that translates attribute access into ONNX node creation via a GraphBuilder.""" diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 049ecb997a..f6ace000c3 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -148,17 +148,17 @@ def test_value_naming_with_hierarchy(self): # Test custom names with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=["my_mul"]) - self.assertEqual(t2.name, "layer1.my_mul") + self.assertEqual(t2.name, "layer1/my_mul") # Test nested hierarchical context with custom names op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=["my_nested_add"]) - self.assertEqual(t3.name, "layer1.attention.my_nested_add") + self.assertEqual(t3.name, "layer1.attention/my_nested_add") # Pop back and verify prefix is applied correctly op.builder.pop_module() t4 = op.Mul(t3, y, _outputs=["another_mul"]) - self.assertEqual(t4.name, "layer1.another_mul") + self.assertEqual(t4.name, "layer1/another_mul") op.builder.pop_module() t5 = op.Add(t4, x, _outputs=["final_result"]) @@ -181,13 +181,13 @@ def test_value_naming_with_ir_value_objects(self): # Test with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=[out2]) - self.assertEqual(t2.name, "layer1.layer_output") + self.assertEqual(t2.name, "layer1/layer_output") self.assertIs(t2, out2) # Test nested hierarchical context op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=[out3]) - self.assertEqual(t3.name, "layer1.attention.nested_output") + self.assertEqual(t3.name, "layer1.attention/nested_output") self.assertIs(t3, out3) def test_default_output_naming_strategy(self): @@ -239,20 +239,20 @@ def test_hierarchical_naming(self): # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1.v_Add_2") - self.assertEqual(t3.producer().name, "layer1.Add_node_2") + self.assertEqual(t3.name, "layer1/v_Add_2") + self.assertEqual(t3.producer().name, "layer1/Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention.v_Mul_3") - self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") + self.assertEqual(t4.name, "layer1.attention/v_Mul_3") + self.assertEqual(t4.producer().name, "layer1.attention/Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1.v_Add_4") - self.assertEqual(t5.producer().name, "layer1.Add_node_4") + self.assertEqual(t5.name, "layer1/v_Add_4") + self.assertEqual(t5.producer().name, "layer1/Add_node_4") # Pop back to root context op.builder.pop_module() From e00bb6dde36e041344786bb77a4e4c2de89fd710 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:14:32 -0800 Subject: [PATCH 10/25] fix: conditionally push and pop module in __call__ based on module name Signed-off-by: Justin Chu --- onnxscript/nn/_module.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index 2ac18dd4de..df01e0da3e 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -68,8 +68,9 @@ def __setattr__(self, name: str, value: Any) -> None: def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: builder: GraphBuilder = op.builder - module_name = self._name or "" - builder.push_module(module_name) + has_name = bool(self._name) # Only push if we have a name + if has_name: + builder.push_module(self._name) try: # Realize parameters: qualify names and register as graph initializers. for param in self._parameters.values(): @@ -77,7 +78,8 @@ def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: result = self.forward(op, *args, **kwargs) finally: - builder.pop_module() + if has_name: + builder.pop_module() return result def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: From 21a5c13d60e37eb0ff5a09d4d902251933efd407 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:32:52 -0800 Subject: [PATCH 11/25] fix: use / throughout node scope prefix, not just at the join point _qualify_node_name now replaces all dots in the context prefix with / so nested scopes produce names like layer1/attention/Add_node_0 instead of layer1.attention/Add_node_0. Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 4 +++- onnxscript/_internal/builder_test.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index ff065af17b..1dc6a713d0 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -362,7 +362,9 @@ def _qualify_node_name(self, name: str) -> str: scope hierarchies. """ prefix = self.context_name() - return f"{prefix}/{name}" if prefix else name + if not prefix: + return name + return f"{prefix.replace('.', '/')}/{name}" class OpBuilder: diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index f6ace000c3..b218f25589 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -153,7 +153,7 @@ def test_value_naming_with_hierarchy(self): # Test nested hierarchical context with custom names op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=["my_nested_add"]) - self.assertEqual(t3.name, "layer1.attention/my_nested_add") + self.assertEqual(t3.name, "layer1/attention/my_nested_add") # Pop back and verify prefix is applied correctly op.builder.pop_module() @@ -187,7 +187,7 @@ def test_value_naming_with_ir_value_objects(self): # Test nested hierarchical context op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=[out3]) - self.assertEqual(t3.name, "layer1.attention/nested_output") + self.assertEqual(t3.name, "layer1/attention/nested_output") self.assertIs(t3, out3) def test_default_output_naming_strategy(self): @@ -245,8 +245,8 @@ def test_hierarchical_naming(self): # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention/v_Mul_3") - self.assertEqual(t4.producer().name, "layer1.attention/Mul_node_3") + self.assertEqual(t4.name, "layer1/attention/v_Mul_3") + self.assertEqual(t4.producer().name, "layer1/attention/Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() From 75b0751f43124046f622fa2f44783c8498f444c8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:42:36 -0800 Subject: [PATCH 12/25] feat: redesign module scope system with node metadata - Change _context_stack from cumulative strings to a list of (name, class_name) tuples in _scope_stack - push_module(name, class_name) stores individual scope entries; names like 'layers.3' are kept as single scopes (no dot splitting) - qualify_name uses '.' join for parameters (layers.3.self_attn.weight) - _qualify_value_name uses '/' join with v_ prefix (layers.3/self_attn/v_Add_0) - _qualify_node_name uses '/' join (layers.3/self_attn/Add_node_0) - Module.__call__ passes type(self).__qualname__ as class_name - call_op attaches metadata_props to every node: - namespace: scope path + op type (e.g. 'layer1/self_attn: Add') - pkg.onnxscript.class_hierarchy: list of class names + op type - pkg.onnxscript.name_scopes: list of scope names Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 102 ++++++++++++++++++--------- onnxscript/_internal/builder_test.py | 68 ++++++++++++++---- onnxscript/nn/_module.py | 9 ++- 3 files changed, 128 insertions(+), 51 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 1dc6a713d0..95018b0f2d 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -80,10 +80,10 @@ def __init__(self, graph: ir.Graph) -> None: self._op_builder = self.opset("", opset_version) - # Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done. - # The current context is used as a prefix for naming values and nodes. - # This allows us to generate names like "layer1.attention.query" - self._context_stack: list[str] = [""] + # Module scope stack. Each entry is (name, class_name) where name is + # the module attribute name (e.g. "layers.0", "self_attn") and + # class_name is the qualified class name (e.g. "Gemma3DecoderLayer"). + self._scope_stack: list[tuple[str, str]] = [] # Cache for constant initializers (scalars and sequences), keyed by (value, dtype). # This avoids creating duplicate initializers for the same constant @@ -184,22 +184,22 @@ def _adapt_outputs( if outputs < 0: raise ValueError(f"Number of outputs must be non-negative, got {outputs}") if outputs == 1: - name = f"v_{op_type}_{count}" if op_type else f"f_{count}" - return [ir.Value(name=self._qualify_node_name(name))] + name = f"{op_type}_{count}" if op_type else f"{count}" + return [ir.Value(name=self._qualify_value_name(name))] else: names = [ - (f"v_{op_type}_{count}_{i}" if op_type else f"v_{count}_{i}") + (f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}") for i in range(outputs) ] - return [ir.Value(name=self._qualify_node_name(n)) for n in names] + return [ir.Value(name=self._qualify_value_name(n)) for n in names] adapted_outputs = [] for output in outputs: if isinstance(output, ir.Value): if output.name: - output.name = self._qualify_node_name(output.name) + output.name = self._qualify_value_name(output.name) adapted_outputs.append(output) elif isinstance(output, str): - adapted_outputs.append(ir.Value(name=self._qualify_node_name(output))) + adapted_outputs.append(ir.Value(name=self._qualify_value_name(output))) else: raise TypeError("Output type not supported.") return adapted_outputs @@ -324,47 +324,83 @@ def call_op( version=version, name=node_name, ) + + # Attach scope metadata to the node + node.metadata_props["namespace"] = self._build_namespace(op_type, domain) + class_hierarchy = self.scope_classes() + op_id = f"{domain}.{op_type}" if domain else op_type + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr( + class_hierarchy + [op_id] + ) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr( + self.scope_names() + ) + self.add_node(node) return node.outputs if len(node.outputs) > 1 else node.outputs[0] - def push_module(self, module: str) -> None: - """Push a new naming context onto the stack (e.g. a layer or module name).""" - current = self.context_name() - if module: - new_context = f"{current}.{module}" if current else module - else: - new_context = current - self._context_stack.append(new_context) + def push_module(self, module: str, class_name: str = "") -> None: + """Push a new module scope onto the stack. + + Args: + module: The attribute name of the module (e.g. ``"layers.0"``). + class_name: The qualified class name (e.g. ``"Gemma3DecoderLayer"``). + """ + self._scope_stack.append((module, class_name)) def pop_module(self) -> None: - """Pop the most recent naming context off the stack.""" - if len(self._context_stack) <= 1: + """Pop the most recent module scope off the stack.""" + if not self._scope_stack: raise RuntimeError("Cannot pop_module: no module context has been pushed.") - self._context_stack.pop() + self._scope_stack.pop() + + def scope_names(self) -> list[str]: + """Return the list of module attribute names in the current scope.""" + return [name for name, _ in self._scope_stack if name] - def context_name(self) -> str: - """Return the current dot-separated naming context prefix.""" - return self._context_stack[-1] if self._context_stack else "" + def scope_classes(self) -> list[str]: + """Return the list of class names in the current scope.""" + return [cls for _, cls in self._scope_stack if cls] def qualify_name(self, name: str) -> str: """Prepend the current hierarchical context prefix to the given name. Uses ``.`` as separator, appropriate for parameter and initializer names. """ - prefix = self.context_name() - return f"{prefix}.{name}" if prefix else name + parts = self.scope_names() + if parts: + return ".".join(parts) + "." + name + return name + + def _qualify_value_name(self, name: str) -> str: + """Qualify a value name with the current scope using ``/`` separator. + + The name is prefixed with ``v_`` to distinguish values from parameters. + """ + parts = self.scope_names() + qualified = f"v_{name}" + if parts: + return "/".join(parts) + "/" + qualified + return qualified def _qualify_node_name(self, name: str) -> str: - """Prepend the current hierarchical context prefix to a node name. + """Qualify a node name with the current scope using ``/`` separator.""" + parts = self.scope_names() + if parts: + return "/".join(parts) + "/" + name + return name + + def _build_namespace(self, op_type: str, domain: str = "") -> str: + """Build the namespace string for a node. - Uses ``/`` as separator, following the ONNX convention for node - scope hierarchies. + Format: ``scope1/scope2: domain.op_type`` or ``scope1/scope2: op_type``. """ - prefix = self.context_name() - if not prefix: - return name - return f"{prefix.replace('.', '/')}/{name}" + scope = "/".join(self.scope_names()) + op_id = f"{domain}.{op_type}" if domain else op_type + if scope: + return f"{scope}: {op_id}" + return op_id class OpBuilder: diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index b218f25589..5f53d7b849 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -128,14 +128,14 @@ def _add_with_custom_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names - self.assertEqual(nodes[0].outputs[0].name, "add_result") - self.assertEqual(nodes[1].outputs[0].name, "mul_result") - self.assertEqual(nodes[2].outputs[0].name, "final_add") + # Check output names (v_ prefix is added to all value names) + self.assertEqual(nodes[0].outputs[0].name, "v_add_result") + self.assertEqual(nodes[1].outputs[0].name, "v_mul_result") + self.assertEqual(nodes[2].outputs[0].name, "v_final_add") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "final_add") + self.assertEqual(graph.outputs[0].name, "v_final_add") def test_value_naming_with_hierarchy(self): """Test that hierarchical naming works with user-specified output names.""" @@ -143,26 +143,26 @@ def test_value_naming_with_hierarchy(self): # Test custom names at root level t1 = op.Add(x, y, _outputs=["my_add"]) - self.assertEqual(t1.name, "my_add") + self.assertEqual(t1.name, "v_my_add") # Test custom names with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=["my_mul"]) - self.assertEqual(t2.name, "layer1/my_mul") + self.assertEqual(t2.name, "layer1/v_my_mul") # Test nested hierarchical context with custom names op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=["my_nested_add"]) - self.assertEqual(t3.name, "layer1/attention/my_nested_add") + self.assertEqual(t3.name, "layer1/attention/v_my_nested_add") # Pop back and verify prefix is applied correctly op.builder.pop_module() t4 = op.Mul(t3, y, _outputs=["another_mul"]) - self.assertEqual(t4.name, "layer1/another_mul") + self.assertEqual(t4.name, "layer1/v_another_mul") op.builder.pop_module() t5 = op.Add(t4, x, _outputs=["final_result"]) - self.assertEqual(t5.name, "final_result") + self.assertEqual(t5.name, "v_final_result") def test_value_naming_with_ir_value_objects(self): """Test that hierarchical naming works when passing ir.Value objects as _outputs.""" @@ -175,19 +175,19 @@ def test_value_naming_with_ir_value_objects(self): # Test at root level t1 = op.Add(x, y, _outputs=[out1]) - self.assertEqual(t1.name, "my_output") + self.assertEqual(t1.name, "v_my_output") self.assertIs(t1, out1) # Test with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=[out2]) - self.assertEqual(t2.name, "layer1/layer_output") + self.assertEqual(t2.name, "layer1/v_layer_output") self.assertIs(t2, out2) # Test nested hierarchical context op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=[out3]) - self.assertEqual(t3.name, "layer1/attention/nested_output") + self.assertEqual(t3.name, "layer1/attention/v_nested_output") self.assertIs(t3, out3) def test_default_output_naming_strategy(self): @@ -549,6 +549,48 @@ def test_multi_output_names_are_unique(self): self.assertNotEqual(out1_a.name, out2_a.name) self.assertNotEqual(out1_b.name, out2_b.name) + def test_node_metadata_props_namespace(self): + """Test that nodes have namespace metadata matching the scope hierarchy.""" + op, x, y = _create_builder_with_inputs() + + # Root-level node + t1 = op.Add(x, y) + self.assertEqual(t1.producer().metadata_props["namespace"], "Add") + + # Node inside a module scope + op.builder.push_module("layer1", "DecoderLayer") + t2 = op.Mul(t1, y) + self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: Mul") + + # Nested scope + op.builder.push_module("self_attn", "Attention") + t3 = op.Add(t2, x) + self.assertEqual( + t3.producer().metadata_props["namespace"], "layer1/self_attn: Add" + ) + op.builder.pop_module() + op.builder.pop_module() + + def test_node_metadata_props_class_hierarchy(self): + """Test that nodes have class hierarchy metadata.""" + op, x, y = _create_builder_with_inputs() + + op.builder.push_module("layer1", "DecoderLayer") + op.builder.push_module("self_attn", "Attention") + t1 = op.MatMul(x, y) + node = t1.producer() + + self.assertEqual( + node.metadata_props["pkg.onnxscript.class_hierarchy"], + repr(["DecoderLayer", "Attention", "MatMul"]), + ) + self.assertEqual( + node.metadata_props["pkg.onnxscript.name_scopes"], + repr(["layer1", "self_attn"]), + ) + op.builder.pop_module() + op.builder.pop_module() + def test_attributes_are_created_properly(self): """Test that int, float, str, and list attributes are set correctly on a node.""" op, x, y = _create_builder_with_inputs() diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index df01e0da3e..0f6abfdb8d 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -68,9 +68,9 @@ def __setattr__(self, name: str, value: Any) -> None: def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: builder: GraphBuilder = op.builder - has_name = bool(self._name) # Only push if we have a name - if has_name: - builder.push_module(self._name) + module_name = self._name or "" + class_name = type(self).__qualname__ + builder.push_module(module_name, class_name) try: # Realize parameters: qualify names and register as graph initializers. for param in self._parameters.values(): @@ -78,8 +78,7 @@ def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: result = self.forward(op, *args, **kwargs) finally: - if has_name: - builder.pop_module() + builder.pop_module() return result def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: From 3135535c3e121f1b3025ea524b0b02543c64b9bc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:43:31 -0800 Subject: [PATCH 13/25] refactor: rename qualify_name to _qualify_initializer_name Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 4 ++-- onnxscript/nn/_parameter.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 95018b0f2d..5fe8dd6306 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -109,7 +109,7 @@ def initializer( if name is None: name = tensor.name if qualify: - name = self.qualify_name(name) + name = self._qualify_initializer_name(name) shape = ir.Shape(tensor.shape) value = ir.Value( name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor @@ -363,7 +363,7 @@ def scope_classes(self) -> list[str]: """Return the list of class names in the current scope.""" return [cls for _, cls in self._scope_stack if cls] - def qualify_name(self, name: str) -> str: + def _qualify_initializer_name(self, name: str) -> str: """Prepend the current hierarchical context prefix to the given name. Uses ``.`` as separator, appropriate for parameter and initializer names. diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index eac6fc8a07..8f45d99123 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -63,7 +63,7 @@ def _realize(self, builder: GraphBuilder) -> Parameter: "initialized with a name before realization." ) if self.name: - self.name = builder.qualify_name(self.name) + self.name = builder._qualify_initializer_name(self.name) builder.graph.initializers[self.name] = self # type: ignore[index] self._realized = True return self From 67ffb426a6a18776c982b37304828c75e5a0443c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:45:54 -0800 Subject: [PATCH 14/25] refactor: use '.' delimiter for value names, '/' for node names Value names use '.' separator with 'v_' prefix (e.g. v_layer1.attention.Add_0). Node names and namespace strings use '/' separator (e.g. layer1/attention/Add_node_0). Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 7 +++---- onnxscript/_internal/builder_test.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 5fe8dd6306..4f5ffba541 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -374,15 +374,14 @@ def _qualify_initializer_name(self, name: str) -> str: return name def _qualify_value_name(self, name: str) -> str: - """Qualify a value name with the current scope using ``/`` separator. + """Qualify a value name with the current scope using ``.`` separator. The name is prefixed with ``v_`` to distinguish values from parameters. """ parts = self.scope_names() - qualified = f"v_{name}" if parts: - return "/".join(parts) + "/" + qualified - return qualified + return "v_" + ".".join(parts) + "." + name + return f"v_{name}" def _qualify_node_name(self, name: str) -> str: """Qualify a node name with the current scope using ``/`` separator.""" diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 5f53d7b849..5fea6a11ed 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -148,17 +148,17 @@ def test_value_naming_with_hierarchy(self): # Test custom names with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=["my_mul"]) - self.assertEqual(t2.name, "layer1/v_my_mul") + self.assertEqual(t2.name, "v_layer1.my_mul") # Test nested hierarchical context with custom names op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=["my_nested_add"]) - self.assertEqual(t3.name, "layer1/attention/v_my_nested_add") + self.assertEqual(t3.name, "v_layer1.attention.my_nested_add") # Pop back and verify prefix is applied correctly op.builder.pop_module() t4 = op.Mul(t3, y, _outputs=["another_mul"]) - self.assertEqual(t4.name, "layer1/v_another_mul") + self.assertEqual(t4.name, "v_layer1.another_mul") op.builder.pop_module() t5 = op.Add(t4, x, _outputs=["final_result"]) @@ -181,13 +181,13 @@ def test_value_naming_with_ir_value_objects(self): # Test with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=[out2]) - self.assertEqual(t2.name, "layer1/v_layer_output") + self.assertEqual(t2.name, "v_layer1.layer_output") self.assertIs(t2, out2) # Test nested hierarchical context op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=[out3]) - self.assertEqual(t3.name, "layer1/attention/v_nested_output") + self.assertEqual(t3.name, "v_layer1.attention.nested_output") self.assertIs(t3, out3) def test_default_output_naming_strategy(self): @@ -239,19 +239,19 @@ def test_hierarchical_naming(self): # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1/v_Add_2") + self.assertEqual(t3.name, "v_layer1.Add_2") self.assertEqual(t3.producer().name, "layer1/Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1/attention/v_Mul_3") + self.assertEqual(t4.name, "v_layer1.attention.Mul_3") self.assertEqual(t4.producer().name, "layer1/attention/Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1/v_Add_4") + self.assertEqual(t5.name, "v_layer1.Add_4") self.assertEqual(t5.producer().name, "layer1/Add_node_4") # Pop back to root context From feb1ff7e1d3dedf9d00d3872e138ce432417b7e4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:54:26 -0800 Subject: [PATCH 15/25] fix: namespace uses 'name: class' pairs, align scope list lengths - namespace: each scope entry is 'name: class_name' joined by '/' e.g. 'layer1: DecoderLayer/self_attn: Attention/Add' - scope_names() and scope_classes() return all entries (no filtering) so class_hierarchy and name_scopes always have matching lengths - _scope_name_parts() filters empty names for initializer/value/node qualifying Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 26 ++++++++++++++++---------- onnxscript/_internal/builder_test.py | 9 +++++++-- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 4f5ffba541..d380907338 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -357,18 +357,22 @@ def pop_module(self) -> None: def scope_names(self) -> list[str]: """Return the list of module attribute names in the current scope.""" - return [name for name, _ in self._scope_stack if name] + return [name for name, _ in self._scope_stack] def scope_classes(self) -> list[str]: """Return the list of class names in the current scope.""" - return [cls for _, cls in self._scope_stack if cls] + return [cls for _, cls in self._scope_stack] + + def _scope_name_parts(self) -> list[str]: + """Return non-empty module names for qualifying names.""" + return [name for name, _ in self._scope_stack if name] def _qualify_initializer_name(self, name: str) -> str: """Prepend the current hierarchical context prefix to the given name. Uses ``.`` as separator, appropriate for parameter and initializer names. """ - parts = self.scope_names() + parts = self._scope_name_parts() if parts: return ".".join(parts) + "." + name return name @@ -378,14 +382,14 @@ def _qualify_value_name(self, name: str) -> str: The name is prefixed with ``v_`` to distinguish values from parameters. """ - parts = self.scope_names() + parts = self._scope_name_parts() if parts: return "v_" + ".".join(parts) + "." + name return f"v_{name}" def _qualify_node_name(self, name: str) -> str: """Qualify a node name with the current scope using ``/`` separator.""" - parts = self.scope_names() + parts = self._scope_name_parts() if parts: return "/".join(parts) + "/" + name return name @@ -393,13 +397,15 @@ def _qualify_node_name(self, name: str) -> str: def _build_namespace(self, op_type: str, domain: str = "") -> str: """Build the namespace string for a node. - Format: ``scope1/scope2: domain.op_type`` or ``scope1/scope2: op_type``. + Each scope entry is formatted as ``name: class_name`` joined by ``/``. """ - scope = "/".join(self.scope_names()) + parts = [] + for name, cls in self._scope_stack: + if name or cls: + parts.append(f"{name}: {cls}" if cls else name) op_id = f"{domain}.{op_type}" if domain else op_type - if scope: - return f"{scope}: {op_id}" - return op_id + parts.append(op_id) + return "/".join(parts) class OpBuilder: diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 5fea6a11ed..aef809f768 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -560,13 +560,13 @@ def test_node_metadata_props_namespace(self): # Node inside a module scope op.builder.push_module("layer1", "DecoderLayer") t2 = op.Mul(t1, y) - self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: Mul") + self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: DecoderLayer/Mul") # Nested scope op.builder.push_module("self_attn", "Attention") t3 = op.Add(t2, x) self.assertEqual( - t3.producer().metadata_props["namespace"], "layer1/self_attn: Add" + t3.producer().metadata_props["namespace"], "layer1: DecoderLayer/self_attn: Attention/Add" ) op.builder.pop_module() op.builder.pop_module() @@ -588,6 +588,11 @@ def test_node_metadata_props_class_hierarchy(self): node.metadata_props["pkg.onnxscript.name_scopes"], repr(["layer1", "self_attn"]), ) + # class_hierarchy includes one entry per scope plus the op + self.assertEqual( + len(eval(node.metadata_props["pkg.onnxscript.class_hierarchy"])), + len(eval(node.metadata_props["pkg.onnxscript.name_scopes"])) + 1, + ) op.builder.pop_module() op.builder.pop_module() From 1a1a7a62ef8fa95bf97af1e9d211e24102758352 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 17:58:03 -0800 Subject: [PATCH 16/25] Update names Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 14 +++++--------- onnxscript/_internal/builder_test.py | 3 ++- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index d380907338..91e4b92a9a 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -327,14 +327,10 @@ def call_op( # Attach scope metadata to the node node.metadata_props["namespace"] = self._build_namespace(op_type, domain) - class_hierarchy = self.scope_classes() + class_hierarchy = self._scope_classes() op_id = f"{domain}.{op_type}" if domain else op_type - node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr( - class_hierarchy + [op_id] - ) - node.metadata_props["pkg.onnxscript.name_scopes"] = repr( - self.scope_names() - ) + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr([*class_hierarchy, op_id]) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) self.add_node(node) @@ -355,11 +351,11 @@ def pop_module(self) -> None: raise RuntimeError("Cannot pop_module: no module context has been pushed.") self._scope_stack.pop() - def scope_names(self) -> list[str]: + def _scope_names(self) -> list[str]: """Return the list of module attribute names in the current scope.""" return [name for name, _ in self._scope_stack] - def scope_classes(self) -> list[str]: + def _scope_classes(self) -> list[str]: """Return the list of class names in the current scope.""" return [cls for _, cls in self._scope_stack] diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index aef809f768..0b1e027bf3 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -566,7 +566,8 @@ def test_node_metadata_props_namespace(self): op.builder.push_module("self_attn", "Attention") t3 = op.Add(t2, x) self.assertEqual( - t3.producer().metadata_props["namespace"], "layer1: DecoderLayer/self_attn: Attention/Add" + t3.producer().metadata_props["namespace"], + "layer1: DecoderLayer/self_attn: Attention/Add", ) op.builder.pop_module() op.builder.pop_module() From 9cd8f20965d1f9345ef9cec486007d2647ae684d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 18:02:20 -0800 Subject: [PATCH 17/25] fix: append op_type to name_scopes so lengths match class_hierarchy Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 9 ++++++--- onnxscript/_internal/builder_test.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 91e4b92a9a..f546469654 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -327,10 +327,13 @@ def call_op( # Attach scope metadata to the node node.metadata_props["namespace"] = self._build_namespace(op_type, domain) - class_hierarchy = self._scope_classes() op_id = f"{domain}.{op_type}" if domain else op_type - node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr([*class_hierarchy, op_id]) - node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr( + [*self._scope_classes(), op_id] + ) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr( + [*self._scope_names(), op_type] + ) self.add_node(node) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 0b1e027bf3..6f61c7364b 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -587,12 +587,12 @@ def test_node_metadata_props_class_hierarchy(self): ) self.assertEqual( node.metadata_props["pkg.onnxscript.name_scopes"], - repr(["layer1", "self_attn"]), + repr(["layer1", "self_attn", "MatMul"]), ) - # class_hierarchy includes one entry per scope plus the op + # class_hierarchy and name_scopes have the same length self.assertEqual( len(eval(node.metadata_props["pkg.onnxscript.class_hierarchy"])), - len(eval(node.metadata_props["pkg.onnxscript.name_scopes"])) + 1, + len(eval(node.metadata_props["pkg.onnxscript.name_scopes"])), ) op.builder.pop_module() op.builder.pop_module() From 77e0d6870f65bdb5fcde0831623030ae60600817 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 18:04:22 -0800 Subject: [PATCH 18/25] fix: remove op_type from namespace, class_hierarchy, and name_scopes namespace now only contains module scopes (e.g. 'layer1: DecoderLayer/self_attn: Attention'). class_hierarchy and name_scopes reflect only the module stack, not the op. Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 11 ++++------- onnxscript/_internal/builder_test.py | 10 +++++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index f546469654..3026ba1b12 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -326,13 +326,12 @@ def call_op( ) # Attach scope metadata to the node - node.metadata_props["namespace"] = self._build_namespace(op_type, domain) - op_id = f"{domain}.{op_type}" if domain else op_type + node.metadata_props["namespace"] = self._build_namespace() node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr( - [*self._scope_classes(), op_id] + self._scope_classes() ) node.metadata_props["pkg.onnxscript.name_scopes"] = repr( - [*self._scope_names(), op_type] + self._scope_names() ) self.add_node(node) @@ -393,7 +392,7 @@ def _qualify_node_name(self, name: str) -> str: return "/".join(parts) + "/" + name return name - def _build_namespace(self, op_type: str, domain: str = "") -> str: + def _build_namespace(self) -> str: """Build the namespace string for a node. Each scope entry is formatted as ``name: class_name`` joined by ``/``. @@ -402,8 +401,6 @@ def _build_namespace(self, op_type: str, domain: str = "") -> str: for name, cls in self._scope_stack: if name or cls: parts.append(f"{name}: {cls}" if cls else name) - op_id = f"{domain}.{op_type}" if domain else op_type - parts.append(op_id) return "/".join(parts) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 6f61c7364b..e95bb24025 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -555,19 +555,19 @@ def test_node_metadata_props_namespace(self): # Root-level node t1 = op.Add(x, y) - self.assertEqual(t1.producer().metadata_props["namespace"], "Add") + self.assertEqual(t1.producer().metadata_props["namespace"], "") # Node inside a module scope op.builder.push_module("layer1", "DecoderLayer") t2 = op.Mul(t1, y) - self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: DecoderLayer/Mul") + self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: DecoderLayer") # Nested scope op.builder.push_module("self_attn", "Attention") t3 = op.Add(t2, x) self.assertEqual( t3.producer().metadata_props["namespace"], - "layer1: DecoderLayer/self_attn: Attention/Add", + "layer1: DecoderLayer/self_attn: Attention", ) op.builder.pop_module() op.builder.pop_module() @@ -583,11 +583,11 @@ def test_node_metadata_props_class_hierarchy(self): self.assertEqual( node.metadata_props["pkg.onnxscript.class_hierarchy"], - repr(["DecoderLayer", "Attention", "MatMul"]), + repr(["DecoderLayer", "Attention"]), ) self.assertEqual( node.metadata_props["pkg.onnxscript.name_scopes"], - repr(["layer1", "self_attn", "MatMul"]), + repr(["layer1", "self_attn"]), ) # class_hierarchy and name_scopes have the same length self.assertEqual( From 961320026367e244194ab1357985c2029d909ce0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 18:04:35 -0800 Subject: [PATCH 19/25] Format Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 3026ba1b12..4589c99e59 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -327,12 +327,8 @@ def call_op( # Attach scope metadata to the node node.metadata_props["namespace"] = self._build_namespace() - node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr( - self._scope_classes() - ) - node.metadata_props["pkg.onnxscript.name_scopes"] = repr( - self._scope_names() - ) + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes()) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) self.add_node(node) From 5667be7b69cb03b51fa4754ea5a525bb7517c5c3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 18:08:03 -0800 Subject: [PATCH 20/25] refactor: use ast.literal_eval instead of eval in tests Signed-off-by: Justin Chu --- onnxscript/_internal/builder_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index e95bb24025..209ccb6165 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -3,6 +3,7 @@ from __future__ import annotations +import ast import unittest from typing import Sequence @@ -591,8 +592,8 @@ def test_node_metadata_props_class_hierarchy(self): ) # class_hierarchy and name_scopes have the same length self.assertEqual( - len(eval(node.metadata_props["pkg.onnxscript.class_hierarchy"])), - len(eval(node.metadata_props["pkg.onnxscript.name_scopes"])), + len(ast.literal_eval(node.metadata_props["pkg.onnxscript.class_hierarchy"])), + len(ast.literal_eval(node.metadata_props["pkg.onnxscript.name_scopes"])), ) op.builder.pop_module() op.builder.pop_module() From 6bb9931eef991d5e2b31da915ae927b2770340e9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Feb 2026 18:12:43 -0800 Subject: [PATCH 21/25] Update name check Signed-off-by: Justin Chu --- onnxscript/nn/_parameter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index 8f45d99123..fbf2f63cfc 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -56,15 +56,15 @@ def _realize(self, builder: GraphBuilder) -> Parameter: if self._realized: return self - if self.name is None: + self_name = self.name + if not self_name: raise ValueError( "Parameter._realize() called on a Parameter without a name. " "Ensure the Parameter is attached to a Module attribute or otherwise " "initialized with a name before realization." ) - if self.name: - self.name = builder._qualify_initializer_name(self.name) - builder.graph.initializers[self.name] = self # type: ignore[index] + self_name = self.name = builder._qualify_initializer_name(self_name) + builder.graph.initializers[self_name] = self self._realized = True return self From 7741555793924c7b2896eda81874c97540362b6f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 21 Feb 2026 07:56:59 -0800 Subject: [PATCH 22/25] Implement ModuleList Signed-off-by: Justin Chu --- onnxscript/nn/__init__.py | 3 +- onnxscript/nn/_module.py | 6 +- onnxscript/nn/_module_list.py | 96 +++++++++++++++++++ onnxscript/nn/_module_test.py | 167 +++++++++++++++++++++++++++++++++- 4 files changed, 269 insertions(+), 3 deletions(-) create mode 100644 onnxscript/nn/_module_list.py diff --git a/onnxscript/nn/__init__.py b/onnxscript/nn/__init__.py index 181409812a..248c9ef40e 100644 --- a/onnxscript/nn/__init__.py +++ b/onnxscript/nn/__init__.py @@ -4,6 +4,7 @@ """PyTorch-like module interface for building ONNX graphs.""" from onnxscript.nn._module import Module +from onnxscript.nn._module_list import ModuleList from onnxscript.nn._parameter import Parameter -__all__ = ["Module", "Parameter"] +__all__ = ["Module", "ModuleList", "Parameter"] diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py index 0f6abfdb8d..084703703e 100644 --- a/onnxscript/nn/_module.py +++ b/onnxscript/nn/_module.py @@ -49,6 +49,10 @@ def __init__(self, name: str | None = None) -> None: def name(self) -> str | None: return self._name + def _set_name(self, name: str) -> None: + """Set the module name. Subclasses may override to propagate names to children.""" + object.__setattr__(self, "_name", name) + def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, Parameter): # Auto-register parameters; set default name from attribute name. @@ -60,7 +64,7 @@ def __setattr__(self, name: str, value: Any) -> None: elif isinstance(value, Module): # Auto-register child modules; inherit attribute name if unnamed. if value._name is None: - object.__setattr__(value, "_name", name) + value._set_name(name) self._modules[name] = value object.__setattr__(self, name, value) else: diff --git a/onnxscript/nn/_module_list.py b/onnxscript/nn/_module_list.py new file mode 100644 index 0000000000..e71ab2a855 --- /dev/null +++ b/onnxscript/nn/_module_list.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Iterator, overload + +from onnxscript._internal.builder import OpBuilder +from onnxscript.nn._module import Module + + +class ModuleList(Module): + """Holds child modules in a list, mirroring ``torch.nn.ModuleList``. + + Children are registered with string keys ``"0"``, ``"1"``, etc., so that + hierarchical parameter names use ``.0.``, ``.1.`` separators just like + PyTorch. + + Example:: + + class MyModel(Module): + def __init__(self): + super().__init__("model") + self.layers = ModuleList([Linear(4, 4) for _ in range(3)]) + + def forward(self, op, x): + for layer in self.layers: + x = layer(op, x) + return x + + # Parameters will be named: + # model.layers.0.weight, model.layers.1.weight, model.layers.2.weight + """ + + def __init__(self, modules: list[Module] | None = None) -> None: + super().__init__() + if modules is not None: + for idx, module in enumerate(modules): + self._register_child(str(idx), module) + + def _set_name(self, name: str) -> None: + """Set this container's name and prefix all children's names.""" + object.__setattr__(self, "_name", name) + for key, child in self._modules.items(): + child._set_name(f"{name}.{key}") # pylint: disable=protected-access + + def _register_child(self, key: str, module: Module) -> None: + """Register a child module under the given string key.""" + if module._name is None: # pylint: disable=protected-access + object.__setattr__(module, "_name", key) + self._modules[key] = module + object.__setattr__(self, key, module) + + def append(self, module: Module) -> ModuleList: + """Append a module to the end of the list.""" + key = str(len(self._modules)) + self._register_child(key, module) + return self + + def extend(self, modules: list[Module]) -> ModuleList: + """Append modules from an iterable to the end of the list.""" + for module in modules: + self.append(module) + return self + + @overload + def __getitem__(self, idx: int) -> Module: ... + + @overload + def __getitem__(self, idx: slice) -> ModuleList: ... + + def __getitem__(self, idx: int | slice) -> Module | ModuleList: + if isinstance(idx, slice): + keys = list(self._modules.keys())[idx] + new_list = ModuleList() + for i, key in enumerate(keys): + new_list._register_child(str(i), self._modules[key]) + return new_list + if idx < 0: + idx += len(self._modules) + key = str(idx) + if key not in self._modules: + raise IndexError(f"index {idx} is out of range") + return self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError( + "ModuleList is not callable directly. " + "Iterate over its children and call them individually." + ) diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py index 172ba16a13..8c68a13e9a 100644 --- a/onnxscript/nn/_module_test.py +++ b/onnxscript/nn/_module_test.py @@ -9,7 +9,7 @@ import onnx_ir as ir from onnxscript._internal.builder import GraphBuilder, OpBuilder -from onnxscript.nn import Module, Parameter +from onnxscript.nn import Module, ModuleList, Parameter def _create_graph_and_op() -> tuple[ir.Graph, OpBuilder]: @@ -663,5 +663,170 @@ def forward(self, op): self.assertIsNone(m.b.const_value) +class ModuleListTest(unittest.TestCase): + def test_basic_indexing(self): + """ModuleList supports integer indexing.""" + + class Layer(Module): + def __init__(self): + super().__init__() + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer(), Layer()]) + self.assertEqual(len(ml), 3) + self.assertIsInstance(ml[0], Layer) + self.assertIsInstance(ml[2], Layer) + self.assertIsInstance(ml[-1], Layer) + self.assertIs(ml[-1], ml[2]) + + def test_index_out_of_range(self): + ml = ModuleList() + with self.assertRaises(IndexError): + _ = ml[0] + + def test_append(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList() + ml.append(Layer()) + ml.append(Layer()) + self.assertEqual(len(ml), 2) + + def test_extend(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList() + ml.extend([Layer(), Layer()]) + self.assertEqual(len(ml), 2) + + def test_iteration(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + layers = [Layer(), Layer(), Layer()] + ml = ModuleList(layers) + for i, layer in enumerate(ml): + self.assertIs(layer, layers[i]) + + def test_slice(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer(), Layer()]) + sub = ml[1:] + self.assertIsInstance(sub, ModuleList) + self.assertEqual(len(sub), 2) + + def test_children_auto_named(self): + """Children get '0', '1', ... names automatically.""" + + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer()]) + self.assertEqual(ml[0]._name, "0") # pylint: disable=protected-access + self.assertEqual(ml[1]._name, "1") # pylint: disable=protected-access + + def test_named_parameters_with_numeric_keys(self): + """Parameters within ModuleList children use numeric-indexed names.""" + + class Layer(Module): + def __init__(self): + super().__init__() + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer()]) + named = dict(ml.named_parameters()) + self.assertIn("0.w", named) + self.assertIn("1.w", named) + + def test_same_class_submodules_get_distinct_names_in_graph(self): + """Multiple sub-modules of the same class in a ModuleList get distinct + .0., .1. prefixed initializer and node names in the graph. + """ + + class Linear(Module): + def __init__(self, in_f, out_f): + super().__init__() + self.weight = Parameter([out_f, in_f], name="weight") + + def forward(self, op, x): + w_t = op.Transpose(self.weight, perm=[1, 0]) + return op.MatMul(x, w_t) + + class Model(Module): + def __init__(self): + super().__init__("model") + self.layers = ModuleList([Linear(4, 4), Linear(4, 4), Linear(4, 4)]) + + def forward(self, op, x): + for layer in self.layers: + x = layer(op, x) + return x + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = Model() + result = m(op, x) + + self.assertIsInstance(result, ir.Value) + # Each layer's weight must have a distinct hierarchical name + self.assertIn("model.layers.0.weight", graph.initializers) + self.assertIn("model.layers.1.weight", graph.initializers) + self.assertIn("model.layers.2.weight", graph.initializers) + # All three must be different ir.Value objects + self.assertIsNot( + graph.initializers["model.layers.0.weight"], + graph.initializers["model.layers.1.weight"], + ) + self.assertIsNot( + graph.initializers["model.layers.1.weight"], + graph.initializers["model.layers.2.weight"], + ) + # Check that we got 6 nodes: (Transpose + MatMul) * 3 layers + op_types = [node.op_type for node in graph] + self.assertEqual(op_types, ["Transpose", "MatMul"] * 3) + + def test_modulelist_not_directly_callable(self): + ml = ModuleList() + _, op = _create_graph_and_op() + with self.assertRaises(NotImplementedError): + ml(op) + + if __name__ == "__main__": unittest.main() From 248742ea41a1d0079d9bf431e379ad105b5a15d3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 21 Feb 2026 16:45:36 -0800 Subject: [PATCH 23/25] Update Parameter to take dtype from data Signed-off-by: Justin Chu --- onnxscript/nn/_parameter.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index fbf2f63cfc..f295bf3708 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -20,7 +20,7 @@ class Parameter(ir.Value): Args: shape: Shape of the parameter tensor. - dtype: Data type of the parameter. Defaults to FLOAT. + dtype: Data type of the parameter. If None and data is not provided, defaults to float32. name: Name for the parameter. If None, the attribute name from the parent Module is used. data: Optional initial tensor data. If provided, the initializer @@ -29,11 +29,21 @@ class Parameter(ir.Value): def __init__( self, - shape: Sequence[int | ir.SymbolicDim | None], - dtype: ir.DataType = ir.DataType.FLOAT, + shape: Sequence[int], + dtype: ir.DataType | None = None, name: str | None = None, data: ir.TensorProtocol | None = None, ) -> None: + if data is not None: + if data.dtype != dtype: + raise ValueError( + f"Data type of provided data ({data.dtype}) does not match the specified dtype ({dtype})." + ) + if dtype is None: + dtype = data.dtype + elif dtype is None: + dtype = ir.DataType.FLOAT + super().__init__( name=name, shape=ir.Shape(shape), From d5e97102ffe6984f59de6ba7edf00f50e5202ced Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 21 Feb 2026 16:46:51 -0800 Subject: [PATCH 24/25] pylint Signed-off-by: Justin Chu --- onnxscript/nn/_parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index f295bf3708..8744e20fb1 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -73,7 +73,7 @@ def _realize(self, builder: GraphBuilder) -> Parameter: "Ensure the Parameter is attached to a Module attribute or otherwise " "initialized with a name before realization." ) - self_name = self.name = builder._qualify_initializer_name(self_name) + self_name = self.name = builder._qualify_initializer_name(self_name) # pylint: disable=protected-access builder.graph.initializers[self_name] = self self._realized = True return self From eae41ac32699b256398f875b72fdb9f975d54bea Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 21 Feb 2026 16:48:37 -0800 Subject: [PATCH 25/25] fix check error Signed-off-by: Justin Chu --- onnxscript/nn/_parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py index 8744e20fb1..d319868e89 100644 --- a/onnxscript/nn/_parameter.py +++ b/onnxscript/nn/_parameter.py @@ -35,7 +35,7 @@ def __init__( data: ir.TensorProtocol | None = None, ) -> None: if data is not None: - if data.dtype != dtype: + if dtype is not None and data.dtype != dtype: raise ValueError( f"Data type of provided data ({data.dtype}) does not match the specified dtype ({dtype})." )