Skip to content

Commit 9c5dee0

Browse files
committed
feat(PyTreeKind): use pybind11::native_enum to create enum class PyTreeKind
1 parent b82b62c commit 9c5dee0

2 files changed

Lines changed: 50 additions & 9 deletions

File tree

include/optree/pymacros.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ limitations under the License.
2121

2222
#include <pybind11/pybind11.h>
2323

24-
namespace py = pybind11;
25-
26-
#if PY_VERSION_HEX < 0x03090000 // Python 3.9
24+
#if !(defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x03090000) // Python 3.9
2725
#error "Python 3.9 or newer is required."
2826
#endif
2927

28+
#if !(defined(PYBIND11_VERSION_HEX) && PYBIND11_VERSION_HEX >= 0x020C00F0) // pybind11 2.12.0
29+
#error "pybind11 2.12.0 or newer is required."
30+
#endif
31+
32+
namespace py = pybind11;
33+
3034
#ifndef Py_ALWAYS_INLINE
3135
#define Py_ALWAYS_INLINE
3236
#endif

src/optree.cpp

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ limitations under the License.
2525
#include <pybind11/pybind11.h>
2626
#include <pybind11/stl.h>
2727

28+
#if defined(PYBIND11_HAS_NATIVE_ENUM) || \
29+
(defined(PYBIND11_INTERNALS_VERSION) && PYBIND11_INTERNALS_VERSION >= 8)
30+
#ifndef PYBIND11_HAS_NATIVE_ENUM
31+
#define PYBIND11_HAS_NATIVE_ENUM
32+
#endif
33+
#include <pybind11/native_enum.h>
34+
#else
35+
#undef PYBIND11_HAS_NATIVE_ENUM
36+
#endif
37+
2838
namespace optree {
2939

3040
py::module_ GetCxxModule(const std::optional<py::module_>& module) {
@@ -175,6 +185,22 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
175185
#define def_method_pos_only(...) def(__VA_ARGS__)
176186
#endif
177187

188+
#ifdef PYBIND11_HAS_NATIVE_ENUM
189+
py::native_enum<PyTreeKind>(mod, "PyTreeKind", "enum.IntEnum", "The kind of a pytree node.")
190+
.value("CUSTOM", PyTreeKind::Custom, "A custom type.")
191+
.value("LEAF", PyTreeKind::Leaf, "A opaque leaf node.")
192+
.value("NONE", PyTreeKind::None, "None.")
193+
.value("TUPLE", PyTreeKind::Tuple, "A tuple.")
194+
.value("LIST", PyTreeKind::List, "A list.")
195+
.value("DICT", PyTreeKind::Dict, "A dict.")
196+
.value("NAMEDTUPLE", PyTreeKind::NamedTuple, "A collections.namedtuple.")
197+
.value("ORDEREDDICT", PyTreeKind::OrderedDict, "A collections.OrderedDict.")
198+
.value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.")
199+
.value("DEQUE", PyTreeKind::Deque, "A collections.deque.")
200+
.value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.")
201+
.finalize();
202+
auto PyTreeKindTypeObject = py::getattr(mod, "PyTreeKind");
203+
#else
178204
auto PyTreeKindTypeObject =
179205
py::enum_<PyTreeKind>(mod, "PyTreeKind", "The kind of a pytree node.", py::module_local())
180206
.value("CUSTOM", PyTreeKind::Custom, "A custom type.")
@@ -188,6 +214,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
188214
.value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.")
189215
.value("DEQUE", PyTreeKind::Deque, "A collections.deque.")
190216
.value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.");
217+
#endif
191218
auto* const PyTreeKind_Type = reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr());
192219
PyTreeKind_Type->tp_name = "optree.PyTreeKind";
193220
py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));
@@ -412,25 +439,35 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
412439

413440
#undef def_method_pos_only
414441

442+
// Make the PyTreeSpec and PyTreeIter types immutable.
415443
#ifdef Py_TPFLAGS_IMMUTABLETYPE
416-
PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
417444
PyTreeSpec_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
418-
PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
419-
PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
420445
PyTreeSpec_Type->tp_flags &= ~Py_TPFLAGS_READY;
446+
PyTreeIter_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
421447
PyTreeIter_Type->tp_flags &= ~Py_TPFLAGS_READY;
422448
#endif
423449

424-
if (PyType_Ready(PyTreeKind_Type) < 0) [[unlikely]] {
425-
INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed.");
426-
}
427450
if (PyType_Ready(PyTreeSpec_Type) < 0) [[unlikely]] {
428451
INTERNAL_ERROR("`PyType_Ready(&PyTreeSpec_Type)` failed.");
429452
}
430453
if (PyType_Ready(PyTreeIter_Type) < 0) [[unlikely]] {
431454
INTERNAL_ERROR("`PyType_Ready(&PyTreeIter_Type)` failed.");
432455
}
433456

457+
#ifdef Py_TPFLAGS_IMMUTABLETYPE
458+
PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
459+
460+
// Only run `PyType_Ready` for C++ types.
461+
// Re-run `PyType_Ready` for native Python enums will cause unexpected behavior.
462+
#ifndef PYBIND11_HAS_NATIVE_ENUM
463+
PyTreeKind_Type->tp_flags &= ~Py_TPFLAGS_READY;
464+
465+
if (PyType_Ready(PyTreeKind_Type) < 0) [[unlikely]] {
466+
INTERNAL_ERROR("`PyType_Ready(&PyTreeKind_Type)` failed.");
467+
}
468+
#endif
469+
#endif
470+
434471
py::getattr(py::module_::import("atexit"),
435472
"register")(py::cpp_function(&PyTreeTypeRegistry::Clear));
436473
}

0 commit comments

Comments
 (0)