@@ -25,6 +25,16 @@ limitations under the License.
2525#include < pybind11/pybind11.h>
2626#include < pybind11/stl.h>
2727
28+ #if defined(PYBIND11_HAS_NATIVE_ENUM) || \
29+ (defined (PYBIND11_INTERNALS_VERSION) && PYBIND11_INTERNALS_VERSION >= 8 )
30+ #ifndef PYBIND11_HAS_NATIVE_ENUM
31+ #define PYBIND11_HAS_NATIVE_ENUM
32+ #endif
33+ #include < pybind11/native_enum.h>
34+ #else
35+ #undef PYBIND11_HAS_NATIVE_ENUM
36+ #endif
37+
2838namespace optree {
2939
3040py::module_ GetCxxModule (const std::optional<py::module_>& module ) {
@@ -175,6 +185,22 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
175185#define def_method_pos_only (...) def(__VA_ARGS__)
176186#endif
177187
188+ #ifdef PYBIND11_HAS_NATIVE_ENUM
189+ py::native_enum<PyTreeKind>(mod, " PyTreeKind" , " enum.IntEnum" , " The kind of a pytree node." )
190+ .value (" CUSTOM" , PyTreeKind::Custom, " A custom type." )
191+ .value (" LEAF" , PyTreeKind::Leaf, " A opaque leaf node." )
192+ .value (" NONE" , PyTreeKind::None, " None." )
193+ .value (" TUPLE" , PyTreeKind::Tuple, " A tuple." )
194+ .value (" LIST" , PyTreeKind::List, " A list." )
195+ .value (" DICT" , PyTreeKind::Dict, " A dict." )
196+ .value (" NAMEDTUPLE" , PyTreeKind::NamedTuple, " A collections.namedtuple." )
197+ .value (" ORDEREDDICT" , PyTreeKind::OrderedDict, " A collections.OrderedDict." )
198+ .value (" DEFAULTDICT" , PyTreeKind::DefaultDict, " A collections.defaultdict." )
199+ .value (" DEQUE" , PyTreeKind::Deque, " A collections.deque." )
200+ .value (" STRUCTSEQUENCE" , PyTreeKind::StructSequence, " A PyStructSequence." )
201+ .finalize ();
202+ auto PyTreeKindTypeObject = py::getattr (mod, " PyTreeKind" );
203+ #else
178204 auto PyTreeKindTypeObject =
179205 py::enum_<PyTreeKind>(mod, " PyTreeKind" , " The kind of a pytree node." , py::module_local ())
180206 .value (" CUSTOM" , PyTreeKind::Custom, " A custom type." )
@@ -188,6 +214,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
188214 .value (" DEFAULTDICT" , PyTreeKind::DefaultDict, " A collections.defaultdict." )
189215 .value (" DEQUE" , PyTreeKind::Deque, " A collections.deque." )
190216 .value (" STRUCTSEQUENCE" , PyTreeKind::StructSequence, " A PyStructSequence." );
217+ #endif
191218 auto * const PyTreeKind_Type = reinterpret_cast <PyTypeObject*>(PyTreeKindTypeObject.ptr ());
192219 PyTreeKind_Type->tp_name = " optree.PyTreeKind" ;
193220 py::setattr (PyTreeKindTypeObject.ptr (), Py_Get_ID (__module__), Py_Get_ID (optree));
@@ -412,25 +439,35 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
412439
413440#undef def_method_pos_only
414441
442+ // Make the PyTreeSpec and PyTreeIter types immutable.
415443#ifdef Py_TPFLAGS_IMMUTABLETYPE
416- PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
417444 PyTreeSpec_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
418- PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
419- PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
420445 PyTreeSpec_Type->tp_flags &= ~Py_TPFLAGS_READY;
446+ PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
421447 PyTreeIter_Type->tp_flags &= ~Py_TPFLAGS_READY;
422448#endif
423449
424- if (PyType_Ready (PyTreeKind_Type) < 0 ) [[unlikely]] {
425- INTERNAL_ERROR (" `PyType_Ready(&PyTreeKind_Type)` failed." );
426- }
427450 if (PyType_Ready (PyTreeSpec_Type) < 0 ) [[unlikely]] {
428451 INTERNAL_ERROR (" `PyType_Ready(&PyTreeSpec_Type)` failed." );
429452 }
430453 if (PyType_Ready (PyTreeIter_Type) < 0 ) [[unlikely]] {
431454 INTERNAL_ERROR (" `PyType_Ready(&PyTreeIter_Type)` failed." );
432455 }
433456
457+ #ifdef Py_TPFLAGS_IMMUTABLETYPE
458+ PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
459+
460+ // Only run `PyType_Ready` for C++ types.
461+ // Re-run `PyType_Ready` for native Python enums will cause unexpected behavior.
462+ #ifndef PYBIND11_HAS_NATIVE_ENUM
463+ PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
464+
465+ if (PyType_Ready (PyTreeKind_Type) < 0 ) [[unlikely]] {
466+ INTERNAL_ERROR (" `PyType_Ready(&PyTreeKind_Type)` failed." );
467+ }
468+ #endif
469+ #endif
470+
434471 py::getattr (py::module_::import (" atexit" ),
435472 " register" )(py::cpp_function (&PyTreeTypeRegistry::Clear));
436473}
0 commit comments