1717
1818import builtins
1919import enum
20+ import platform
2021from collections .abc import Callable , Collection , Iterable , Iterator
21- from typing import Any
22+ from typing import Any , Final , final
2223from typing_extensions import Self
2324
2425from optree .typing import (
@@ -32,66 +33,23 @@ from optree.typing import (
3233 UnflattenFunc ,
3334)
3435
35- class InternalError (RuntimeError ): ...
36-
37- MAX_RECURSION_DEPTH : int
38-
3936# Set if the type allows subclassing (see CPython's Include/object.h)
40- Py_TPFLAGS_BASETYPE : int # (1UL << 10)
37+ Py_TPFLAGS_BASETYPE : Final [ int ] # (1UL << 10)
4138
42- GLIBCXX_USE_CXX11_ABI : bool
39+ # Meta information during build
40+ PY_VERSION : Final [str ]
41+ PY_VERSION_HEX : Final [int ]
42+ if platform .python_implementation () == 'PyPy' : # noqa: PYI002
43+ PYPY_VERSION : Final [str ]
44+ PYPY_VERSION_NUM : Final [int ]
45+ PYBIND11_VERSION_HEX : Final [int ]
46+ PYBIND11_INTERNALS_VERSION : Final [int ]
47+ GLIBCXX_USE_CXX11_ABI : Final [bool ]
4348
44- def flatten (
45- tree : PyTree [T ],
46- / ,
47- leaf_predicate : Callable [[T ], bool ] | None = None ,
48- none_is_leaf : bool = False ,
49- namespace : str = '' ,
50- ) -> tuple [list [T ], PyTreeSpec ]: ...
51- def flatten_with_path (
52- tree : PyTree [T ],
53- / ,
54- leaf_predicate : Callable [[T ], bool ] | None = None ,
55- none_is_leaf : bool = False ,
56- namespace : str = '' ,
57- ) -> tuple [list [tuple [Any , ...]], list [T ], PyTreeSpec ]: ...
58- def make_leaf (
59- none_is_leaf : bool = False ,
60- namespace : str = '' , # unused
61- ) -> PyTreeSpec : ...
62- def make_none (
63- none_is_leaf : bool = False ,
64- namespace : str = '' , # unused
65- ) -> PyTreeSpec : ...
66- def make_from_collection (
67- collection : Collection [PyTreeSpec ],
68- / ,
69- none_is_leaf : bool = False ,
70- namespace : str = '' ,
71- ) -> PyTreeSpec : ...
72- def is_leaf (
73- obj : T ,
74- / ,
75- leaf_predicate : Callable [[T ], bool ] | None = None ,
76- none_is_leaf : bool = False ,
77- namespace : str = '' ,
78- ) -> bool : ...
79- def all_leaves (
80- iterable : Iterable [T ],
81- / ,
82- leaf_predicate : Callable [[T ], bool ] | None = None ,
83- none_is_leaf : bool = False ,
84- namespace : str = '' ,
85- ) -> bool : ...
86- def is_namedtuple (obj : object | type , / ) -> bool : ...
87- def is_namedtuple_instance (obj : object , / ) -> bool : ...
88- def is_namedtuple_class (cls : type , / ) -> bool : ...
89- def namedtuple_fields (obj : tuple | type [tuple ], / ) -> tuple [str , ...]: ...
90- def is_structseq (obj : object | type , / ) -> bool : ...
91- def is_structseq_instance (obj : object , / ) -> bool : ...
92- def is_structseq_class (cls : type , / ) -> bool : ...
93- def structseq_fields (obj : tuple | type [tuple ], / ) -> tuple [str , ...]: ...
49+ @final
50+ class InternalError (RuntimeError ): ...
9451
52+ @final
9553class PyTreeKind (enum .IntEnum ):
9654 CUSTOM = 0 # a custom type
9755 LEAF = enum .auto () # an opaque leaf node
@@ -105,6 +63,9 @@ class PyTreeKind(enum.IntEnum):
10563 DEQUE = enum .auto () # a collections.deque
10664 STRUCTSEQUENCE = enum .auto () # a PyStructSequence
10765
66+ MAX_RECURSION_DEPTH : Final [int ]
67+
68+ @final
10869class PyTreeSpec :
10970 num_nodes : int
11071 num_leaves : int
@@ -157,6 +118,7 @@ class PyTreeSpec:
157118 def __hash__ (self , / ) -> int : ...
158119 def __len__ (self , / ) -> int : ...
159120
121+ @final
160122class PyTreeIter (Iterator [T ]):
161123 def __init__ (
162124 self ,
@@ -169,6 +131,63 @@ class PyTreeIter(Iterator[T]):
169131 def __iter__ (self , / ) -> Self : ...
170132 def __next__ (self , / ) -> T : ...
171133
134+ # Functions
135+ def flatten (
136+ tree : PyTree [T ],
137+ / ,
138+ leaf_predicate : Callable [[T ], bool ] | None = None ,
139+ none_is_leaf : bool = False ,
140+ namespace : str = '' ,
141+ ) -> tuple [list [T ], PyTreeSpec ]: ...
142+ def flatten_with_path (
143+ tree : PyTree [T ],
144+ / ,
145+ leaf_predicate : Callable [[T ], bool ] | None = None ,
146+ none_is_leaf : bool = False ,
147+ namespace : str = '' ,
148+ ) -> tuple [list [tuple [Any , ...]], list [T ], PyTreeSpec ]: ...
149+
150+ # Constructors
151+ def make_leaf (
152+ none_is_leaf : bool = False ,
153+ namespace : str = '' , # unused
154+ ) -> PyTreeSpec : ...
155+ def make_none (
156+ none_is_leaf : bool = False ,
157+ namespace : str = '' , # unused
158+ ) -> PyTreeSpec : ...
159+ def make_from_collection (
160+ collection : Collection [PyTreeSpec ],
161+ / ,
162+ none_is_leaf : bool = False ,
163+ namespace : str = '' ,
164+ ) -> PyTreeSpec : ...
165+
166+ # Utility functions
167+ def is_leaf (
168+ obj : T ,
169+ / ,
170+ leaf_predicate : Callable [[T ], bool ] | None = None ,
171+ none_is_leaf : bool = False ,
172+ namespace : str = '' ,
173+ ) -> bool : ...
174+ def all_leaves (
175+ iterable : Iterable [T ],
176+ / ,
177+ leaf_predicate : Callable [[T ], bool ] | None = None ,
178+ none_is_leaf : bool = False ,
179+ namespace : str = '' ,
180+ ) -> bool : ...
181+ def is_namedtuple (obj : object | type , / ) -> bool : ...
182+ def is_namedtuple_instance (obj : object , / ) -> bool : ...
183+ def is_namedtuple_class (cls : type , / ) -> bool : ...
184+ def namedtuple_fields (obj : tuple | type [tuple ], / ) -> tuple [str , ...]: ...
185+ def is_structseq (obj : object | type , / ) -> bool : ...
186+ def is_structseq_instance (obj : object , / ) -> bool : ...
187+ def is_structseq_class (cls : type , / ) -> bool : ...
188+ def structseq_fields (obj : tuple | type [tuple ], / ) -> tuple [str , ...]: ...
189+
190+ # Registration functions
172191def register_node (
173192 cls : type [Collection [T ]],
174193 / ,
0 commit comments