Skip to content

Commit e247928

Browse files
committed
Add ownership-moving variants of readonly and readwrite
The additional `PyArray` methods `try_into_readonly` and `try_into_readwrite` allow directly moving the ownership of the `Bound` pointer backing a `PyArray` into the relevant view type. This both a) avoids reference counting overhead and b) allows methods on `PyReadwriteArray` (like `resize`) that _require_ unique pointer referencing, not just unique active borrows to function without the user having to manually drop the base guard.
1 parent 4149c5d commit e247928

3 files changed

Lines changed: 73 additions & 7 deletions

File tree

src/array.rs

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ unsafe fn clone_elements<T: Element>(py: Python<'_>, elems: &[T], data_ptr: &mut
714714

715715
/// Implementation of functionality for [`PyArray<T, D>`].
716716
#[doc(alias = "PyArray")]
717-
pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
717+
pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> + Sized {
718718
/// Access an untyped representation of this array.
719719
fn as_untyped(&self) -> &Bound<'py, PyUntypedArray>;
720720

@@ -956,12 +956,33 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
956956
T: Element,
957957
D: Dimension;
958958

959+
/// Consume `self` into an immutable borrow of the NumPy array
960+
fn try_into_readonly(self) -> Result<PyReadonlyArray<'py, T, D>, BorrowError>
961+
where
962+
T: Element,
963+
D: Dimension;
964+
959965
/// Get an immutable borrow of the NumPy array
960966
fn try_readonly(&self) -> Result<PyReadonlyArray<'py, T, D>, BorrowError>
961967
where
962968
T: Element,
963969
D: Dimension;
964970

971+
/// Consume `self` into an immutable borrow of the NumPy array
972+
///
973+
/// # Panics
974+
///
975+
/// Panics if the allocation backing the array is currently mutably borrowed.
976+
///
977+
/// For a non-panicking variant, use [`try_readonly`][Self::try_into_readonly].
978+
fn into_readonly(self) -> PyReadonlyArray<'py, T, D>
979+
where
980+
T: Element,
981+
D: Dimension,
982+
{
983+
self.try_into_readonly().unwrap()
984+
}
985+
965986
/// Get an immutable borrow of the NumPy array
966987
///
967988
/// # Panics
@@ -977,12 +998,36 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
977998
self.try_readonly().unwrap()
978999
}
9791000

1001+
/// Consume `self` into an mutable borrow of the NumPy array
1002+
fn try_into_readwrite(self) -> Result<PyReadwriteArray<'py, T, D>, BorrowError>
1003+
where
1004+
T: Element,
1005+
D: Dimension;
1006+
9801007
/// Get a mutable borrow of the NumPy array
9811008
fn try_readwrite(&self) -> Result<PyReadwriteArray<'py, T, D>, BorrowError>
9821009
where
9831010
T: Element,
9841011
D: Dimension;
9851012

1013+
/// Consume `self` into an mutable borrow of the NumPy array
1014+
///
1015+
/// # Panics
1016+
///
1017+
/// Panics if the allocation backing the array is currently borrowed or
1018+
/// if the array is [flagged as][flags] not writeable.
1019+
///
1020+
/// For a non-panicking variant, use [`try_readwrite`][Self::try_readwrite].
1021+
///
1022+
/// [flags]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html
1023+
fn into_readwrite(self) -> PyReadwriteArray<'py, T, D>
1024+
where
1025+
T: Element,
1026+
D: Dimension,
1027+
{
1028+
self.try_into_readwrite().unwrap()
1029+
}
1030+
9861031
/// Get a mutable borrow of the NumPy array
9871032
///
9881033
/// # Panics
@@ -1467,20 +1512,36 @@ impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray<T, D>> {
14671512
slice.map(|slc| T::vec_from_slice(self.py(), slc))
14681513
}
14691514

1515+
fn try_into_readonly(self) -> Result<PyReadonlyArray<'py, T, D>, BorrowError>
1516+
where
1517+
T: Element,
1518+
D: Dimension,
1519+
{
1520+
PyReadonlyArray::try_new(self)
1521+
}
1522+
14701523
fn try_readonly(&self) -> Result<PyReadonlyArray<'py, T, D>, BorrowError>
14711524
where
14721525
T: Element,
14731526
D: Dimension,
14741527
{
1475-
PyReadonlyArray::try_new(self.clone())
1528+
self.clone().try_into_readonly()
1529+
}
1530+
1531+
fn try_into_readwrite(self) -> Result<PyReadwriteArray<'py, T, D>, BorrowError>
1532+
where
1533+
T: Element,
1534+
D: Dimension,
1535+
{
1536+
PyReadwriteArray::try_new(self)
14761537
}
14771538

14781539
fn try_readwrite(&self) -> Result<PyReadwriteArray<'py, T, D>, BorrowError>
14791540
where
14801541
T: Element,
14811542
D: Dimension,
14821543
{
1483-
PyReadwriteArray::try_new(self.clone())
1544+
self.clone().try_into_readwrite()
14841545
}
14851546

14861547
unsafe fn as_array(&self) -> ArrayView<'_, T, D>

src/borrow/mod.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,11 @@ where
606606
///
607607
/// Safe wrapper for [`PyArray::resize`].
608608
///
609+
/// Note that as this mutates a pointed-to object, you cannot hold multiple
610+
/// pointers to the array simultaneously; if you begin with a [`PyArray`],
611+
/// you will need to use [`PyArrayMethods::into_readwrite`] instead of
612+
/// the shared-reference variant.
613+
///
609614
/// # Example
610615
///
611616
/// ```
@@ -616,7 +621,7 @@ where
616621
/// let pyarray = PyArray::arange(py, 0, 10, 1);
617622
/// assert_eq!(pyarray.len(), 10);
618623
///
619-
/// let pyarray = pyarray.readwrite();
624+
/// let pyarray = pyarray.into_readwrite();
620625
/// let pyarray = pyarray.resize(100).unwrap();
621626
/// assert_eq!(pyarray.len(), 100);
622627
/// });
@@ -722,7 +727,7 @@ mod tests {
722727
.cast_into::<PyArray1<f64>>()
723728
.unwrap();
724729

725-
let exclusive = array.readwrite();
730+
let exclusive = array.into_readwrite();
726731
assert!(exclusive.resize(100).is_err());
727732
});
728733
}
@@ -732,7 +737,7 @@ mod tests {
732737
Python::attach(|py| {
733738
let array = PyArray::<f64, _>::zeros(py, 10, false);
734739

735-
let exclusive = array.readwrite();
740+
let exclusive = array.into_readwrite();
736741
assert!(exclusive.resize(10).is_ok());
737742
});
738743
}

tests/borrow.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ fn resize_using_exclusive_borrow() {
341341
let array = PyArray::<f64, _>::zeros(py, 3, false);
342342
assert_eq!(array.shape(), [3]);
343343

344-
let mut array = array.readwrite();
344+
let mut array = array.into_readwrite();
345345
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 3]);
346346

347347
let mut array = array.resize(5).unwrap();

0 commit comments

Comments
 (0)