Skip to content

Commit 986f75e

Browse files
committed
Fix missing PyUFunc_Type
1 parent b735453 commit 986f75e

6 files changed

Lines changed: 77 additions & 75 deletions

File tree

src/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
129129
const MODULE: Option<&'static str> = Some("numpy");
130130

131131
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
132-
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
132+
unsafe { npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
133133
}
134134

135135
fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {
@@ -233,7 +233,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
233233
let mut dims = dims.into_dimension();
234234
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
235235
py,
236-
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
236+
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
237237
T::get_dtype(py).into_dtype_ptr(),
238238
dims.ndim_cint(),
239239
dims.as_dims_ptr(),
@@ -259,7 +259,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
259259
let mut dims = dims.into_dimension();
260260
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
261261
py,
262-
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
262+
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
263263
T::get_dtype(py).into_dtype_ptr(),
264264
dims.ndim_cint(),
265265
dims.as_dims_ptr(),

src/dtype.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use pyo3::{
1717
};
1818

1919
use crate::npyffi::{
20-
NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
20+
self, NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
2121
PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
2222
NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
2323
};
@@ -58,7 +58,7 @@ unsafe impl PyTypeInfo for PyArrayDescr {
5858

5959
#[inline]
6060
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61-
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
61+
unsafe { npyffi::get_type_object(py, NpyTypes::PyArrayDescr_Type) }
6262
}
6363
}
6464

src/npyffi/array.rs

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ unsafe impl Send for PyArrayAPI {}
8282
unsafe impl Sync for PyArrayAPI {}
8383

8484
impl PyArrayAPI {
85-
unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> NonNull<*const c_void> {
85+
pub(super) unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> NonNull<*const c_void> {
8686
let api = self
8787
.0
8888
.get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME))
@@ -374,81 +374,16 @@ impl PyArrayAPI {
374374
// Min v2.0 impl_api![365; _PyDataType_GetArrFuncs(descr: *const PyArray_Descr) -> *mut PyArray_ArrFuncs];
375375
}
376376

377-
// Define type objects associated with the NumPy API
378-
macro_rules! impl_array_type {
379-
($(($offset:expr, $tname:ident)),* $(,)?) => {
380-
/// All type objects exported by the NumPy API.
381-
#[allow(non_camel_case_types)]
382-
pub enum NpyTypes { $($tname),* }
383-
384-
impl PyArrayAPI {
385-
/// Get a pointer of the type object associated with `ty`.
386-
pub unsafe fn get_type_object<'py>(&self, py: Python<'py>, ty: NpyTypes) -> *mut PyTypeObject {
387-
match ty {
388-
$( NpyTypes::$tname => self.get(py, $offset).read() as _ ),*
389-
}
390-
}
391-
}
392-
}
393-
}
394-
395-
impl_array_type! {
396-
// Slot 1 was never meaningfully used by NumPy
397-
(2, PyArray_Type),
398-
(3, PyArrayDescr_Type),
399-
// Unused slot 4, was `PyArrayFlags_Type`
400-
(5, PyArrayIter_Type),
401-
(6, PyArrayMultiIter_Type),
402-
// (7, NPY_NUMUSERTYPES) -> c_int,
403-
(8, PyBoolArrType_Type),
404-
// (9, _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
405-
(10, PyGenericArrType_Type),
406-
(11, PyNumberArrType_Type),
407-
(12, PyIntegerArrType_Type),
408-
(13, PySignedIntegerArrType_Type),
409-
(14, PyUnsignedIntegerArrType_Type),
410-
(15, PyInexactArrType_Type),
411-
(16, PyFloatingArrType_Type),
412-
(17, PyComplexFloatingArrType_Type),
413-
(18, PyFlexibleArrType_Type),
414-
(19, PyCharacterArrType_Type),
415-
(20, PyByteArrType_Type),
416-
(21, PyShortArrType_Type),
417-
(22, PyIntArrType_Type),
418-
(23, PyLongArrType_Type),
419-
(24, PyLongLongArrType_Type),
420-
(25, PyUByteArrType_Type),
421-
(26, PyUShortArrType_Type),
422-
(27, PyUIntArrType_Type),
423-
(28, PyULongArrType_Type),
424-
(29, PyULongLongArrType_Type),
425-
(30, PyFloatArrType_Type),
426-
(31, PyDoubleArrType_Type),
427-
(32, PyLongDoubleArrType_Type),
428-
(33, PyCFloatArrType_Type),
429-
(34, PyCDoubleArrType_Type),
430-
(35, PyCLongDoubleArrType_Type),
431-
(36, PyObjectArrType_Type),
432-
(37, PyStringArrType_Type),
433-
(38, PyUnicodeArrType_Type),
434-
(39, PyVoidArrType_Type),
435-
(214, PyTimeIntegerArrType_Type),
436-
(215, PyDatetimeArrType_Type),
437-
(216, PyTimedeltaArrType_Type),
438-
(217, PyHalfArrType_Type),
439-
(218, NpyIter_Type),
440-
}
441-
442377
/// Checks that `op` is an instance of `PyArray` or not.
443378
#[allow(non_snake_case)]
444379
pub unsafe fn PyArray_Check<'py>(py: Python<'py>, op: *mut PyObject) -> c_int {
445-
ffi::PyObject_TypeCheck(op, PY_ARRAY_API.get_type_object(py, NpyTypes::PyArray_Type))
380+
ffi::PyObject_TypeCheck(op, super::get_type_object(py, NpyTypes::PyArray_Type))
446381
}
447382

