Skip to content

Commit fd3a63d

Browse files
committed
Use NumPy C API to implement object to array conversion
1 parent 072d05e commit fd3a63d

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

src/array_like.rs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@ use std::marker::PhantomData;
22
use std::ops::Deref;
33

44
use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5-
use pyo3::{
6-
intern,
7-
sync::PyOnceLock,
8-
types::{PyAnyMethods, PyDict},
9-
Borrowed, FromPyObject, Py, PyAny, PyErr, PyResult,
10-
};
5+
use pyo3::{types::PyAnyMethods, Borrowed, FromPyObject, PyAny, PyErr, PyResult};
116

12-
use crate::array::PyArrayMethods;
13-
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
7+
use crate::npyffi::NPY_ARRAY_FORCECAST;
8+
use crate::{array::PyArrayMethods, PY_ARRAY_API};
9+
use crate::{Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
1410

1511
pub trait Coerce: Sealed {
1612
const ALLOW_TYPE_CHANGE: bool;
@@ -166,24 +162,32 @@ where
166162
}
167163
}
168164

169-
static AS_ARRAY: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
170-
171-
let as_array = AS_ARRAY
172-
.get_or_try_init(py, || {
173-
get_array_module(py)?.getattr("asarray").map(Into::into)
174-
})?
175-
.bind(py);
176-
177-
let kwargs = if C::ALLOW_TYPE_CHANGE {
178-
let kwargs = PyDict::new(py);
179-
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
180-
Some(kwargs)
165+
let (dtype, flags) = if C::ALLOW_TYPE_CHANGE {
166+
(Some(T::get_dtype(py)), NPY_ARRAY_FORCECAST)
181167
} else {
182-
None
168+
(None, 0)
169+
};
170+
171+
let newtype = dtype
172+
.as_ref()
173+
.map(|dt| dt.as_ptr() as *mut _)
174+
.unwrap_or_else(std::ptr::null_mut);
175+
176+
let array = unsafe {
177+
let ptr = PY_ARRAY_API.PyArray_FromAny(
178+
py,
179+
ob.as_ptr(),
180+
newtype,
181+
0,
182+
0,
183+
flags,
184+
std::ptr::null_mut(),
185+
);
186+
187+
pyo3::Bound::from_owned_ptr_or_err(py, ptr)?
183188
};
184189

185-
let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
186-
Ok(Self(array, PhantomData))
190+
Ok(Self(array.extract()?, PhantomData))
187191
}
188192
}
189193

0 commit comments

Comments
 (0)