Skip to content

Commit 3c382b3

Browse files
mattiapenatiIcxolu
authored andcommitted
Use NumPy C API to implement object to array conversion
1 parent 9123c34 commit 3c382b3

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

src/array_like.rs

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@ use std::marker::PhantomData;
22
use std::ops::Deref;
33

44
use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5-
use pyo3::{
6-
intern,
7-
sync::PyOnceLock,
8-
types::{PyAnyMethods, PyDict},
9-
Borrowed, FromPyObject, Py, PyAny, PyErr, PyResult,
10-
};
5+
use pyo3::{types::PyAnyMethods, Borrowed, FromPyObject, PyAny, PyErr, PyResult};
116

12-
use crate::array::PyArrayMethods;
13-
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
7+
use crate::npyffi::NPY_ARRAY_FORCECAST;
8+
use crate::{array::PyArrayMethods, PY_ARRAY_API};
9+
use crate::{Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
1410

1511
pub trait Coerce: Sealed {
1612
const ALLOW_TYPE_CHANGE: bool;
@@ -166,24 +162,31 @@ where
166162
}
167163
}
168164

169-
static AS_ARRAY: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
170-
171-
let as_array = AS_ARRAY
172-
.get_or_try_init(py, || {
173-
get_array_module(py)?.getattr("asarray").map(Into::into)
174-
})?
175-
.bind(py);
176-
177-
let kwargs = if C::ALLOW_TYPE_CHANGE {
178-
let kwargs = PyDict::new(py);
179-
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
180-
Some(kwargs)
165+
let (dtype, flags) = if C::ALLOW_TYPE_CHANGE {
166+
(Some(T::get_dtype(py)), NPY_ARRAY_FORCECAST)
181167
} else {
182-
None
168+
(None, 0)
169+
};
170+
171+
let newtype = dtype
172+
.map(|dt| dt.into_ptr().cast())
173+
.unwrap_or_else(std::ptr::null_mut);
174+
175+
let array = unsafe {
176+
let ptr = PY_ARRAY_API.PyArray_FromAny(
177+
py,
178+
ob.as_ptr(),
179+
newtype,
180+
0,
181+
0,
182+
flags,
183+
std::ptr::null_mut(),
184+
);
185+
186+
pyo3::Bound::from_owned_ptr_or_err(py, ptr)?
183187
};
184188

185-
let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
186-
Ok(Self(array, PhantomData))
189+
Ok(Self(array.extract()?, PhantomData))
187190
}
188191
}
189192

0 commit comments

Comments
 (0)