448383
/// Checks that `op` is an exact instance of `PyArray` or not.
449384
#[allow(non_snake_case)]
450385
pub unsafe fn PyArray_CheckExact<'py>(py: Python<'py>, op: *mut PyObject) -> c_int {
451-
(ffi::Py_TYPE(op) == PY_ARRAY_API.get_type_object(py, NpyTypes::PyArray_Type)) as _
386+
(ffi::Py_TYPE(op) == super::get_type_object(py, NpyTypes::PyArray_Type)) as _
452387
}
453388

454389
#[cfg(test)]

src/npyffi/mod.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::os::raw::{c_uint, c_void};
1414
use std::ptr::NonNull;
1515

1616
use pyo3::{
17+
ffi::PyTypeObject,
1718
sync::PyOnceLock,
1819
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
1920
PyResult, Python,
@@ -60,6 +61,72 @@ macro_rules! impl_api {
6061
};
6162
}
6263

64+
// Define type objects associated with the NumPy API
65+
macro_rules! impl_array_type {
66+
($(($api:ident [ $offset:expr ] , $tname:ident)),* $(,)?) => {
67+
/// All type objects exported by the NumPy API.
68+
#[allow(non_camel_case_types)]
69+
pub enum NpyTypes { $($tname),* }
70+
71+
/// Get a pointer of the type object associated with `ty`.
72+
pub unsafe fn get_type_object<'py>(py: Python<'py>, ty: NpyTypes) -> *mut PyTypeObject {
73+
match ty {
74+
$( NpyTypes::$tname => $api.get(py, $offset).read() as _ ),*
75+
}
76+
}
77+
}
78+
}
79+
80+
impl_array_type! {
81+
// Multiarray API
82+
// Slot 1 was never meaningfully used by NumPy
83+
(PY_ARRAY_API[2], PyArray_Type),
84+
(PY_ARRAY_API[3], PyArrayDescr_Type),
85+
// Unused slot 4, was `PyArrayFlags_Type`
86+
(PY_ARRAY_API[5], PyArrayIter_Type),
87+
(PY_ARRAY_API[6], PyArrayMultiIter_Type),
88+
// (PY_ARRAY_API[7], NPY_NUMUSERTYPES) -> c_int,
89+
(PY_ARRAY_API[8], PyBoolArrType_Type),
90+
// (PY_ARRAY_API[9], _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
91+
(PY_ARRAY_API[10], PyGenericArrType_Type),
92+
(PY_ARRAY_API[11], PyNumberArrType_Type),
93+
(PY_ARRAY_API[12], PyIntegerArrType_Type),
94+
(PY_ARRAY_API[13], PySignedIntegerArrType_Type),
95+
(PY_ARRAY_API[14], PyUnsignedIntegerArrType_Type),
96+
(PY_ARRAY_API[15], PyInexactArrType_Type),
97+
(PY_ARRAY_API[16], PyFloatingArrType_Type),
98+
(PY_ARRAY_API[17], PyComplexFloatingArrType_Type),
99+
(PY_ARRAY_API[18], PyFlexibleArrType_Type),
100+
(PY_ARRAY_API[19], PyCharacterArrType_Type),
101+
(PY_ARRAY_API[20], PyByteArrType_Type),
102+
(PY_ARRAY_API[21], PyShortArrType_Type),
103+
(PY_ARRAY_API[22], PyIntArrType_Type),
104+
(PY_ARRAY_API[23], PyLongArrType_Type),
105+
(PY_ARRAY_API[24], PyLongLongArrType_Type),
106+
(PY_ARRAY_API[25], PyUByteArrType_Type),
107+
(PY_ARRAY_API[26], PyUShortArrType_Type),
108+
(PY_ARRAY_API[27], PyUIntArrType_Type),
109+
(PY_ARRAY_API[28], PyULongArrType_Type),
110+
(PY_ARRAY_API[29], PyULongLongArrType_Type),
111+
(PY_ARRAY_API[30], PyFloatArrType_Type),
112+
(PY_ARRAY_API[31], PyDoubleArrType_Type),
113+
(PY_ARRAY_API[32], PyLongDoubleArrType_Type),
114+
(PY_ARRAY_API[33], PyCFloatArrType_Type),
115+
(PY_ARRAY_API[34], PyCDoubleArrType_Type),
116+
(PY_ARRAY_API[35], PyCLongDoubleArrType_Type),
117+
(PY_ARRAY_API[36], PyObjectArrType_Type),
118+
(PY_ARRAY_API[37], PyStringArrType_Type),
119+
(PY_ARRAY_API[38], PyUnicodeArrType_Type),
120+
(PY_ARRAY_API[39], PyVoidArrType_Type),
121+
(PY_ARRAY_API[214], PyTimeIntegerArrType_Type),
122+
(PY_ARRAY_API[215], PyDatetimeArrType_Type),
123+
(PY_ARRAY_API[216], PyTimedeltaArrType_Type),
124+
(PY_ARRAY_API[217], PyHalfArrType_Type),
125+
(PY_ARRAY_API[218], NpyIter_Type),
126+
// UFunc API
127+
(PY_UFUNC_API[0], PyUFunc_Type),
128+
}
129+
63130
pub mod array;
64131
pub mod flags;
65132
pub mod objects;

src/npyffi/ufunc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ unsafe impl Send for PyUFuncAPI {}
2929
unsafe impl Sync for PyUFuncAPI {}
3030

3131
impl PyUFuncAPI {
32-
unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> NonNull<*const c_void> {
32+
pub(super) unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> NonNull<*const c_void> {
3333
let api = self
3434
.0
3535
.get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME))

src/untyped_array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ unsafe impl PyTypeInfo for PyUntypedArray {
6464
const MODULE: Option<&'static str> = Some("numpy");
6565

6666
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
67-
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
67+
unsafe { npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
6868
}
6969

7070
fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {

0 commit comments

Comments
 (0)