Skip to content

Commit dbc841e

Browse files
authored
Update to target ABI v2 (#537)
* Spread usage of NonNull to mark the pointer returned from capsules * Update multiarray API to match ABI v2 * Update ufunc API to match ABI v2 * Fix missing PyUFunc_Type * Update FFI associated types and constants * Check compatibility with chosen ABI/API versions * Documentation improvements * Replace `std::os::raw` with `std::ffi` in `npyffi` * Update changelog
1 parent 47b3da8 commit dbc841e

File tree

12 files changed

+347
-387
lines changed

12 files changed

+347
-387
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Changelog
22

33
- Unreleased
4-
- Fix PyArray_DTypeMeta definition when Py_LIMITED_API is disabled (#532)
4+
- Fix PyArray_DTypeMeta definition when Py_LIMITED_API is disabled ([#532](https://github.com/PyO3/rust-numpy/pull/532))
5+
- The NumPy C API binding has been updated to target the ABI v2, while maintaining runtime compatibility with NumPy v1 targeting the API v1.15. The higher interface is unchanged. ([#537](https://github.com/PyO3/rust-numpy/pull/537))
56

67
- v0.28.0
78
- Fix mismatched behavior between `PyArrayLike1` and `PyArrayLike2` when used with floats ([#520](https://github.com/PyO3/rust-numpy/pull/520))

src/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
129129
const MODULE: Option<&'static str> = Some("numpy");
130130

131131
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
132-
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
132+
unsafe { npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
133133
}
134134

135135
fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {
@@ -233,7 +233,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
233233
let mut dims = dims.into_dimension();
234234
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
235235
py,
236-
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
236+
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
237237
T::get_dtype(py).into_dtype_ptr(),
238238
dims.ndim_cint(),
239239
dims.as_dims_ptr(),
@@ -259,7 +259,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
259259
let mut dims = dims.into_dimension();
260260
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
261261
py,
262-
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
262+
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
263263
T::get_dtype(py).into_dtype_ptr(),
264264
dims.ndim_cint(),
265265
dims.as_dims_ptr(),

src/dtype.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use pyo3::{
1717
};
1818

1919
use crate::npyffi::{
20-
NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
20+
self, NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
2121
PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
2222
NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
2323
};
@@ -58,7 +58,7 @@ unsafe impl PyTypeInfo for PyArrayDescr {
5858

5959
#[inline]
6060
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61-
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
61+
unsafe { npyffi::get_type_object(py, NpyTypes::PyArrayDescr_Type) }
6262
}
6363
}
6464

src/npyffi/array.rs

Lines changed: 116 additions & 144 deletions
Large diffs are not rendered by default.

src/npyffi/flags.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use super::npy_uint32;
2-
use std::os::raw::c_int;
1+
use super::{npy_uint32, npy_uint64};
2+
use std::ffi::c_int;
33

44
pub const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x0001;
55
pub const NPY_ARRAY_F_CONTIGUOUS: c_int = 0x0002;
@@ -11,8 +11,8 @@ pub const NPY_ARRAY_ELEMENTSTRIDES: c_int = 0x0080;
1111
pub const NPY_ARRAY_ALIGNED: c_int = 0x0100;
1212
pub const NPY_ARRAY_NOTSWAPPED: c_int = 0x0200;
1313
pub const NPY_ARRAY_WRITEABLE: c_int = 0x0400;
14-
pub const NPY_ARRAY_UPDATEIFCOPY: c_int = 0x1000;
1514
pub const NPY_ARRAY_WRITEBACKIFCOPY: c_int = 0x2000;
15+
pub const NPY_ARRAY_ENSURENOCOPY: c_int = 0x4000;
1616
pub const NPY_ARRAY_BEHAVED: c_int = NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE;
1717
pub const NPY_ARRAY_BEHAVED_NS: c_int = NPY_ARRAY_BEHAVED | NPY_ARRAY_NOTSWAPPED;
1818
pub const NPY_ARRAY_CARRAY: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED;
@@ -22,13 +22,14 @@ pub const NPY_ARRAY_FARRAY_RO: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNE
2222
pub const NPY_ARRAY_DEFAULT: c_int = NPY_ARRAY_CARRAY;
2323
pub const NPY_ARRAY_IN_ARRAY: c_int = NPY_ARRAY_CARRAY_RO;
2424
pub const NPY_ARRAY_OUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
25-
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY;
25+
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
2626
pub const NPY_ARRAY_INOUT_ARRAY2: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
2727
pub const NPY_ARRAY_IN_FARRAY: c_int = NPY_ARRAY_FARRAY_RO;
2828
pub const NPY_ARRAY_OUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
29-
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_UPDATEIFCOPY;
29+
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
3030
pub const NPY_ARRAY_INOUT_FARRAY2: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
31-
pub const NPY_ARRAY_UPDATE_ALL: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;
31+
pub const NPY_ARRAY_UPDATE_ALL: c_int =
32+
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNED;
3233

3334
pub const NPY_ITER_C_INDEX: npy_uint32 = 0x00000001;
3435
pub const NPY_ITER_F_INDEX: npy_uint32 = 0x00000002;
@@ -63,19 +64,18 @@ pub const NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE: npy_uint32 = 0x40000000;
6364
pub const NPY_ITER_GLOBAL_FLAGS: npy_uint32 = 0x0000ffff;
6465
pub const NPY_ITER_PER_OP_FLAGS: npy_uint32 = 0xffff0000;
6566

66-
pub const NPY_ITEM_REFCOUNT: u64 = 0x01;
67-
pub const NPY_ITEM_HASOBJECT: u64 = 0x01;
68-
pub const NPY_LIST_PICKLE: u64 = 0x02;
69-
pub const NPY_ITEM_IS_POINTER: u64 = 0x04;
70-
pub const NPY_NEEDS_INIT: u64 = 0x08;
71-
pub const NPY_NEEDS_PYAPI: u64 = 0x10;
72-
pub const NPY_USE_GETITEM: u64 = 0x20;
73-
pub const NPY_USE_SETITEM: u64 = 0x40;
74-
#[allow(overflowing_literals)]
75-
pub const NPY_ALIGNED_STRUCT: u64 = 0x80;
76-
pub const NPY_FROM_FIELDS: u64 =
67+
pub const NPY_ITEM_REFCOUNT: npy_uint64 = 0x01;
68+
pub const NPY_ITEM_HASOBJECT: npy_uint64 = 0x01;
69+
pub const NPY_LIST_PICKLE: npy_uint64 = 0x02;
70+
pub const NPY_ITEM_IS_POINTER: npy_uint64 = 0x04;
71+
pub const NPY_NEEDS_INIT: npy_uint64 = 0x08;
72+
pub const NPY_NEEDS_PYAPI: npy_uint64 = 0x10;
73+
pub const NPY_USE_GETITEM: npy_uint64 = 0x20;
74+
pub const NPY_USE_SETITEM: npy_uint64 = 0x40;
75+
pub const NPY_ALIGNED_STRUCT: npy_uint64 = 0x80;
76+
pub const NPY_FROM_FIELDS: npy_uint64 =
7777
NPY_NEEDS_INIT | NPY_LIST_PICKLE | NPY_ITEM_REFCOUNT | NPY_NEEDS_PYAPI;
78-
pub const NPY_OBJECT_DTYPE_FLAGS: u64 = NPY_LIST_PICKLE
78+
pub const NPY_OBJECT_DTYPE_FLAGS: npy_uint64 = NPY_LIST_PICKLE
7979
| NPY_USE_GETITEM
8080
| NPY_ITEM_IS_POINTER
8181
| NPY_ITEM_REFCOUNT

src/npyffi/mod.rs

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
//! Low-Level bindings for NumPy C API.
22
//!
3-
//! <https://numpy.org/doc/stable/reference/c-api>
3+
//! This module provides FFI bindings to [NumPy C API], implementing access to the NumPy array and
4+
//! ufunc functionality. This binding is compatible with ABI v2 and the target API is v1.15 to
5+
//! ensure the compatibility with the older NumPy version. See the official NumPy documentation
6+
//! for more details about [API compatibility].
7+
//!
8+
//! [NumPy's C API]: https://numpy.org/doc/stable/reference/c-api
9+
//! [API compatibility]: https://numpy.org/doc/stable/dev/depending_on_numpy.html
10+
//!
411
#![allow(
512
non_camel_case_types,
613
missing_docs,
@@ -9,46 +16,42 @@
916
clippy::missing_safety_doc
1017
)]
1118

19+
use std::ffi::{c_uint, c_void};
1220
use std::mem::forget;
13-
use std::os::raw::{c_uint, c_void};
21+
use std::ptr::NonNull;
1422

1523
use pyo3::{
24+
ffi::PyTypeObject,
1625
sync::PyOnceLock,
1726
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
1827
PyResult, Python,
1928
};
2029

21-
pub const API_VERSION_2_0: c_uint = 0x00000012;
22-
2330
static API_VERSION: PyOnceLock<c_uint> = PyOnceLock::new();
2431

2532
fn get_numpy_api<'py>(
2633
py: Python<'py>,
2734
module: &str,
2835
capsule: &str,
29-
) -> PyResult<*const *const c_void> {
36+
) -> PyResult<NonNull<*const c_void>> {
3037
let module = PyModule::import(py, module)?;
3138
let capsule = module.getattr(capsule)?.cast_into::<PyCapsule>()?;
3239

33-
let api = capsule
34-
.pointer_checked(None)?
35-
.cast::<*const c_void>()
36-
.as_ptr()
37-
.cast_const();
40+
let api = capsule.pointer_checked(None)?;
3841

3942
// Intentionally leak a reference to the capsule
4043
// so we can safely cache a pointer into its interior.
4144
forget(capsule);
4245

43-
Ok(api)
46+
Ok(api.cast())
4447
}
4548

4649
/// Returns whether the runtime `numpy` version is 2.0 or greater.
4750
pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
4851
let api_version = *API_VERSION.get_or_init(py, || unsafe {
4952
PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
5053
});
51-
api_version >= API_VERSION_2_0
54+
api_version >= NPY_2_0_API_VERSION
5255
}
5356

5457
// Implements wrappers for NumPy's Array and UFunc API
@@ -57,52 +60,90 @@ macro_rules! impl_api {
5760
[$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
5861
#[allow(non_snake_case)]
5962
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
60-
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*;
61-
(*fptr)($($arg), *)
63+
let f: extern "C" fn ($($arg : $t), *) $(-> $ret)* = self.get(py, $offset).cast().read();
64+
f($($arg), *)
6265
}
6366
};
67+
}
6468

65-
// API with version constraints, checked at runtime
66-
[$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
67-
#[allow(non_snake_case)]
68-
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
69-
assert!(
70-
!is_numpy_2(py),
71-
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
72-
stringify!($fname),
73-
API_VERSION_2_0,
74-
*API_VERSION.get(py).expect("API_VERSION is initialized"),
75-
);
76-
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
77-
(*fptr)($($arg), *)
78-
}
69+
// Define type objects associated with the NumPy API
70+
macro_rules! impl_array_type {
71+
($(($api:ident [ $offset:expr ] , $tname:ident)),* $(,)?) => {
72+
/// All type objects exported by the NumPy API.
73+
#[allow(non_camel_case_types)]
74+
pub enum NpyTypes { $($tname),* }
7975

80-
};
81-
[$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
82-
#[allow(non_snake_case)]
83-
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
84-
assert!(
85-
is_numpy_2(py),
86-
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
87-
stringify!($fname),
88-
API_VERSION_2_0,
89-
*API_VERSION.get(py).expect("API_VERSION is initialized"),
90-
);
91-
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
92-
(*fptr)($($arg), *)
76+
/// Get a pointer of the type object associated with `ty`.
77+
pub unsafe fn get_type_object<'py>(py: Python<'py>, ty: NpyTypes) -> *mut PyTypeObject {
78+
match ty {
79+
$( NpyTypes::$tname => $api.get(py, $offset).read() as _ ),*
80+
}
9381
}
82+
}
83+
}
9484

95-
};
85+
impl_array_type! {
86+
// Multiarray API
87+
// Slot 1 was never meaningfully used by NumPy
88+
(PY_ARRAY_API[2], PyArray_Type),
89+
(PY_ARRAY_API[3], PyArrayDescr_Type),
90+
// Unused slot 4, was `PyArrayFlags_Type`
91+
(PY_ARRAY_API[5], PyArrayIter_Type),
92+
(PY_ARRAY_API[6], PyArrayMultiIter_Type),
93+
// (PY_ARRAY_API[7], NPY_NUMUSERTYPES) -> c_int,
94+
(PY_ARRAY_API[8], PyBoolArrType_Type),
95+
// (PY_ARRAY_API[9], _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
96+
(PY_ARRAY_API[10], PyGenericArrType_Type),
97+
(PY_ARRAY_API[11], PyNumberArrType_Type),
98+
(PY_ARRAY_API[12], PyIntegerArrType_Type),
99+
(PY_ARRAY_API[13], PySignedIntegerArrType_Type),
100+
(PY_ARRAY_API[14], PyUnsignedIntegerArrType_Type),
101+
(PY_ARRAY_API[15], PyInexactArrType_Type),
102+
(PY_ARRAY_API[16], PyFloatingArrType_Type),
103+
(PY_ARRAY_API[17], PyComplexFloatingArrType_Type),
104+
(PY_ARRAY_API[18], PyFlexibleArrType_Type),
105+
(PY_ARRAY_API[19], PyCharacterArrType_Type),
106+
(PY_ARRAY_API[20], PyByteArrType_Type),
107+
(PY_ARRAY_API[21], PyShortArrType_Type),
108+
(PY_ARRAY_API[22], PyIntArrType_Type),
109+
(PY_ARRAY_API[23], PyLongArrType_Type),
110+
(PY_ARRAY_API[24], PyLongLongArrType_Type),
111+
(PY_ARRAY_API[25], PyUByteArrType_Type),
112+
(PY_ARRAY_API[26], PyUShortArrType_Type),
113+
(PY_ARRAY_API[27], PyUIntArrType_Type),
114+
(PY_ARRAY_API[28], PyULongArrType_Type),
115+
(PY_ARRAY_API[29], PyULongLongArrType_Type),
116+
(PY_ARRAY_API[30], PyFloatArrType_Type),
117+
(PY_ARRAY_API[31], PyDoubleArrType_Type),
118+
(PY_ARRAY_API[32], PyLongDoubleArrType_Type),
119+
(PY_ARRAY_API[33], PyCFloatArrType_Type),
120+
(PY_ARRAY_API[34], PyCDoubleArrType_Type),
121+
(PY_ARRAY_API[35], PyCLongDoubleArrType_Type),
122+
(PY_ARRAY_API[36], PyObjectArrType_Type),
123+
(PY_ARRAY_API[37], PyStringArrType_Type),
124+
(PY_ARRAY_API[38], PyUnicodeArrType_Type),
125+
(PY_ARRAY_API[39], PyVoidArrType_Type),
126+
(PY_ARRAY_API[214], PyTimeIntegerArrType_Type),
127+
(PY_ARRAY_API[215], PyDatetimeArrType_Type),
128+
(PY_ARRAY_API[216], PyTimedeltaArrType_Type),
129+
(PY_ARRAY_API[217], PyHalfArrType_Type),
130+
(PY_ARRAY_API[218], NpyIter_Type),
131+
// UFunc API
132+
(PY_UFUNC_API[0], PyUFunc_Type),
96133
}
97134

98135
pub mod array;
99136
pub mod flags;
137+
mod npy_common;
138+
mod numpyconfig;
100139
pub mod objects;
101140
pub mod types;
102141
pub mod ufunc;
103142

104143
pub use self::array::*;
105144
pub use self::flags::*;
145+
pub use self::npy_common::*;
146+
pub use self::numpyconfig::*;
106147
pub use self::objects::*;
107148
pub use self::types::*;
108149
pub use self::ufunc::*;

src/npyffi/npy_common.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
use std::ffi::c_int;
2+
3+
/// Unknown CPU endianness.
4+
pub const NPY_CPU_UNKNOWN_ENDIAN: c_int = 0;
5+
/// CPU is little-endian.
6+
pub const NPY_CPU_LITTLE: c_int = 1;
7+
/// CPU is big-endian.
8+
pub const NPY_CPU_BIG: c_int = 2;

src/npyffi/numpyconfig.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// This file matches the numpyconfig.h header.
2+
3+
use std::ffi::c_uint;
4+
5+
/// The current target ABI version
6+
const NPY_ABI_VERSION: c_uint = 0x02000000;
7+
8+
/// The current target API version (v1.15)
9+
const NPY_API_VERSION: c_uint = 0x0000000c;
10+
11+
pub(super) const NPY_2_0_API_VERSION: c_uint = 0x00000012;
12+
13+
/// The current version of the `ndarray` object (ABI version).
14+
pub const NPY_VERSION: c_uint = NPY_ABI_VERSION;
15+
/// The current version of C API.
16+
pub const NPY_FEATURE_VERSION: c_uint = NPY_API_VERSION;
17+
/// The string representation of current version C API.
18+
pub const NPY_FEATURE_VERSION_STRING: &str = "1.15";

0 commit comments

Comments
 (0)