Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3033248
Add onnxscript.nn package with Module and Parameter classes
justinchuby Feb 20, 2026
11c880c
Add children, named_children, state_dict, load_state_dict to Module
justinchuby Feb 20, 2026
b6c084f
Improve nn module test coverage to 100%
justinchuby Feb 20, 2026
48b6a14
Make Parameter.realize() private (_realize())
justinchuby Feb 20, 2026
caa95cc
Update onnxscript/nn/_parameter.py
justinchuby Feb 20, 2026
13c8e6d
Update onnxscript/nn/_parameter.py
justinchuby Feb 20, 2026
bb7fba1
fix: ensure unique output names in GraphBuilder._adapt_outputs
justinchuby Feb 21, 2026
391ab4c
test: update builder tests for v_ output name prefix
justinchuby Feb 21, 2026
7f0e6bf
fix: use / separator for node and value scope names
justinchuby Feb 21, 2026
e00bb6d
fix: conditionally push and pop module in __call__ based on module name
justinchuby Feb 21, 2026
21a5c13
fix: use / throughout node scope prefix, not just at the join point
justinchuby Feb 21, 2026
75b0751
feat: redesign module scope system with node metadata
justinchuby Feb 21, 2026
3135535
refactor: rename qualify_name to _qualify_initializer_name
justinchuby Feb 21, 2026
67ffb42
refactor: use '.' delimiter for value names, '/' for node names
justinchuby Feb 21, 2026
feb1ff7
fix: namespace uses 'name: class' pairs, align scope list lengths
justinchuby Feb 21, 2026
1a1a7a6
Update names
justinchuby Feb 21, 2026
9cd8f20
fix: append op_type to name_scopes so lengths match class_hierarchy
justinchuby Feb 21, 2026
77e0d68
fix: remove op_type from namespace, class_hierarchy, and name_scopes
justinchuby Feb 21, 2026
9613200
Format
justinchuby Feb 21, 2026
5667be7
refactor: use ast.literal_eval instead of eval in tests
justinchuby Feb 21, 2026
6bb9931
Update name check
justinchuby Feb 21, 2026
7741555
Implement ModuleList
justinchuby Feb 21, 2026
248742e
Update Parameter to take dtype from data
justinchuby Feb 22, 2026
d5e9710
pylint
justinchuby Feb 22, 2026
eae41ac
fix check error
justinchuby Feb 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"script",
"graph",
"ir",
"nn",
"optimizer",
"rewriter",
"version_converter",
Expand Down Expand Up @@ -127,7 +128,7 @@

# isort: on

from . import ir, optimizer, rewriter, version_converter
from . import ir, nn, optimizer, rewriter, version_converter
from ._internal.builder import GraphBuilder
from ._internal.utils import external_tensor
from ._internal.values import OnnxFunction, TracedOnnxFunction
Expand Down
109 changes: 79 additions & 30 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def __init__(self, graph: ir.Graph) -> None:

self._op_builder = self.opset("", opset_version)

# Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done.
# The current context is used as a prefix for naming values and nodes.
# This allows us to generate names like "layer1.attention.query"
self._context_stack: list[str] = [""]
# Module scope stack. Each entry is (name, class_name) where name is
Comment thread
gramalingam marked this conversation as resolved.
# the module attribute name (e.g. "layers.0", "self_attn") and
# class_name is the qualified class name (e.g. "Gemma3DecoderLayer").
self._scope_stack: list[tuple[str, str]] = []

