11//! Low-Level bindings for NumPy C API.
22//!
3- //! <https://numpy.org/doc/stable/reference/c-api>
3+ //! This module provides FFI bindings to [NumPy C API], implementing access to the NumPy array and
4+ //! ufunc functionality. This binding is compatible with ABI v2 and the target API is v1.15 to
5+ //! ensure the compatibility with the older NumPy version. See the official NumPy documentation
6+ //! for more details about [API compatibility].
7+ //!
8+ //! [NumPy's C API]: https://numpy.org/doc/stable/reference/c-api
9+ //! [API compatibility]: https://numpy.org/doc/stable/dev/depending_on_numpy.html
10+ //!
411#![ allow(
512 non_camel_case_types,
613 missing_docs,
916 clippy:: missing_safety_doc
1017) ]
1118
19+ use std:: ffi:: { c_uint, c_void} ;
1220use std:: mem:: forget;
13- use std:: os :: raw :: { c_uint , c_void } ;
21+ use std:: ptr :: NonNull ;
1422
1523use pyo3:: {
24+ ffi:: PyTypeObject ,
1625 sync:: PyOnceLock ,
1726 types:: { PyAnyMethods , PyCapsule , PyCapsuleMethods , PyModule } ,
1827 PyResult , Python ,
1928} ;
2029
21- pub const API_VERSION_2_0 : c_uint = 0x00000012 ;
22-
2330static API_VERSION : PyOnceLock < c_uint > = PyOnceLock :: new ( ) ;
2431
2532fn get_numpy_api < ' py > (
2633 py : Python < ' py > ,
2734 module : & str ,
2835 capsule : & str ,
29- ) -> PyResult < * const * const c_void > {
36+ ) -> PyResult < NonNull < * const c_void > > {
3037 let module = PyModule :: import ( py, module) ?;
3138 let capsule = module. getattr ( capsule) ?. cast_into :: < PyCapsule > ( ) ?;
3239
33- let api = capsule
34- . pointer_checked ( None ) ?
35- . cast :: < * const c_void > ( )
36- . as_ptr ( )
37- . cast_const ( ) ;
40+ let api = capsule. pointer_checked ( None ) ?;
3841
3942 // Intentionally leak a reference to the capsule
4043 // so we can safely cache a pointer into its interior.
4144 forget ( capsule) ;
4245
43- Ok ( api)
46+ Ok ( api. cast ( ) )
4447}
4548
4649/// Returns whether the runtime `numpy` version is 2.0 or greater.
4750pub fn is_numpy_2 < ' py > ( py : Python < ' py > ) -> bool {
4851 let api_version = * API_VERSION . get_or_init ( py, || unsafe {
4952 PY_ARRAY_API . PyArray_GetNDArrayCFeatureVersion ( py)
5053 } ) ;
51- api_version >= API_VERSION_2_0
54+ api_version >= NPY_2_0_API_VERSION
5255}
5356
5457// Implements wrappers for NumPy's Array and UFunc API
@@ -57,52 +60,90 @@ macro_rules! impl_api {
5760 [ $offset: expr; $fname: ident ( $( $arg: ident: $t: ty) ,* $( , ) ?) $( -> $ret: ty) ?] => {
5861 #[ allow( non_snake_case) ]
5962 pub unsafe fn $fname<' py>( & self , py: Python <' py>, $( $arg : $t) , * ) $( -> $ret) * {
60- let fptr = self . get ( py , $offset ) as * const extern "C" fn ( $( $arg : $t) , * ) $( -> $ret) * ;
61- ( * fptr ) ( $( $arg) , * )
63+ let f : extern "C" fn ( $( $arg : $t) , * ) $( -> $ret) * = self . get ( py , $offset ) . cast ( ) . read ( ) ;
64+ f ( $( $arg) , * )
6265 }
6366 } ;
67+ }
6468
65- // API with version constraints, checked at runtime
66- [ $offset: expr; NumPy1 ; $fname: ident ( $( $arg: ident: $t: ty) ,* $( , ) ?) $( -> $ret: ty) ?] => {
67- #[ allow( non_snake_case) ]
68- pub unsafe fn $fname<' py>( & self , py: Python <' py>, $( $arg : $t) , * ) $( -> $ret) * {
69- assert!(
70- !is_numpy_2( py) ,
71- "{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}" ,
72- stringify!( $fname) ,
73- API_VERSION_2_0 ,
74- * API_VERSION . get( py) . expect( "API_VERSION is initialized" ) ,
75- ) ;
76- let fptr = self . get( py, $offset) as * const extern "C" fn ( $( $arg: $t) , * ) $( -> $ret) * ;
77- ( * fptr) ( $( $arg) , * )
78- }
69+ // Define type objects associated with the NumPy API
70+ macro_rules! impl_array_type {
71+ ( $( ( $api: ident [ $offset: expr ] , $tname: ident) ) ,* $( , ) ?) => {
72+ /// All type objects exported by the NumPy API.
73+ #[ allow( non_camel_case_types) ]
74+ pub enum NpyTypes { $( $tname) ,* }
7975
80- } ;
81- [ $offset: expr; NumPy2 ; $fname: ident ( $( $arg: ident: $t: ty) ,* $( , ) ?) $( -> $ret: ty) ?] => {
82- #[ allow( non_snake_case) ]
83- pub unsafe fn $fname<' py>( & self , py: Python <' py>, $( $arg : $t) , * ) $( -> $ret) * {
84- assert!(
85- is_numpy_2( py) ,
86- "{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}" ,
87- stringify!( $fname) ,
88- API_VERSION_2_0 ,
89- * API_VERSION . get( py) . expect( "API_VERSION is initialized" ) ,
90- ) ;
91- let fptr = self . get( py, $offset) as * const extern "C" fn ( $( $arg: $t) , * ) $( -> $ret) * ;
92- ( * fptr) ( $( $arg) , * )
76+ /// Get a pointer of the type object associated with `ty`.
77+ pub unsafe fn get_type_object<' py>( py: Python <' py>, ty: NpyTypes ) -> * mut PyTypeObject {
78+ match ty {
79+ $( NpyTypes :: $tname => $api. get( py, $offset) . read( ) as _ ) ,*
80+ }
9381 }
82+ }
83+ }
9484
95- } ;
85+ impl_array_type ! {
86+ // Multiarray API
87+ // Slot 1 was never meaningfully used by NumPy
88+ ( PY_ARRAY_API [ 2 ] , PyArray_Type ) ,
89+ ( PY_ARRAY_API [ 3 ] , PyArrayDescr_Type ) ,
90+ // Unused slot 4, was `PyArrayFlags_Type`
91+ ( PY_ARRAY_API [ 5 ] , PyArrayIter_Type ) ,
92+ ( PY_ARRAY_API [ 6 ] , PyArrayMultiIter_Type ) ,
93+ // (PY_ARRAY_API[7], NPY_NUMUSERTYPES) -> c_int,
94+ ( PY_ARRAY_API [ 8 ] , PyBoolArrType_Type ) ,
95+ // (PY_ARRAY_API[9], _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
96+ ( PY_ARRAY_API [ 10 ] , PyGenericArrType_Type ) ,
97+ ( PY_ARRAY_API [ 11 ] , PyNumberArrType_Type ) ,
98+ ( PY_ARRAY_API [ 12 ] , PyIntegerArrType_Type ) ,
99+ ( PY_ARRAY_API [ 13 ] , PySignedIntegerArrType_Type ) ,
100+ ( PY_ARRAY_API [ 14 ] , PyUnsignedIntegerArrType_Type ) ,
101+ ( PY_ARRAY_API [ 15 ] , PyInexactArrType_Type ) ,
102+ ( PY_ARRAY_API [ 16 ] , PyFloatingArrType_Type ) ,
103+ ( PY_ARRAY_API [ 17 ] , PyComplexFloatingArrType_Type ) ,
104+ ( PY_ARRAY_API [ 18 ] , PyFlexibleArrType_Type ) ,
105+ ( PY_ARRAY_API [ 19 ] , PyCharacterArrType_Type ) ,
106+ ( PY_ARRAY_API [ 20 ] , PyByteArrType_Type ) ,
107+ ( PY_ARRAY_API [ 21 ] , PyShortArrType_Type ) ,
108+ ( PY_ARRAY_API [ 22 ] , PyIntArrType_Type ) ,
109+ ( PY_ARRAY_API [ 23 ] , PyLongArrType_Type ) ,
110+ ( PY_ARRAY_API [ 24 ] , PyLongLongArrType_Type ) ,
111+ ( PY_ARRAY_API [ 25 ] , PyUByteArrType_Type ) ,
112+ ( PY_ARRAY_API [ 26 ] , PyUShortArrType_Type ) ,
113+ ( PY_ARRAY_API [ 27 ] , PyUIntArrType_Type ) ,
114+ ( PY_ARRAY_API [ 28 ] , PyULongArrType_Type ) ,
115+ ( PY_ARRAY_API [ 29 ] , PyULongLongArrType_Type ) ,
116+ ( PY_ARRAY_API [ 30 ] , PyFloatArrType_Type ) ,
117+ ( PY_ARRAY_API [ 31 ] , PyDoubleArrType_Type ) ,
118+ ( PY_ARRAY_API [ 32 ] , PyLongDoubleArrType_Type ) ,
119+ ( PY_ARRAY_API [ 33 ] , PyCFloatArrType_Type ) ,
120+ ( PY_ARRAY_API [ 34 ] , PyCDoubleArrType_Type ) ,
121+ ( PY_ARRAY_API [ 35 ] , PyCLongDoubleArrType_Type ) ,
122+ ( PY_ARRAY_API [ 36 ] , PyObjectArrType_Type ) ,
123+ ( PY_ARRAY_API [ 37 ] , PyStringArrType_Type ) ,
124+ ( PY_ARRAY_API [ 38 ] , PyUnicodeArrType_Type ) ,
125+ ( PY_ARRAY_API [ 39 ] , PyVoidArrType_Type ) ,
126+ ( PY_ARRAY_API [ 214 ] , PyTimeIntegerArrType_Type ) ,
127+ ( PY_ARRAY_API [ 215 ] , PyDatetimeArrType_Type ) ,
128+ ( PY_ARRAY_API [ 216 ] , PyTimedeltaArrType_Type ) ,
129+ ( PY_ARRAY_API [ 217 ] , PyHalfArrType_Type ) ,
130+ ( PY_ARRAY_API [ 218 ] , NpyIter_Type ) ,
131+ // UFunc API
132+ ( PY_UFUNC_API [ 0 ] , PyUFunc_Type ) ,
96133}
97134
98135pub mod array;
99136pub mod flags;
137+ mod npy_common;
138+ mod numpyconfig;
100139pub mod objects;
101140pub mod types;
102141pub mod ufunc;
103142
104143pub use self :: array:: * ;
105144pub use self :: flags:: * ;
145+ pub use self :: npy_common:: * ;
146+ pub use self :: numpyconfig:: * ;
106147pub use self :: objects:: * ;
107148pub use self :: types:: * ;
108149pub use self :: ufunc:: * ;
0 commit comments