Skip to content

Commit feb1ff7

Browse files
committed
fix: namespace uses 'name: class' pairs, align scope list lengths
- namespace: each scope entry is 'name: class_name' joined by '/' e.g. 'layer1: DecoderLayer/self_attn: Attention/Add' - scope_names() and scope_classes() return all entries (no filtering) so class_hierarchy and name_scopes always have matching lengths - _scope_name_parts() filters empty names for initializer/value/node qualifying Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 67ffb42 commit feb1ff7

2 files changed

Lines changed: 23 additions & 12 deletions

File tree

onnxscript/_internal/builder.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -357,18 +357,22 @@ def pop_module(self) -> None:
357357

358358
def scope_names(self) -> list[str]:
359359
"""Return the list of module attribute names in the current scope."""
360-
return [name for name, _ in self._scope_stack if name]
360+
return [name for name, _ in self._scope_stack]
361361

362362
def scope_classes(self) -> list[str]:
363363
"""Return the list of class names in the current scope."""
364-
return [cls for _, cls in self._scope_stack if cls]
364+
return [cls for _, cls in self._scope_stack]
365+
366+
def _scope_name_parts(self) -> list[str]:
367+
"""Return non-empty module names for qualifying names."""
368+
return [name for name, _ in self._scope_stack if name]
365369

366370
def _qualify_initializer_name(self, name: str) -> str:
367371
"""Prepend the current hierarchical context prefix to the given name.
368372
369373
Uses ``.`` as separator, appropriate for parameter and initializer names.
370374
"""
371-
parts = self.scope_names()
375+
parts = self._scope_name_parts()
372376
if parts:
373377
return ".".join(parts) + "." + name
374378
return name
@@ -378,28 +382,30 @@ def _qualify_value_name(self, name: str) -> str:
378382
379383
The name is prefixed with ``v_`` to distinguish values from parameters.
380384
"""
381-
parts = self.scope_names()
385+
parts = self._scope_name_parts()
382386
if parts:
383387
return "v_" + ".".join(parts) + "." + name
384388
return f"v_{name}"
385389

386390
def _qualify_node_name(self, name: str) -> str:
387391
"""Qualify a node name with the current scope using ``/`` separator."""
388-
parts = self.scope_names()
392+
parts = self._scope_name_parts()
389393
if parts:
390394
return "/".join(parts) + "/" + name
391395
return name
392396

393397
def _build_namespace(self, op_type: str, domain: str = "") -> str:
394398
"""Build the namespace string for a node.
395399
396-
Format: ``scope1/scope2: domain.op_type`` or ``scope1/scope2: op_type``.
400+
Each scope entry is formatted as ``name: class_name`` joined by ``/``.
397401
"""
398-
scope = "/".join(self.scope_names())
402+
parts = []
403+
for name, cls in self._scope_stack:
404+
if name or cls:
405+
parts.append(f"{name}: {cls}" if cls else name)
399406
op_id = f"{domain}.{op_type}" if domain else op_type
400-
if scope:
401-
return f"{scope}: {op_id}"
402-
return op_id
407+
parts.append(op_id)
408+
return "/".join(parts)
403409

404410

405411
class OpBuilder:

onnxscript/_internal/builder_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,13 +560,13 @@ def test_node_metadata_props_namespace(self):
560560
# Node inside a module scope
561561
op.builder.push_module("layer1", "DecoderLayer")
562562
t2 = op.Mul(t1, y)
563-
self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: Mul")
563+
self.assertEqual(t2.producer().metadata_props["namespace"], "layer1: DecoderLayer/Mul")
564564

565565
# Nested scope
566566
op.builder.push_module("self_attn", "Attention")
567567
t3 = op.Add(t2, x)
568568
self.assertEqual(
569-
t3.producer().metadata_props["namespace"], "layer1/self_attn: Add"
569+
t3.producer().metadata_props["namespace"], "layer1: DecoderLayer/self_attn: Attention/Add"
570570
)
571571
op.builder.pop_module()
572572
op.builder.pop_module()
@@ -588,6 +588,11 @@ def test_node_metadata_props_class_hierarchy(self):
588588
node.metadata_props["pkg.onnxscript.name_scopes"],
589589
repr(["layer1", "self_attn"]),
590590
)
591+
# class_hierarchy includes one entry per scope plus the op
592+
self.assertEqual(
593+
len(eval(node.metadata_props["pkg.onnxscript.class_hierarchy"])),
594+
len(eval(node.metadata_props["pkg.onnxscript.name_scopes"])) + 1,
595+
)
591596
op.builder.pop_module()
592597
op.builder.pop_module()
593598

0 commit comments

Comments
 (0)