Skip to content

Commit 5d8a1f3

Browse files
committed
Speed up py::numpy_scalar<> type caster
1 parent 4734944 commit 5d8a1f3

1 file changed

Lines changed: 12 additions & 14 deletions

File tree

include/pybind11/numpy.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -482,28 +482,26 @@ struct type_caster<numpy_scalar<T>> {
482482

483483
PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
484484

485-
static object target_dtype() {
486-
auto& api = npy_api::get();
487-
return reinterpret_steal<object>(api.PyArray_DescrFromType_(type_info::typenum));
485+
static handle& target_type() {
486+
static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
487+
return tp;
488+
}
489+
490+
static handle& target_dtype() {
491+
static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
492+
return tp;
488493
}
489494

490495
bool load(handle src, bool) {
491-
auto& api = npy_api::get();
492-
auto target = target_dtype();
493-
if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(src.ptr()))) {
494-
if (api.PyArray_EquivTypes_(descr.ptr(), target.ptr())) {
495-
api.PyArray_ScalarAsCtype_(src.ptr(), &value.value);
496-
return true;
497-
}
496+
if (isinstance(src, target_type())) {
497+
npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
498+
return true;
498499
}
499500
return false;
500501
}
501502

502503
static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
503-
auto& api = npy_api::get();
504-
auto target = target_dtype();
505-
auto size = reinterpret_steal<object>(PyLong_FromLong(sizeof(value_type)));
506-
return api.PyArray_Scalar_(&src.value, target.ptr(), size.ptr());
504+
return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
507505
}
508506
};
509507

0 commit comments

Comments
 (0)