# Cache for constant initializers (scalars and sequences), keyed by (value, dtype).
# This avoids creating duplicate initializers for the same constant
Expand All @@ -109,7 +109,7 @@ def initializer(
if name is None:
name = tensor.name
if qualify:
name = self.qualify_name(name)
name = self._qualify_initializer_name(name)
shape = ir.Shape(tensor.shape)
value = ir.Value(
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
Expand Down Expand Up @@ -180,24 +180,26 @@ def _adapt_outputs(
self, outputs: int | Sequence[str | ir.Value], op_type: str = ""
) -> Sequence[ir.Value]:
if isinstance(outputs, int):
count = self.graph.num_nodes()
if outputs < 0:
raise ValueError(f"Number of outputs must be non-negative, got {outputs}")
if outputs == 1:
name = f"{op_type}_output" if op_type else "output"
return [ir.Value(name=self.qualify_name(name))]
name = f"{op_type}_{count}" if op_type else f"{count}"
return [ir.Value(name=self._qualify_value_name(name))]
else:
names = [
f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs)
(f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}")
for i in range(outputs)
]
return [ir.Value(name=self.qualify_name(n)) for n in names]
return [ir.Value(name=self._qualify_value_name(n)) for n in names]
adapted_outputs = []
for output in outputs:
if isinstance(output, ir.Value):
if output.name:
output.name = self.qualify_name(output.name)
output.name = self._qualify_value_name(output.name)
adapted_outputs.append(output)
elif isinstance(output, str):
adapted_outputs.append(ir.Value(name=self.qualify_name(output)))
adapted_outputs.append(ir.Value(name=self._qualify_value_name(output)))
else:
raise TypeError("Output type not supported.")
return adapted_outputs
Expand Down Expand Up @@ -304,7 +306,7 @@ def call_op(
outputs = kwargs.pop("_outputs", 1)

count = self.graph.num_nodes()
node_name = self.qualify_name(f"{op_type}_node_{count}")
node_name = self._qualify_node_name(f"{op_type}_node_{count}")

output_values = self._adapt_outputs(outputs, op_type)

Expand All @@ -322,33 +324,80 @@ def call_op(
version=version,
name=node_name,
)

# Attach scope metadata to the node
node.metadata_props["namespace"] = self._build_namespace()
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
Comment thread
gramalingam marked this conversation as resolved.
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())

self.add_node(node)

return node.outputs if len(node.outputs) > 1 else node.outputs[0]

def push_module(self, module: str) -> None:
"""Push a new naming context onto the stack (e.g. a layer or module name)."""
current = self.context_name()
if module:
new_context = f"{current}.{module}" if current else module
else:
new_context = current
self._context_stack.append(new_context)
def push_module(self, module: str, class_name: str = "") -> None:
Comment thread
gramalingam marked this conversation as resolved.
"""Push a new module scope onto the stack.

Args:
module: The attribute name of the module (e.g. ``"layers.0"``).
class_name: The qualified class name (e.g. ``"Gemma3DecoderLayer"``).
"""
self._scope_stack.append((module, class_name))

def pop_module(self) -> None:
"""Pop the most recent naming context off the stack."""
if len(self._context_stack) <= 1:
"""Pop the most recent module scope off the stack."""
if not self._scope_stack:
raise RuntimeError("Cannot pop_module: no module context has been pushed.")
self._context_stack.pop()
self._scope_stack.pop()

def _scope_names(self) -> list[str]:
"""Return the list of module attribute names in the current scope."""
return [name for name, _ in self._scope_stack]

def _scope_classes(self) -> list[str]:
"""Return the list of class names in the current scope."""
return [cls for _, cls in self._scope_stack]

def context_name(self) -> str:
"""Return the current dot-separated naming context prefix."""
return self._context_stack[-1] if self._context_stack else ""
def _scope_name_parts(self) -> list[str]:
"""Return non-empty module names for qualifying names."""
return [name for name, _ in self._scope_stack if name]

def qualify_name(self, name: str) -> str:
"""Prepend the current hierarchical context prefix to the given name."""
prefix = self.context_name()
return f"{prefix}.{name}" if prefix else name
def _qualify_initializer_name(self, name: str) -> str:
"""Prepend the current hierarchical context prefix to the given name.

Uses ``.`` as separator, appropriate for parameter and initializer names.
"""
parts = self._scope_name_parts()
if parts:
return ".".join(parts) + "." + name
return name

def _qualify_value_name(self, name: str) -> str:
"""Qualify a value name with the current scope using ``.`` separator.

The name is prefixed with ``v_`` to distinguish values from parameters.
"""
parts = self._scope_name_parts()
if parts:
return "v_" + ".".join(parts) + "." + name
return f"v_{name}"

def _qualify_node_name(self, name: str) -> str:
"""Qualify a node name with the current scope using ``/`` separator."""
parts = self._scope_name_parts()
if parts:
return "/".join(parts) + "/" + name
return name

def _build_namespace(self) -> str:
"""Build the namespace string for a node.

Each scope entry is formatted as ``name: class_name`` joined by ``/``.
"""
parts = []
for name, cls in self._scope_stack:
if name or cls:
parts.append(f"{name}: {cls}" if cls else name)
return "/".join(parts)


class OpBuilder:
Expand Down
Loading
Loading