Skip to content

Commit b82b62c

Browse files
authored
feat(dataclasses): allow creating dataclass types in the global namespace (#212)
1 parent 0fa140c commit b82b62c

3 files changed

Lines changed: 71 additions & 37 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
### Changed
1919

2020
- Enforce naming convention of packages with singular and plural: `optree.{accessor,integration}` -> `optree.{accessors,integrations}` by [@XuehaiPan](https://github.com/XuehaiPan) in [#209](https://github.com/metaopt/optree/pull/209).
21+
- Allow creating dataclass types in the global namespace by [@XuehaiPan](https://github.com/XuehaiPan) in [#212](https://github.com/metaopt/optree/pull/212).
2122

2223
### Fixed
2324

optree/dataclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def decorator(cls: _TypeT) -> _TypeT:
302302
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str):
303303
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
304304
if namespace == '':
305-
raise ValueError('The namespace cannot be an empty string.')
305+
namespace = GLOBAL_NAMESPACE
306306

307307
cls = dataclasses.dataclass(cls, **kwargs) # type: ignore[assignment]
308308

@@ -412,7 +412,7 @@ def make_dataclass( # type: ignore[no-redef] # noqa: C901,D417
412412
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str):
413413
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
414414
if namespace == '':
415-
raise ValueError('The namespace cannot be an empty string.')
415+
namespace = GLOBAL_NAMESPACE
416416

417417
dataclass_kwargs = {
418418
'init': init,

tests/test_dataclasses.py

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -368,30 +368,41 @@ class Foo1:
368368
x: int
369369
y: float
370370

371-
with pytest.raises(ValueError, match=re.escape('The namespace cannot be an empty string.')):
371+
@optree.dataclasses.dataclass(namespace='')
372+
class Foo2:
373+
x: int
374+
y: float
372375

373-
@optree.dataclasses.dataclass(namespace='')
374-
class Foo2:
375-
x: int
376-
y: float
376+
foo = Foo2(1, 2.0)
377+
accessors, leaves, treespec = optree.tree_flatten_with_accessor(foo)
378+
assert optree.tree_unflatten(treespec, leaves) == foo
379+
assert accessors == [
380+
optree.PyTreeAccessor((optree.DataclassEntry('x', Foo2, optree.PyTreeKind.CUSTOM),)),
381+
optree.PyTreeAccessor((optree.DataclassEntry('y', Foo2, optree.PyTreeKind.CUSTOM),)),
382+
]
383+
assert [a(foo) for a in accessors] == [1, 2.0]
384+
assert leaves == [1, 2.0]
385+
assert treespec.namespace == ''
386+
assert treespec.kind == optree.PyTreeKind.CUSTOM
387+
assert treespec.type is Foo2
377388

378389
@optree.dataclasses.dataclass(namespace=GLOBAL_NAMESPACE)
379-
class Foo:
390+
class Foo3:
380391
x: int
381392
y: float
382393

383-
foo = Foo(1, 2.0)
394+
foo = Foo3(1, 2.0)
384395
accessors, leaves, treespec = optree.tree_flatten_with_accessor(foo)
385396
assert optree.tree_unflatten(treespec, leaves) == foo
386397
assert accessors == [
387-
optree.PyTreeAccessor((optree.DataclassEntry('x', Foo, optree.PyTreeKind.CUSTOM),)),
388-
optree.PyTreeAccessor((optree.DataclassEntry('y', Foo, optree.PyTreeKind.CUSTOM),)),
398+
optree.PyTreeAccessor((optree.DataclassEntry('x', Foo3, optree.PyTreeKind.CUSTOM),)),
399+
optree.PyTreeAccessor((optree.DataclassEntry('y', Foo3, optree.PyTreeKind.CUSTOM),)),
389400
]
390401
assert [a(foo) for a in accessors] == [1, 2.0]
391402
assert leaves == [1, 2.0]
392403
assert treespec.namespace == ''
393404
assert treespec.kind == optree.PyTreeKind.CUSTOM
394-
assert treespec.type is Foo
405+
assert treespec.type is Foo3
395406

396407

397408
def test_make_dataclass_future_parameters():
@@ -684,30 +695,41 @@ def test_make_dataclass_with_invalid_namespace():
684695
with pytest.raises(TypeError, match='The namespace must be a string'):
685696
optree.dataclasses.make_dataclass('Foo1', ['x', ('y', int), ('z', float, 0.0)], namespace=1)
686697

687-
with pytest.raises(ValueError, match=re.escape('The namespace cannot be an empty string.')):
688-
optree.dataclasses.make_dataclass(
689-
'Foo2',
690-
['x', ('y', int), ('z', float, 0.0)],
691-
namespace='',
692-
)
698+
Foo2 = optree.dataclasses.make_dataclass( # noqa: N806
699+
'Foo2',
700+
[('x', int), ('y', float)],
701+
namespace='',
702+
)
703+
foo = Foo2(1, 2.0)
704+
accessors, leaves, treespec = optree.tree_flatten_with_accessor(foo)
705+
assert optree.tree_unflatten(treespec, leaves) == foo
706+
assert accessors == [
707+
optree.PyTreeAccessor((optree.DataclassEntry('x', Foo2, optree.PyTreeKind.CUSTOM),)),
708+
optree.PyTreeAccessor((optree.DataclassEntry('y', Foo2, optree.PyTreeKind.CUSTOM),)),
709+
]
710+
assert [a(foo) for a in accessors] == [1, 2.0]
711+
assert leaves == [1, 2.0]
712+
assert treespec.namespace == ''
713+
assert treespec.kind == optree.PyTreeKind.CUSTOM
714+
assert treespec.type is Foo2
693715

694-
Foo = optree.dataclasses.make_dataclass( # noqa: N806
695-
'Foo',
716+
Foo3 = optree.dataclasses.make_dataclass( # noqa: N806
717+
'Foo3',
696718
[('x', int), ('y', float)],
697719
namespace=GLOBAL_NAMESPACE,
698720
)
699-
foo = Foo(1, 2.0)
721+
foo = Foo3(1, 2.0)
700722
accessors, leaves, treespec = optree.tree_flatten_with_accessor(foo)
701723
assert optree.tree_unflatten(treespec, leaves) == foo
702724
assert accessors == [
703-
optree.PyTreeAccessor((optree.DataclassEntry('x', Foo, optree.PyTreeKind.CUSTOM),)),
704-
optree.PyTreeAccessor((optree.DataclassEntry('y', Foo, optree.PyTreeKind.CUSTOM),)),
725+
optree.PyTreeAccessor((optree.DataclassEntry('x', Foo3, optree.PyTreeKind.CUSTOM),)),
726+
optree.PyTreeAccessor((optree.DataclassEntry('y', Foo3, optree.PyTreeKind.CUSTOM),)),
705727
]
706728
assert [a(foo) for a in accessors] == [1, 2.0]
707729
assert leaves == [1, 2.0]
708730
assert treespec.namespace == ''
709731
assert treespec.kind == optree.PyTreeKind.CUSTOM
710-
assert treespec.type is Foo
732+
assert treespec.type is Foo3
711733

712734
with pytest.raises(TypeError, match='The namespace must be a string'):
713735
optree.dataclasses.make_dataclass(
@@ -717,29 +739,40 @@ def test_make_dataclass_with_invalid_namespace():
717739
namespace=None,
718740
)
719741

720-
with pytest.raises(ValueError, match=re.escape('The namespace cannot be an empty string.')):
721-
optree.dataclasses.make_dataclass(
722-
'Foo2',
723-
['x', ('y', int), ('z', float, 0.0)],
724-
ns='',
725-
namespace={},
726-
)
742+
Bar2 = optree.dataclasses.make_dataclass( # noqa: N806
743+
'Bar2',
744+
[('x', int), ('y', float)],
745+
ns='',
746+
namespace={},
747+
)
748+
bar = Bar2(1, 2.0)
749+
accessors, leaves, treespec = optree.tree_flatten_with_accessor(bar)
750+
assert optree.tree_unflatten(treespec, leaves) == bar
751+
assert accessors == [
752+
optree.PyTreeAccessor((optree.DataclassEntry('x', Bar2, optree.PyTreeKind.CUSTOM),)),
753+
optree.PyTreeAccessor((optree.DataclassEntry('y', Bar2, optree.PyTreeKind.CUSTOM),)),
754+
]
755+
assert [a(bar) for a in accessors] == [1, 2.0]
756+
assert leaves == [1, 2.0]
757+
assert treespec.namespace == ''
758+
assert treespec.kind == optree.PyTreeKind.CUSTOM
759+
assert treespec.type is Bar2
727760

728-
Bar = optree.dataclasses.make_dataclass( # noqa: N806
729-
'Bar',
761+
Bar3 = optree.dataclasses.make_dataclass( # noqa: N806
762+
'Bar3',
730763
[('x', int), ('y', float)],
731764
ns=GLOBAL_NAMESPACE,
732765
namespace={},
733766
)
734-
bar = Bar(1, 2.0)
767+
bar = Bar3(1, 2.0)
735768
accessors, leaves, treespec = optree.tree_flatten_with_accessor(bar)
736769
assert optree.tree_unflatten(treespec, leaves) == bar
737770
assert accessors == [
738-
optree.PyTreeAccessor((optree.DataclassEntry('x', Bar, optree.PyTreeKind.CUSTOM),)),
739-
optree.PyTreeAccessor((optree.DataclassEntry('y', Bar, optree.PyTreeKind.CUSTOM),)),
771+
optree.PyTreeAccessor((optree.DataclassEntry('x', Bar3, optree.PyTreeKind.CUSTOM),)),
772+
optree.PyTreeAccessor((optree.DataclassEntry('y', Bar3, optree.PyTreeKind.CUSTOM),)),
740773
]
741774
assert [a(bar) for a in accessors] == [1, 2.0]
742775
assert leaves == [1, 2.0]
743776
assert treespec.namespace == ''
744777
assert treespec.kind == optree.PyTreeKind.CUSTOM
745-
assert treespec.type is Bar
778+
assert treespec.type is Bar3

0 commit comments

Comments
 (0)