diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index 34ea9a0778..9783e37c24 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -5,6 +5,7 @@ "script", "graph", "ir", + "nn", "optimizer", "rewriter", "version_converter", @@ -127,7 +128,7 @@ # isort: on -from . import ir, optimizer, rewriter, version_converter +from . import ir, nn, optimizer, rewriter, version_converter from ._internal.builder import GraphBuilder from ._internal.utils import external_tensor from ._internal.values import OnnxFunction, TracedOnnxFunction diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 0b626bd100..4589c99e59 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -80,10 +80,10 @@ def __init__(self, graph: ir.Graph) -> None: self._op_builder = self.opset("", opset_version) - # Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done. - # The current context is used as a prefix for naming values and nodes. - # This allows us to generate names like "layer1.attention.query" - self._context_stack: list[str] = [""] + # Module scope stack. Each entry is (name, class_name) where name is + # the module attribute name (e.g. "layers.0", "self_attn") and + # class_name is the qualified class name (e.g. "Gemma3DecoderLayer"). + self._scope_stack: list[tuple[str, str]] = [] # Cache for constant initializers (scalars and sequences), keyed by (value, dtype). # This avoids creating duplicate initializers for the same constant @@ -109,7 +109,7 @@ def initializer( if name is None: name = tensor.name if qualify: - name = self.qualify_name(name) + name = self._qualify_initializer_name(name) shape = ir.Shape(tensor.shape) value = ir.Value( name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor @@ -180,24 +180,26 @@ def _adapt_outputs( self, outputs: int | Sequence[str | ir.Value], op_type: str = "" ) -> Sequence[ir.Value]: if isinstance(outputs, int): + count = self.graph.num_nodes() if outputs < 0: raise ValueError(f"Number of outputs must be non-negative, got {outputs}") if outputs == 1: - name = f"{op_type}_output" if op_type else "output" - return [ir.Value(name=self.qualify_name(name))] + name = f"{op_type}_{count}" if op_type else f"{count}" + return [ir.Value(name=self._qualify_value_name(name))] else: names = [ - f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) + (f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}") + for i in range(outputs) ] - return [ir.Value(name=self.qualify_name(n)) for n in names] + return [ir.Value(name=self._qualify_value_name(n)) for n in names] adapted_outputs = [] for output in outputs: if isinstance(output, ir.Value): if output.name: - output.name = self.qualify_name(output.name) + output.name = self._qualify_value_name(output.name) adapted_outputs.append(output) elif isinstance(output, str): - adapted_outputs.append(ir.Value(name=self.qualify_name(output))) + adapted_outputs.append(ir.Value(name=self._qualify_value_name(output))) else: raise TypeError("Output type not supported.") return adapted_outputs @@ -304,7 +306,7 @@ def call_op( outputs = kwargs.pop("_outputs", 1) count = self.graph.num_nodes() - node_name = self.qualify_name(f"{op_type}_node_{count}") + node_name = self._qualify_node_name(f"{op_type}_node_{count}") output_values = self._adapt_outputs(outputs, op_type) @@ -322,33 +324,80 @@ def call_op( version=version, name=node_name, ) + + # Attach scope metadata to the node + node.metadata_props["namespace"] = self._build_namespace() + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes()) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) + self.add_node(node) return node.outputs if len(node.outputs) > 1 else node.outputs[0] - def push_module(self, module: str) -> None: - """Push a new naming context onto the stack (e.g. a layer or module name).""" - current = self.context_name() - if module: - new_context = f"{current}.{module}" if current else module - else: - new_context = current - self._context_stack.append(new_context) + def push_module(self, module: str, class_name: str = "") -> None: + """Push a new module scope onto the stack. + + Args: + module: The attribute name of the module (e.g. ``"layers.0"``). + class_name: The qualified class name (e.g. ``"Gemma3DecoderLayer"``). + """ + self._scope_stack.append((module, class_name)) def pop_module(self) -> None: - """Pop the most recent naming context off the stack.""" - if len(self._context_stack) <= 1: + """Pop the most recent module scope off the stack.""" + if not self._scope_stack: raise RuntimeError("Cannot pop_module: no module context has been pushed.") - self._context_stack.pop() + self._scope_stack.pop() + + def _scope_names(self) -> list[str]: + """Return the list of module attribute names in the current scope.""" + return [name for name, _ in self._scope_stack] + + def _scope_classes(self) -> list[str]: + """Return the list of class names in the current scope.""" + return [cls for _, cls in self._scope_stack] - def context_name(self) -> str: - """Return the current dot-separated naming context prefix.""" - return self._context_stack[-1] if self._context_stack else "" + def _scope_name_parts(self) -> list[str]: + """Return non-empty module names for qualifying names.""" + return [name for name, _ in self._scope_stack if name] - def qualify_name(self, name: str) -> str: - """Prepend the current hierarchical context prefix to the given name.""" - prefix = self.context_name() - return f"{prefix}.{name}" if prefix else name + def _qualify_initializer_name(self, name: str) -> str: + """Prepend the current hierarchical context prefix to the given name. + + Uses ``.`` as separator, appropriate for parameter and initializer names. + """ + parts = self._scope_name_parts() + if parts: + return ".".join(parts) + "." + name + return name + + def _qualify_value_name(self, name: str) -> str: + """Qualify a value name with the current scope using ``.`` separator. + + The name is prefixed with ``v_`` to distinguish values from parameters. + """ + parts = self._scope_name_parts() + if parts: + return "v_" + ".".join(parts) + "." + name + return f"v_{name}" + + def _qualify_node_name(self, name: str) -> str: + """Qualify a node name with the current scope using ``/`` separator.""" + parts = self._scope_name_parts() + if parts: + return "/".join(parts) + "/" + name + return name + + def _build_namespace(self) -> str: + """Build the namespace string for a node. + + Each scope entry is formatted as ``name: class_name`` joined by ``/``. + """ + parts = [] + for name, cls in self._scope_stack: + if name or cls: + parts.append(f"{name}: {cls}" if cls else name) + return "/".join(parts) class OpBuilder: diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 9edd0b68b4..209ccb6165 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -3,6 +3,7 @@ from __future__ import annotations +import ast import unittest from typing import Sequence @@ -128,14 +129,14 @@ def _add_with_custom_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names - self.assertEqual(nodes[0].outputs[0].name, "add_result") - self.assertEqual(nodes[1].outputs[0].name, "mul_result") - self.assertEqual(nodes[2].outputs[0].name, "final_add") + # Check output names (v_ prefix is added to all value names) + self.assertEqual(nodes[0].outputs[0].name, "v_add_result") + self.assertEqual(nodes[1].outputs[0].name, "v_mul_result") + self.assertEqual(nodes[2].outputs[0].name, "v_final_add") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "final_add") + self.assertEqual(graph.outputs[0].name, "v_final_add") def test_value_naming_with_hierarchy(self): """Test that hierarchical naming works with user-specified output names.""" @@ -143,26 +144,26 @@ def test_value_naming_with_hierarchy(self): # Test custom names at root level t1 = op.Add(x, y, _outputs=["my_add"]) - self.assertEqual(t1.name, "my_add") + self.assertEqual(t1.name, "v_my_add") # Test custom names with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=["my_mul"]) - self.assertEqual(t2.name, "layer1.my_mul") + self.assertEqual(t2.name, "v_layer1.my_mul") # Test nested hierarchical context with custom names op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=["my_nested_add"]) - self.assertEqual(t3.name, "layer1.attention.my_nested_add") + self.assertEqual(t3.name, "v_layer1.attention.my_nested_add") # Pop back and verify prefix is applied correctly op.builder.pop_module() t4 = op.Mul(t3, y, _outputs=["another_mul"]) - self.assertEqual(t4.name, "layer1.another_mul") + self.assertEqual(t4.name, "v_layer1.another_mul") op.builder.pop_module() t5 = op.Add(t4, x, _outputs=["final_result"]) - self.assertEqual(t5.name, "final_result") + self.assertEqual(t5.name, "v_final_result") def test_value_naming_with_ir_value_objects(self): """Test that hierarchical naming works when passing ir.Value objects as _outputs.""" @@ -175,19 +176,19 @@ def test_value_naming_with_ir_value_objects(self): # Test at root level t1 = op.Add(x, y, _outputs=[out1]) - self.assertEqual(t1.name, "my_output") + self.assertEqual(t1.name, "v_my_output") self.assertIs(t1, out1) # Test with hierarchical context op.builder.push_module("layer1") t2 = op.Mul(t1, y, _outputs=[out2]) - self.assertEqual(t2.name, "layer1.layer_output") + self.assertEqual(t2.name, "v_layer1.layer_output") self.assertIs(t2, out2) # Test nested hierarchical context op.builder.push_module("attention") t3 = op.Add(t2, x, _outputs=[out3]) - self.assertEqual(t3.name, "layer1.attention.nested_output") + self.assertEqual(t3.name, "v_layer1.attention.nested_output") self.assertIs(t3, out3) def test_default_output_naming_strategy(self): @@ -214,14 +215,14 @@ def _ops_with_default_names( nodes = list(graph) self.assertEqual(len(nodes), 3) - # Check output names follow the {op_type}_output pattern for single outputs - self.assertEqual(nodes[0].outputs[0].name, "Add_output") - self.assertEqual(nodes[1].outputs[0].name, "Mul_output") - self.assertEqual(nodes[2].outputs[0].name, "Add_output") + # Check output names follow the v_{op_type}_{count} pattern for single outputs + self.assertEqual(nodes[0].outputs[0].name, "v_Add_0") + self.assertEqual(nodes[1].outputs[0].name, "v_Mul_1") + self.assertEqual(nodes[2].outputs[0].name, "v_Add_2") # Verify the final output has the correct name self.assertEqual(len(graph.outputs), 1) - self.assertEqual(graph.outputs[0].name, "Add_output") + self.assertEqual(graph.outputs[0].name, "v_Add_2") def test_hierarchical_naming(self): """Test the hierarchical naming strategy (for value and node names).""" @@ -229,35 +230,35 @@ def test_hierarchical_naming(self): # Test node and value naming at root level t1 = op.Add(x, y) - self.assertEqual(t1.name, "Add_output") + self.assertEqual(t1.name, "v_Add_0") self.assertEqual(t1.producer().name, "Add_node_0") t2 = op.Mul(t1, y) - self.assertEqual(t2.name, "Mul_output") + self.assertEqual(t2.name, "v_Mul_1") self.assertEqual(t2.producer().name, "Mul_node_1") # Test node and value naming with hierarchical context prefix op.builder.push_module("layer1") t3 = op.Add(t2, x) - self.assertEqual(t3.name, "layer1.Add_output") - self.assertEqual(t3.producer().name, "layer1.Add_node_2") + self.assertEqual(t3.name, "v_layer1.Add_2") + self.assertEqual(t3.producer().name, "layer1/Add_node_2") # Test nested hierarchical context op.builder.push_module("attention") t4 = op.Mul(t3, y) - self.assertEqual(t4.name, "layer1.attention.Mul_output") - self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") + self.assertEqual(t4.name, "v_layer1.attention.Mul_3") + self.assertEqual(t4.producer().name, "layer1/attention/Mul_node_3") # Pop back to layer1 and verify naming continues correctly op.builder.pop_module() t5 = op.Add(t4, x) - self.assertEqual(t5.name, "layer1.Add_output") - self.assertEqual(t5.producer().name, "layer1.Add_node_4") + self.assertEqual(t5.name, "v_layer1.Add_4") + self.assertEqual(t5.producer().name, "layer1/Add_node_4") # Pop back to root context op.builder.pop_module() t6 = op.Mul(t5, y) - self.assertEqual(t6.name, "Mul_output") + self.assertEqual(t6.name, "v_Mul_5") self.assertEqual(t6.producer().name, "Mul_node_5") def test_shape_inference_add(self): @@ -311,7 +312,7 @@ def test_custom_domain_with_version(self): # Verify output value is created self.assertIsNotNone(result) - self.assertEqual(result.name, "MicrosoftOp_output") + self.assertEqual(result.name, "v_MicrosoftOp_0") def test_multiple_custom_domain_operations(self): """Test mixing operations from multiple domains.""" @@ -519,6 +520,84 @@ def test_pop_module_raises_on_empty_stack(self): with self.assertRaises(RuntimeError): op.builder.pop_module() + def test_output_names_are_unique_for_same_op_type(self): + """Test that repeated calls to the same op produce unique output names.""" + op, x, y = _create_builder_with_inputs() + + t1 = op.Add(x, y) + t2 = op.Add(x, y) + t3 = op.Add(x, y) + + # Each Add output should have a unique name via the node count suffix + self.assertEqual(t1.name, "v_Add_0") + self.assertEqual(t2.name, "v_Add_1") + self.assertEqual(t3.name, "v_Add_2") + + # Verify all names are distinct + names = [t1.name, t2.name, t3.name] + self.assertEqual(len(set(names)), 3) + + def test_multi_output_names_are_unique(self): + """Test that multi-output ops produce unique names with counter suffix.""" + op, x, y = _create_builder_with_inputs() + + # First multi-output call + out1_a, out1_b = op.TopK(x, 1, axis=-1, _outputs=2) + # Second multi-output call + out2_a, out2_b = op.TopK(y, 1, axis=-1, _outputs=2) + + # Each call should produce unique names + self.assertNotEqual(out1_a.name, out2_a.name) + self.assertNotEqual(out1_b.name, out2_b.name) + + def test_node_metadata_props_namespace(self): + """Test that nodes have namespace metadata matching the scope hierarchy.""" + op, x, y = _create_builder_with_inputs() + + # Root-level node + t1 = op.Add(x, y) + self.assertEqual(t1.producer().metadata_props["namespace"], "") + + # Node inside a module scope + op.builder.push_module("layer1", "DecoderLayer") + t2 = op.Mul(t1, y) + self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: DecoderLayer") + + # Nested scope + op.builder.push_module("self_attn", "Attention") + t3 = op.Add(t2, x) + self.assertEqual( + t3.producer().metadata_props["namespace"], + "layer1: DecoderLayer/self_attn: Attention", + ) + op.builder.pop_module() + op.builder.pop_module() + + def test_node_metadata_props_class_hierarchy(self): + """Test that nodes have class hierarchy metadata.""" + op, x, y = _create_builder_with_inputs() + + op.builder.push_module("layer1", "DecoderLayer") + op.builder.push_module("self_attn", "Attention") + t1 = op.MatMul(x, y) + node = t1.producer() + + self.assertEqual( + node.metadata_props["pkg.onnxscript.class_hierarchy"], + repr(["DecoderLayer", "Attention"]), + ) + self.assertEqual( + node.metadata_props["pkg.onnxscript.name_scopes"], + repr(["layer1", "self_attn"]), + ) + # class_hierarchy and name_scopes have the same length + self.assertEqual( + len(ast.literal_eval(node.metadata_props["pkg.onnxscript.class_hierarchy"])), + len(ast.literal_eval(node.metadata_props["pkg.onnxscript.name_scopes"])), + ) + op.builder.pop_module() + op.builder.pop_module() + def test_attributes_are_created_properly(self): """Test that int, float, str, and list attributes are set correctly on a node.""" op, x, y = _create_builder_with_inputs() diff --git a/onnxscript/nn/__init__.py b/onnxscript/nn/__init__.py new file mode 100644 index 0000000000..248c9ef40e --- /dev/null +++ b/onnxscript/nn/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""PyTorch-like module interface for building ONNX graphs.""" + +from onnxscript.nn._module import Module +from onnxscript.nn._module_list import ModuleList +from onnxscript.nn._parameter import Parameter + +__all__ = ["Module", "ModuleList", "Parameter"] diff --git a/onnxscript/nn/_module.py b/onnxscript/nn/_module.py new file mode 100644 index 0000000000..084703703e --- /dev/null +++ b/onnxscript/nn/_module.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Iterator + +import onnx_ir as ir + +from onnxscript._internal.builder import GraphBuilder, OpBuilder +from onnxscript.nn._parameter import Parameter + + +class Module: + """Base class for all onnxscript modules, mirroring PyTorch's nn.Module. + + Subclasses define ``forward()`` to build ONNX subgraphs. Child modules + and parameters are registered automatically via ``__setattr__``. + Because ``Parameter`` subclasses ``ir.Value``, parameters like + ``self.weight`` can be passed directly to ONNX ops. + + Example:: + + class Linear(onnxscript.nn.Module): + def __init__(self, in_features, out_features, bias=True, name=None): + super().__init__(name) + self.weight = Parameter([out_features, in_features], name="weight") + if bias: + self.bias = Parameter([out_features], name="bias") + else: + self.bias = None + + def forward(self, op, x): + w_t = op.Transpose(self.weight, perm=[1, 0]) + result = op.MatMul(x, w_t) + if self.bias is not None: + result = op.Add(result, self.bias) + return result + """ + + def __init__(self, name: str | None = None) -> None: + # Use object.__setattr__ to avoid triggering our __setattr__ override + # before _parameters and _modules dicts exist. + object.__setattr__(self, "_name", name) + object.__setattr__(self, "_parameters", {}) + object.__setattr__(self, "_modules", {}) + + @property + def name(self) -> str | None: + return self._name + + def _set_name(self, name: str) -> None: + """Set the module name. Subclasses may override to propagate names to children.""" + object.__setattr__(self, "_name", name) + + def __setattr__(self, name: str, value: Any) -> None: + if isinstance(value, Parameter): + # Auto-register parameters; set default name from attribute name. + if value.name is None: + value.name = name + self._parameters[name] = value + # Also store on the instance so getattr works outside forward() + object.__setattr__(self, name, value) + elif isinstance(value, Module): + # Auto-register child modules; inherit attribute name if unnamed. + if value._name is None: + value._set_name(name) + self._modules[name] = value + object.__setattr__(self, name, value) + else: + object.__setattr__(self, name, value) + + def __call__(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + builder: GraphBuilder = op.builder + module_name = self._name or "" + class_name = type(self).__qualname__ + builder.push_module(module_name, class_name) + try: + # Realize parameters: qualify names and register as graph initializers. + for param in self._parameters.values(): + param._realize(builder) # pylint: disable=protected-access + + result = self.forward(op, *args, **kwargs) + finally: + builder.pop_module() + return result + + def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + """Define the computation performed by this module. + + Must be overridden by subclasses. Receives an ``OpBuilder`` as the + first argument so that ONNX ops can be called as ``op.MatMul(x, w)``. + """ + raise NotImplementedError(f"{type(self).__name__} must implement forward()") + + # ------------------------------------------------------------------ + # Iterators + # ------------------------------------------------------------------ + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + """Return an iterator over module parameters.""" + yield from self._parameters.values() + if recurse: + for module in self._modules.values(): + yield from module.parameters(recurse=True) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[tuple[str, Parameter]]: + """Return an iterator over module parameters, yielding (name, Parameter) pairs.""" + for name, param in self._parameters.items(): + full_name = f"{prefix}.{name}" if prefix else name + yield full_name, param + if recurse: + for mod_name, module in self._modules.items(): + sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name + yield from module.named_parameters(prefix=sub_prefix, recurse=True) + + def children(self) -> Iterator[Module]: + """Return an iterator over immediate child modules.""" + yield from self._modules.values() + + def named_children(self) -> Iterator[tuple[str, Module]]: + """Return an iterator over immediate child modules, yielding (name, Module) pairs.""" + yield from self._modules.items() + + def modules(self) -> Iterator[Module]: + """Return an iterator over all modules in the tree (including self).""" + yield self + for module in self._modules.values(): + yield from module.modules() + + def named_modules(self, prefix: str = "") -> Iterator[tuple[str, Module]]: + """Return an iterator over all modules, yielding (name, Module) pairs.""" + yield prefix, self + for name, module in self._modules.items(): + sub_prefix = f"{prefix}.{name}" if prefix else name + yield from module.named_modules(prefix=sub_prefix) + + # ------------------------------------------------------------------ + # State dict + # ------------------------------------------------------------------ + + def state_dict(self, prefix: str = "") -> dict[str, ir.TensorProtocol | None]: + """Return a dictionary mapping parameter names to their tensor data. + + Mirrors ``torch.nn.Module.state_dict()``. Keys use dot-separated + hierarchical names (e.g. ``"layer1.weight"``). Values are the + ``const_value`` of each parameter (``None`` if uninitialized). + """ + result: dict[str, ir.TensorProtocol | None] = {} + for name, param in self._parameters.items(): + full_name = f"{prefix}.{name}" if prefix else name + result[full_name] = param.const_value + for mod_name, module in self._modules.items(): + sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name + result.update(module.state_dict(prefix=sub_prefix)) + return result + + def load_state_dict( + self, + state_dict: dict[str, ir.TensorProtocol], + strict: bool = True, + ) -> None: + """Load parameter data from a state dictionary. + + Mirrors ``torch.nn.Module.load_state_dict()``. Sets ``const_value`` + on each matching parameter. + + Args: + state_dict: Mapping of parameter names to tensor data. + strict: If ``True`` (default), raises ``KeyError`` for missing + keys and ``ValueError`` for unexpected keys. + """ + self._load_state_dict_recursive(state_dict, prefix="", strict=strict) + + def _load_state_dict_recursive( + self, + state_dict: dict[str, ir.TensorProtocol], + prefix: str, + strict: bool, + ) -> set[str]: + """Recursively load state and return the set of consumed keys.""" + consumed: set[str] = set() + for name, param in self._parameters.items(): + full_name = f"{prefix}.{name}" if prefix else name + if full_name in state_dict: + param.const_value = state_dict[full_name] + consumed.add(full_name) + elif strict: + raise KeyError(f"Missing key in state_dict: {full_name!r}") + for mod_name, module in self._modules.items(): + sub_prefix = f"{prefix}.{mod_name}" if prefix else mod_name + consumed |= module._load_state_dict_recursive( # pylint: disable=protected-access + state_dict, prefix=sub_prefix, strict=strict + ) + if strict and prefix == "": + unexpected = set(state_dict.keys()) - consumed + if unexpected: + raise ValueError(f"Unexpected keys in state_dict: {unexpected}") + return consumed + + def __repr__(self) -> str: + lines = [f"{type(self).__name__}("] + for name, module in self._modules.items(): + mod_repr = repr(module).replace("\n", "\n ") + lines.append(f" ({name}): {mod_repr}") + for name, param in self._parameters.items(): + lines.append(f" ({name}): {param!r}") + lines.append(")") + return "\n".join(lines) diff --git a/onnxscript/nn/_module_list.py b/onnxscript/nn/_module_list.py new file mode 100644 index 0000000000..e71ab2a855 --- /dev/null +++ b/onnxscript/nn/_module_list.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Iterator, overload + +from onnxscript._internal.builder import OpBuilder +from onnxscript.nn._module import Module + + +class ModuleList(Module): + """Holds child modules in a list, mirroring ``torch.nn.ModuleList``. + + Children are registered with string keys ``"0"``, ``"1"``, etc., so that + hierarchical parameter names use ``.0.``, ``.1.`` separators just like + PyTorch. + + Example:: + + class MyModel(Module): + def __init__(self): + super().__init__("model") + self.layers = ModuleList([Linear(4, 4) for _ in range(3)]) + + def forward(self, op, x): + for layer in self.layers: + x = layer(op, x) + return x + + # Parameters will be named: + # model.layers.0.weight, model.layers.1.weight, model.layers.2.weight + """ + + def __init__(self, modules: list[Module] | None = None) -> None: + super().__init__() + if modules is not None: + for idx, module in enumerate(modules): + self._register_child(str(idx), module) + + def _set_name(self, name: str) -> None: + """Set this container's name and prefix all children's names.""" + object.__setattr__(self, "_name", name) + for key, child in self._modules.items(): + child._set_name(f"{name}.{key}") # pylint: disable=protected-access + + def _register_child(self, key: str, module: Module) -> None: + """Register a child module under the given string key.""" + if module._name is None: # pylint: disable=protected-access + object.__setattr__(module, "_name", key) + self._modules[key] = module + object.__setattr__(self, key, module) + + def append(self, module: Module) -> ModuleList: + """Append a module to the end of the list.""" + key = str(len(self._modules)) + self._register_child(key, module) + return self + + def extend(self, modules: list[Module]) -> ModuleList: + """Append modules from an iterable to the end of the list.""" + for module in modules: + self.append(module) + return self + + @overload + def __getitem__(self, idx: int) -> Module: ... + + @overload + def __getitem__(self, idx: slice) -> ModuleList: ... + + def __getitem__(self, idx: int | slice) -> Module | ModuleList: + if isinstance(idx, slice): + keys = list(self._modules.keys())[idx] + new_list = ModuleList() + for i, key in enumerate(keys): + new_list._register_child(str(i), self._modules[key]) + return new_list + if idx < 0: + idx += len(self._modules) + key = str(idx) + if key not in self._modules: + raise IndexError(f"index {idx} is out of range") + return self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def forward(self, op: OpBuilder, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError( + "ModuleList is not callable directly. " + "Iterate over its children and call them individually." + ) diff --git a/onnxscript/nn/_module_test.py b/onnxscript/nn/_module_test.py new file mode 100644 index 0000000000..8c68a13e9a --- /dev/null +++ b/onnxscript/nn/_module_test.py @@ -0,0 +1,832 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import unittest +from typing import Any + +import onnx_ir as ir + +from onnxscript._internal.builder import GraphBuilder, OpBuilder +from onnxscript.nn import Module, ModuleList, Parameter + + +def _create_graph_and_op() -> tuple[ir.Graph, OpBuilder]: + """Create an empty graph and its default OpBuilder.""" + graph = ir.Graph( + name="test_model", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, + ) + builder = GraphBuilder(graph) + return graph, builder.op + + +class ParameterTest(unittest.TestCase): + def test_parameter_repr(self): + p = Parameter([3, 4], dtype=ir.DataType.FLOAT, name="weight") + self.assertIn("weight", repr(p)) + self.assertIn("3, 4", repr(p)) + + def test_realize_creates_initializer(self): + graph, op = _create_graph_and_op() + p = Parameter([3, 4], dtype=ir.DataType.FLOAT, name="weight") + value = p._realize(op.builder) # pylint: disable=protected-access + + self.assertIs(value, p) # _realize returns self + self.assertIsInstance(value, ir.Value) + self.assertEqual(value.name, "weight") + self.assertEqual(value.type.dtype, ir.DataType.FLOAT) + self.assertEqual(list(value.shape), [3, 4]) + # Should be registered as initializer + self.assertIn("weight", graph.initializers) + + def test_realize_is_idempotent(self): + _, op = _create_graph_and_op() + p = Parameter([3, 4], name="weight") + v1 = p._realize(op.builder) # pylint: disable=protected-access + v2 = p._realize(op.builder) # pylint: disable=protected-access + self.assertIs(v1, v2) + + def test_realize_with_data(self): + graph, op = _create_graph_and_op() + tensor = ir.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=ir.DataType.FLOAT) + p = Parameter([2, 2], name="weight", data=tensor) + value = p._realize(op.builder) # pylint: disable=protected-access + + self.assertIs(value.const_value, tensor) + self.assertIn("weight", graph.initializers) + # The initializer in the graph IS the parameter itself + self.assertIs(graph.initializers["weight"], p) + + def test_realize_qualifies_name(self): + graph, op = _create_graph_and_op() + op.builder.push_module("layer1") + p = Parameter([3], name="bias") + value = p._realize(op.builder) # pylint: disable=protected-access + op.builder.pop_module() + + self.assertEqual(value.name, "layer1.bias") + self.assertIn("layer1.bias", graph.initializers) + + +class ModuleBasicTest(unittest.TestCase): + def test_parameter_auto_registration(self): + """Parameters assigned as attributes are automatically registered.""" + + class MyModule(Module): + def __init__(self): + super().__init__("my_mod") + self.weight = Parameter([3, 4], name="weight") + self.bias = Parameter([3], name="bias") + + def forward(self, op, x): + pass + + m = MyModule() + param_names = list(m._parameters.keys()) # pylint: disable=protected-access + self.assertEqual(param_names, ["weight", "bias"]) + + def test_module_auto_registration(self): + """Child modules assigned as attributes are automatically registered.""" + + class Child(Module): + def __init__(self): + super().__init__("child") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + p = Parent() + self.assertIn("child", p._modules) # pylint: disable=protected-access + self.assertIs(p._modules["child"], p.child) # pylint: disable=protected-access + + def test_module_inherits_attribute_name(self): + """Child module with no explicit name inherits the attribute name.""" + + class Child(Module): + def __init__(self): + super().__init__() # name=None + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.my_layer = Child() + + def forward(self, op): + pass + + p = Parent() + self.assertEqual(p.my_layer._name, "my_layer") # pylint: disable=protected-access + + def test_parameter_inherits_attribute_name(self): + """Parameter with no explicit name inherits the attribute name.""" + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.weight = Parameter([3, 4]) # name=None + + def forward(self, op, x): + pass + + m = MyModule() + self.assertEqual(m.weight.name, "weight") + + def test_name_property(self): + """The name property returns the module name.""" + + class MyModule(Module): + def __init__(self): + super().__init__("my_mod") + + def forward(self, op): + pass + + m = MyModule() + self.assertEqual(m.name, "my_mod") + + def test_setattr_plain_attribute(self): + """Non-Parameter, non-Module attributes are stored normally.""" + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.hidden_size = 128 + + def forward(self, op): + pass + + m = MyModule() + self.assertEqual(m.hidden_size, 128) + + def test_forward_not_implemented(self): + """Calling forward() on base Module raises NotImplementedError.""" + m = Module("base") + _, op = _create_graph_and_op() + with self.assertRaises(NotImplementedError): + m(op) + + +class ModuleForwardTest(unittest.TestCase): + def test_parameter_is_ir_value_in_forward(self): + """Parameters are ir.Value instances usable directly in forward().""" + captured: dict[str, Any] = {} + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.weight = Parameter([3, 4], name="weight") + + def forward(self, op): + # self.weight IS a Parameter (which IS an ir.Value) + captured["is_ir_value"] = isinstance(self.weight, ir.Value) + captured["is_parameter"] = isinstance(self.weight, Parameter) + captured["weight_name"] = self.weight.name + + _, op = _create_graph_and_op() + m = MyModule() + m(op) + + self.assertTrue(captured["is_ir_value"]) + self.assertTrue(captured["is_parameter"]) + self.assertEqual(captured["weight_name"], "mod.weight") + + # After forward, self.weight is still the same Parameter + self.assertIsInstance(m.weight, Parameter) + + def test_parameter_naming_with_hierarchy(self): + """Parameters get hierarchically qualified names.""" + captured_names: list[str] = [] + + class Inner(Module): + def __init__(self): + super().__init__() + self.w = Parameter([2, 2], name="w") + + def forward(self, op): + captured_names.append(self.w.name) + + class Outer(Module): + def __init__(self): + super().__init__("outer") + self.layer = Inner() + + def forward(self, op): + self.layer(op) + + graph, op = _create_graph_and_op() + m = Outer() + m(op) + + self.assertEqual(captured_names, ["outer.layer.w"]) + self.assertIn("outer.layer.w", graph.initializers) + + def test_forward_with_ops(self): + """Module forward can use OpBuilder to create ONNX nodes.""" + + class AddBias(Module): + def __init__(self, size): + super().__init__("add_bias") + self.bias = Parameter([size], name="bias") + + def forward(self, op, x): + return op.Add(x, self.bias) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([3]), + ) + graph.inputs.append(x) + + m = AddBias(3) + result = m(op, x) + + self.assertIsInstance(result, ir.Value) + nodes = list(graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Add") + + def test_composition_multiple_submodules(self): + """Multiple submodules compose correctly with independent parameters.""" + + class Linear(Module): + def __init__(self, in_f, out_f, name=None): + super().__init__(name) + self.weight = Parameter([out_f, in_f], name="weight") + + def forward(self, op, x): + w_t = op.Transpose(self.weight, perm=[1, 0]) + return op.MatMul(x, w_t) + + class TwoLinear(Module): + def __init__(self): + super().__init__("model") + self.fc1 = Linear(4, 3, name="fc1") + self.fc2 = Linear(3, 2, name="fc2") + + def forward(self, op, x): + h = self.fc1(op, x) + return self.fc2(op, h) + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = TwoLinear() + result = m(op, x) + + self.assertIsInstance(result, ir.Value) + # Check initializer names are hierarchical + self.assertIn("model.fc1.weight", graph.initializers) + self.assertIn("model.fc2.weight", graph.initializers) + # Check nodes: Transpose, MatMul, Transpose, MatMul + op_types = [node.op_type for node in graph] + self.assertEqual(op_types, ["Transpose", "MatMul", "Transpose", "MatMul"]) + + +class ModuleIteratorTest(unittest.TestCase): + def test_parameters_iterator(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w1 = Parameter([3], name="w1") + self.w2 = Parameter([4], name="w2") + + def forward(self, op): + pass + + m = MyModule() + params = list(m.parameters()) + self.assertEqual(len(params), 2) + + def test_parameters_recursive(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.p = Parameter([2], name="p") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + params = list(m.parameters(recurse=True)) + self.assertEqual(len(params), 2) + params_no_recurse = list(m.parameters(recurse=False)) + self.assertEqual(len(params_no_recurse), 1) + + def test_named_parameters_recursive(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.p = Parameter([2], name="p") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + named = dict(m.named_parameters()) + self.assertIn("p", named) + self.assertIn("child.w", named) + + def test_modules_iterator(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class B(Module): + def __init__(self): + super().__init__("b") + self.a = A() + + def forward(self, op): + pass + + m = B() + mods = list(m.modules()) + self.assertEqual(len(mods), 2) + self.assertIs(mods[0], m) + self.assertIs(mods[1], m.a) + + def test_named_modules_iterator(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class B(Module): + def __init__(self): + super().__init__("b") + self.a = A() + + def forward(self, op): + pass + + m = B() + named = dict(m.named_modules()) + self.assertIn("", named) # self + self.assertIn("a", named) + + +class ModuleReprTest(unittest.TestCase): + def test_repr(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + r = repr(m) + self.assertIn("Parent", r) + self.assertIn("child", r) + self.assertIn("Child", r) + + +class ModuleChildrenTest(unittest.TestCase): + def test_children(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class B(Module): + def __init__(self): + super().__init__("b") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.a = A() + self.b = B() + + def forward(self, op): + pass + + m = Parent() + kids = list(m.children()) + self.assertEqual(len(kids), 2) + self.assertIs(kids[0], m.a) + self.assertIs(kids[1], m.b) + + def test_named_children(self): + class A(Module): + def __init__(self): + super().__init__("a") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.layer = A() + + def forward(self, op): + pass + + m = Parent() + named = dict(m.named_children()) + self.assertIn("layer", named) + self.assertIs(named["layer"], m.layer) + + def test_children_does_not_recurse(self): + """children() only yields immediate children, not grandchildren.""" + + class Grandchild(Module): + def __init__(self): + super().__init__("gc") + + def forward(self, op): + pass + + class Child(Module): + def __init__(self): + super().__init__("child") + self.gc = Grandchild() + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + kids = list(m.children()) + self.assertEqual(len(kids), 1) + self.assertIs(kids[0], m.child) + + +class ModuleStateDictTest(unittest.TestCase): + def test_state_dict_flat(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + self.b = Parameter([3], name="b") + + def forward(self, op): + pass + + m = MyModule() + sd = m.state_dict() + self.assertIn("w", sd) + self.assertIn("b", sd) + # Uninitialized parameters have None data + self.assertIsNone(sd["w"]) + self.assertIsNone(sd["b"]) + + def test_state_dict_hierarchical(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.p = Parameter([2], name="p") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + sd = m.state_dict() + self.assertIn("p", sd) + self.assertIn("child.w", sd) + + def test_state_dict_with_data(self): + tensor = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + p = Parameter([3], name="w", data=tensor) + + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = p + + def forward(self, op): + pass + + m = MyModule() + sd = m.state_dict() + self.assertIs(sd["w"], tensor) + + def test_load_state_dict(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + self.b = Parameter([3], name="b") + + def forward(self, op): + pass + + m = MyModule() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + b_data = ir.tensor([0.1, 0.2, 0.3], dtype=ir.DataType.FLOAT) + m.load_state_dict({"w": w_data, "b": b_data}) + + self.assertIs(m.w.const_value, w_data) + self.assertIs(m.b.const_value, b_data) + + def test_load_state_dict_hierarchical(self): + class Child(Module): + def __init__(self): + super().__init__("child") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + class Parent(Module): + def __init__(self): + super().__init__("parent") + self.child = Child() + + def forward(self, op): + pass + + m = Parent() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + m.load_state_dict({"child.w": w_data}) + self.assertIs(m.child.w.const_value, w_data) + + def test_load_state_dict_strict_missing_key(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + m = MyModule() + with self.assertRaises(KeyError): + m.load_state_dict({}) + + def test_load_state_dict_strict_unexpected_key(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + m = MyModule() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + with self.assertRaises(ValueError): + m.load_state_dict({"w": w_data, "extra": w_data}) + + def test_load_state_dict_non_strict(self): + class MyModule(Module): + def __init__(self): + super().__init__("mod") + self.w = Parameter([3], name="w") + self.b = Parameter([3], name="b") + + def forward(self, op): + pass + + m = MyModule() + w_data = ir.tensor([1.0, 2.0, 3.0], dtype=ir.DataType.FLOAT) + # Only load w, skip b — no error because strict=False + m.load_state_dict({"w": w_data}, strict=False) + self.assertIs(m.w.const_value, w_data) + self.assertIsNone(m.b.const_value) + + +class ModuleListTest(unittest.TestCase): + def test_basic_indexing(self): + """ModuleList supports integer indexing.""" + + class Layer(Module): + def __init__(self): + super().__init__() + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer(), Layer()]) + self.assertEqual(len(ml), 3) + self.assertIsInstance(ml[0], Layer) + self.assertIsInstance(ml[2], Layer) + self.assertIsInstance(ml[-1], Layer) + self.assertIs(ml[-1], ml[2]) + + def test_index_out_of_range(self): + ml = ModuleList() + with self.assertRaises(IndexError): + _ = ml[0] + + def test_append(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList() + ml.append(Layer()) + ml.append(Layer()) + self.assertEqual(len(ml), 2) + + def test_extend(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList() + ml.extend([Layer(), Layer()]) + self.assertEqual(len(ml), 2) + + def test_iteration(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + layers = [Layer(), Layer(), Layer()] + ml = ModuleList(layers) + for i, layer in enumerate(ml): + self.assertIs(layer, layers[i]) + + def test_slice(self): + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer(), Layer()]) + sub = ml[1:] + self.assertIsInstance(sub, ModuleList) + self.assertEqual(len(sub), 2) + + def test_children_auto_named(self): + """Children get '0', '1', ... names automatically.""" + + class Layer(Module): + def __init__(self): + super().__init__() + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer()]) + self.assertEqual(ml[0]._name, "0") # pylint: disable=protected-access + self.assertEqual(ml[1]._name, "1") # pylint: disable=protected-access + + def test_named_parameters_with_numeric_keys(self): + """Parameters within ModuleList children use numeric-indexed names.""" + + class Layer(Module): + def __init__(self): + super().__init__() + self.w = Parameter([3], name="w") + + def forward(self, op): + pass + + ml = ModuleList([Layer(), Layer()]) + named = dict(ml.named_parameters()) + self.assertIn("0.w", named) + self.assertIn("1.w", named) + + def test_same_class_submodules_get_distinct_names_in_graph(self): + """Multiple sub-modules of the same class in a ModuleList get distinct + .0., .1. prefixed initializer and node names in the graph. + """ + + class Linear(Module): + def __init__(self, in_f, out_f): + super().__init__() + self.weight = Parameter([out_f, in_f], name="weight") + + def forward(self, op, x): + w_t = op.Transpose(self.weight, perm=[1, 0]) + return op.MatMul(x, w_t) + + class Model(Module): + def __init__(self): + super().__init__("model") + self.layers = ModuleList([Linear(4, 4), Linear(4, 4), Linear(4, 4)]) + + def forward(self, op, x): + for layer in self.layers: + x = layer(op, x) + return x + + graph, op = _create_graph_and_op() + x = ir.Value( + name="input", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([1, 4]), + ) + graph.inputs.append(x) + + m = Model() + result = m(op, x) + + self.assertIsInstance(result, ir.Value) + # Each layer's weight must have a distinct hierarchical name + self.assertIn("model.layers.0.weight", graph.initializers) + self.assertIn("model.layers.1.weight", graph.initializers) + self.assertIn("model.layers.2.weight", graph.initializers) + # All three must be different ir.Value objects + self.assertIsNot( + graph.initializers["model.layers.0.weight"], + graph.initializers["model.layers.1.weight"], + ) + self.assertIsNot( + graph.initializers["model.layers.1.weight"], + graph.initializers["model.layers.2.weight"], + ) + # Check that we got 6 nodes: (Transpose + MatMul) * 3 layers + op_types = [node.op_type for node in graph] + self.assertEqual(op_types, ["Transpose", "MatMul"] * 3) + + def test_modulelist_not_directly_callable(self): + ml = ModuleList() + _, op = _create_graph_and_op() + with self.assertRaises(NotImplementedError): + ml(op) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/nn/_parameter.py b/onnxscript/nn/_parameter.py new file mode 100644 index 0000000000..d319868e89 --- /dev/null +++ b/onnxscript/nn/_parameter.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Sequence + +import onnx_ir as ir + +from onnxscript._internal.builder import GraphBuilder + + +class Parameter(ir.Value): + """A module parameter that is also an ``ir.Value``. + + Since ``Parameter`` subclasses ``ir.Value``, it can be passed directly + to ONNX ops inside ``Module.forward()`` without any conversion. + Calling :meth:`_realize` qualifies the name with the current module + context and registers the parameter as a graph initializer. + + Args: + shape: Shape of the parameter tensor. + dtype: Data type of the parameter. If None and data is not provided, defaults to float32. + name: Name for the parameter. If None, the attribute name from + the parent Module is used. + data: Optional initial tensor data. If provided, the initializer + will carry this as its const_value. + """ + + def __init__( + self, + shape: Sequence[int], + dtype: ir.DataType | None = None, + name: str | None = None, + data: ir.TensorProtocol | None = None, + ) -> None: + if data is not None: + if dtype is not None and data.dtype != dtype: + raise ValueError( + f"Data type of provided data ({data.dtype}) does not match the specified dtype ({dtype})." + ) + if dtype is None: + dtype = data.dtype + elif dtype is None: + dtype = ir.DataType.FLOAT + + super().__init__( + name=name, + shape=ir.Shape(shape), + type=ir.TensorType(dtype), + const_value=data, + ) + self._realized = False + + @property + def dtype(self) -> ir.DataType | None: # type: ignore[override] + """Return the element data type of this parameter.""" + return self.type.dtype if self.type is not None else None + + def _realize(self, builder: GraphBuilder) -> Parameter: + """Qualify the name and register as a graph initializer. + + Uses direct assignment to ``graph.initializers[...]`` to skip the + const_value check. Idempotent: subsequent calls are no-ops. + """ + if self._realized: + return self + + self_name = self.name + if not self_name: + raise ValueError( + "Parameter._realize() called on a Parameter without a name. " + "Ensure the Parameter is attached to a Module attribute or otherwise " + "initialized with a name before realization." + ) + self_name = self.name = builder._qualify_initializer_name(self_name) # pylint: disable=protected-access + builder.graph.initializers[self_name] = self + self._realized = True + return self + + def __repr__(self) -> str: + return f"Parameter(shape={list(self.shape)}, dtype={self.dtype}, name={self.name!r})"