@@ -368,30 +368,41 @@ class Foo1:
368368 x : int
369369 y : float
370370
371- with pytest .raises (ValueError , match = re .escape ('The namespace cannot be an empty string.' )):
371+ @optree .dataclasses .dataclass (namespace = '' )
372+ class Foo2 :
373+ x : int
374+ y : float
372375
373- @optree .dataclasses .dataclass (namespace = '' )
374- class Foo2 :
375- x : int
376- y : float
376+ foo = Foo2 (1 , 2.0 )
377+ accessors , leaves , treespec = optree .tree_flatten_with_accessor (foo )
378+ assert optree .tree_unflatten (treespec , leaves ) == foo
379+ assert accessors == [
380+ optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Foo2 , optree .PyTreeKind .CUSTOM ),)),
381+ optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Foo2 , optree .PyTreeKind .CUSTOM ),)),
382+ ]
383+ assert [a (foo ) for a in accessors ] == [1 , 2.0 ]
384+ assert leaves == [1 , 2.0 ]
385+ assert treespec .namespace == ''
386+ assert treespec .kind == optree .PyTreeKind .CUSTOM
387+ assert treespec .type is Foo2
377388
378389 @optree .dataclasses .dataclass (namespace = GLOBAL_NAMESPACE )
379- class Foo :
390+ class Foo3 :
380391 x : int
381392 y : float
382393
383- foo = Foo (1 , 2.0 )
394+ foo = Foo3 (1 , 2.0 )
384395 accessors , leaves , treespec = optree .tree_flatten_with_accessor (foo )
385396 assert optree .tree_unflatten (treespec , leaves ) == foo
386397 assert accessors == [
387- optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Foo , optree .PyTreeKind .CUSTOM ),)),
388- optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Foo , optree .PyTreeKind .CUSTOM ),)),
398+ optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Foo3 , optree .PyTreeKind .CUSTOM ),)),
399+ optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Foo3 , optree .PyTreeKind .CUSTOM ),)),
389400 ]
390401 assert [a (foo ) for a in accessors ] == [1 , 2.0 ]
391402 assert leaves == [1 , 2.0 ]
392403 assert treespec .namespace == ''
393404 assert treespec .kind == optree .PyTreeKind .CUSTOM
394- assert treespec .type is Foo
405+ assert treespec .type is Foo3
395406
396407
397408def test_make_dataclass_future_parameters ():
@@ -684,30 +695,41 @@ def test_make_dataclass_with_invalid_namespace():
684695 with pytest .raises (TypeError , match = 'The namespace must be a string' ):
685696 optree .dataclasses .make_dataclass ('Foo1' , ['x' , ('y' , int ), ('z' , float , 0.0 )], namespace = 1 )
686697
687- with pytest .raises (ValueError , match = re .escape ('The namespace cannot be an empty string.' )):
688- optree .dataclasses .make_dataclass (
689- 'Foo2' ,
690- ['x' , ('y' , int ), ('z' , float , 0.0 )],
691- namespace = '' ,
692- )
698+ Foo2 = optree .dataclasses .make_dataclass ( # noqa: N806
699+ 'Foo2' ,
700+ [('x' , int ), ('y' , float )],
701+ namespace = '' ,
702+ )
703+ foo = Foo2 (1 , 2.0 )
704+ accessors , leaves , treespec = optree .tree_flatten_with_accessor (foo )
705+ assert optree .tree_unflatten (treespec , leaves ) == foo
706+ assert accessors == [
707+ optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Foo2 , optree .PyTreeKind .CUSTOM ),)),
708+ optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Foo2 , optree .PyTreeKind .CUSTOM ),)),
709+ ]
710+ assert [a (foo ) for a in accessors ] == [1 , 2.0 ]
711+ assert leaves == [1 , 2.0 ]
712+ assert treespec .namespace == ''
713+ assert treespec .kind == optree .PyTreeKind .CUSTOM
714+ assert treespec .type is Foo2
693715
694- Foo = optree .dataclasses .make_dataclass ( # noqa: N806
695- 'Foo ' ,
716+ Foo3 = optree .dataclasses .make_dataclass ( # noqa: N806
717+ 'Foo3 ' ,
696718 [('x' , int ), ('y' , float )],
697719 namespace = GLOBAL_NAMESPACE ,
698720 )
699- foo = Foo (1 , 2.0 )
721+ foo = Foo3 (1 , 2.0 )
700722 accessors , leaves , treespec = optree .tree_flatten_with_accessor (foo )
701723 assert optree .tree_unflatten (treespec , leaves ) == foo
702724 assert accessors == [
703- optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Foo , optree .PyTreeKind .CUSTOM ),)),
704- optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Foo , optree .PyTreeKind .CUSTOM ),)),
725+ optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Foo3 , optree .PyTreeKind .CUSTOM ),)),
726+ optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Foo3 , optree .PyTreeKind .CUSTOM ),)),
705727 ]
706728 assert [a (foo ) for a in accessors ] == [1 , 2.0 ]
707729 assert leaves == [1 , 2.0 ]
708730 assert treespec .namespace == ''
709731 assert treespec .kind == optree .PyTreeKind .CUSTOM
710- assert treespec .type is Foo
732+ assert treespec .type is Foo3
711733
712734 with pytest .raises (TypeError , match = 'The namespace must be a string' ):
713735 optree .dataclasses .make_dataclass (
@@ -717,29 +739,40 @@ def test_make_dataclass_with_invalid_namespace():
717739 namespace = None ,
718740 )
719741
720- with pytest .raises (ValueError , match = re .escape ('The namespace cannot be an empty string.' )):
721- optree .dataclasses .make_dataclass (
722- 'Foo2' ,
723- ['x' , ('y' , int ), ('z' , float , 0.0 )],
724- ns = '' ,
725- namespace = {},
726- )
742+ Bar2 = optree .dataclasses .make_dataclass ( # noqa: N806
743+ 'Bar2' ,
744+ [('x' , int ), ('y' , float )],
745+ ns = '' ,
746+ namespace = {},
747+ )
748+ bar = Bar2 (1 , 2.0 )
749+ accessors , leaves , treespec = optree .tree_flatten_with_accessor (bar )
750+ assert optree .tree_unflatten (treespec , leaves ) == bar
751+ assert accessors == [
752+ optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Bar2 , optree .PyTreeKind .CUSTOM ),)),
753+ optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Bar2 , optree .PyTreeKind .CUSTOM ),)),
754+ ]
755+ assert [a (bar ) for a in accessors ] == [1 , 2.0 ]
756+ assert leaves == [1 , 2.0 ]
757+ assert treespec .namespace == ''
758+ assert treespec .kind == optree .PyTreeKind .CUSTOM
759+ assert treespec .type is Bar2
727760
728- Bar = optree .dataclasses .make_dataclass ( # noqa: N806
729- 'Bar ' ,
761+ Bar3 = optree .dataclasses .make_dataclass ( # noqa: N806
762+ 'Bar3 ' ,
730763 [('x' , int ), ('y' , float )],
731764 ns = GLOBAL_NAMESPACE ,
732765 namespace = {},
733766 )
734- bar = Bar (1 , 2.0 )
767+ bar = Bar3 (1 , 2.0 )
735768 accessors , leaves , treespec = optree .tree_flatten_with_accessor (bar )
736769 assert optree .tree_unflatten (treespec , leaves ) == bar
737770 assert accessors == [
738- optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Bar , optree .PyTreeKind .CUSTOM ),)),
739- optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Bar , optree .PyTreeKind .CUSTOM ),)),
771+ optree .PyTreeAccessor ((optree .DataclassEntry ('x' , Bar3 , optree .PyTreeKind .CUSTOM ),)),
772+ optree .PyTreeAccessor ((optree .DataclassEntry ('y' , Bar3 , optree .PyTreeKind .CUSTOM ),)),
740773 ]
741774 assert [a (bar ) for a in accessors ] == [1 , 2.0 ]
742775 assert leaves == [1 , 2.0 ]
743776 assert treespec .namespace == ''
744777 assert treespec .kind == optree .PyTreeKind .CUSTOM
745- assert treespec .type is Bar
778+ assert treespec .type is Bar3
0 commit comments