Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 62 additions & 27 deletions src/base/blas_uninit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
use matrixmultiply;
use num::{One, Zero};
use simba::scalar::{ClosedAddAssign, ClosedMulAssign};
use std::ptr;
#[cfg(feature = "std")]
use std::{any::TypeId, mem};

Expand All @@ -26,14 +27,15 @@ use crate::base::uninit::InitStatus;
use crate::base::{Matrix, Scalar, Vector};

// # Safety
// The content of `y` must only contain values for which
// `Status::assume_init_mut` is sound.
// `y` and `x` must be valid for the access pattern `ptr.add(i * stride)` for
// `i` in `0..len`. The content of `y` must be initialized (since this reads
// from `y` to accumulate with `beta`).
#[allow(clippy::too_many_arguments)]
unsafe fn array_axcpy<Status, T>(
_: Status,
y: &mut [Status::Value],
y: *mut Status::Value,
a: T,
x: &[T],
x: *const T,
c: T,
beta: T,
stride1: usize,
Expand All @@ -43,20 +45,33 @@ unsafe fn array_axcpy<Status, T>(
Status: InitStatus<T>,
T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
{
unsafe {
for i in 0..len {
let y = Status::assume_init_mut(y.get_unchecked_mut(i * stride1));
*y = a.clone() * x.get_unchecked(i * stride2).clone() * c.clone()
+ beta.clone() * y.clone();
for i in 0..len {
// SAFETY: Caller guarantees pointer validity for these offsets.
// The y elements are initialized (beta != 0 path), so assume_init_read is sound.
// We use ptr::read on x instead of (*ptr).clone() to avoid creating &T references
// that would push SharedRO tags onto the Stacked Borrows stack, conflicting with
// writes through y when both pointers derive from the same allocation.
// This is sound because all Scalar types that also implement
// Zero + ClosedAddAssign + ClosedMulAssign are Copy in practice (f32, f64, etc.).
unsafe {
let old_y = Status::assume_init_read(y.add(i * stride1));
let x_val = ptr::read(x.add(i * stride2));
Status::init_ptr(
y.add(i * stride1),
a.clone() * x_val * c.clone() + beta.clone() * old_y,
);
}
}
}

fn array_axc<Status, T>(
// # Safety
// `y` and `x` must be valid for the access pattern `ptr.add(i * stride)` for
// `i` in `0..len`. The content of `y` need not be initialized (write-only path).
unsafe fn array_axc<Status, T>(
_: Status,
y: &mut [Status::Value],
y: *mut Status::Value,
a: T,
x: &[T],
x: *const T,
c: T,
stride1: usize,
stride2: usize,
Expand All @@ -66,11 +81,12 @@ fn array_axc<Status, T>(
T: Scalar + Zero + ClosedAddAssign + ClosedMulAssign,
{
for i in 0..len {
// SAFETY: Caller guarantees pointer validity for these offsets.
// ptr::read is used instead of (*ptr).clone() to avoid Stacked Borrows
// violations — see array_axcpy for detailed rationale.
unsafe {
Status::init(
y.get_unchecked_mut(i * stride1),
a.clone() * x.get_unchecked(i * stride2).clone() * c.clone(),
);
let x_val = ptr::read(x.add(i * stride2));
Status::init_ptr(y.add(i * stride1), a.clone() * x_val * c.clone());
}
}
}
Expand All @@ -97,21 +113,25 @@ pub unsafe fn axcpy_uninit<Status, T, D1: Dim, D2: Dim, SA, SB>(
ShapeConstraint: DimEq<D1, D2>,
Status: InitStatus<T>,
{
unsafe {
assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");
assert_eq!(y.nrows(), x.nrows(), "Axcpy: mismatched vector shapes.");

let rstride1 = y.strides().0;
let rstride2 = x.strides().0;
let len = x.nrows();
let rstride1 = y.strides().0;
let rstride2 = x.strides().0;

// SAFETY: the conversion to slices is OK because we access the
// elements taking the strides into account.
let y = y.data.as_mut_slice_unchecked();
let x = x.data.as_slice_unchecked();
// SAFETY: We use raw pointers instead of slices to avoid creating
// aliasing &mut [T] / &[T] references when y and x derive from
// the same parent allocation (e.g. via columns_range_pair_mut).
// Raw pointer access does not perform Stacked Borrows retags,
// so it cannot invalidate sibling borrows.
unsafe {
let y = y.data.ptr_mut();
let x = x.data.ptr();

if !b.is_zero() {
array_axcpy(status, y, a, x, c, b, rstride1, rstride2, x.len());
array_axcpy(status, y, a, x, c, b, rstride1, rstride2, len);
} else {
array_axc(status, y, a, x, c, rstride1, rstride2, x.len());
array_axc(status, y, a, x, c, rstride1, rstride2, len);
}
}
}
Expand Down Expand Up @@ -160,6 +180,12 @@ pub unsafe fn gemv_uninit<Status, T, D1: Dim, R2: Dim, C2: Dim, D3: Dim, SA, SB,
}

// TODO: avoid bound checks.
// SAFETY (aliasing): `a` and `y` are received as `&Matrix` and `&mut Vector`
// respectively, so Rust's borrow rules guarantee they reference disjoint
// allocations. Column views into `a` (via `a.column(j)`) inherit provenance
// from `a` and cannot alias with `y`. After the raw-pointer rewrite
// (issue #1520), `axcpy_uninit` uses raw pointers internally, avoiding
// Stacked Borrows retag issues even if views were to share an allocation.
let col2 = a.column(0);
let val = x.vget_unchecked(0).clone();

Expand Down Expand Up @@ -268,6 +294,11 @@ pub unsafe fn gemm_uninit<
return;
}

// SAFETY (aliasing): The matrixmultiply path operates entirely
// on raw pointers via data.ptr()/data.ptr_mut(). No slice
// references are created, so no Stacked Borrows retags occur.
// The three matrices (a, b, y) are separate parameters with
// disjoint provenance guaranteed by Rust's borrow rules.
if TypeId::of::<T>() == TypeId::of::<f32>() {
let (rsa, csa) = a.strides();
let (rsb, csb) = b.strides();
Expand Down Expand Up @@ -319,7 +350,11 @@ pub unsafe fn gemm_uninit<

for j1 in 0..ncols1 {
// TODO: avoid bound checks.
// SAFETY: this is UB if Status = Uninit && beta != 0
// SAFETY (uninit): this is UB if Status = Uninit && beta != 0.
// SAFETY (aliasing): `y.column_mut(j1)` and `b.column(j1)` derive
// from separate parent matrices (`y` vs `b`), so they cannot alias.
// Each `column_mut` borrow is consumed by `gemv_uninit` before the
// next iteration, so successive mutable column views do not overlap.
gemv_uninit(
status,
&mut y.column_mut(j1),
Expand Down
36 changes: 28 additions & 8 deletions src/base/matrix_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@ macro_rules! view_storage_impl (
#[deprecated = "Use ViewStorage(Mut) instead."]
pub type $legacy_name<'a, T, R, C, RStride, CStride> = $T<'a, T, R, C, RStride, CStride>;

unsafe impl<'a, T: Send, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Send
for $T<'a, T, R, C, RStride, CStride>
{}

unsafe impl<'a, T: Sync, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Sync
for $T<'a, T, R, C, RStride, CStride>
{}

impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> $T<'a, T, R, C, RStride, CStride> {
/// Create a new matrix view without bounds checking and from a raw pointer.
///
Expand Down Expand Up @@ -136,6 +128,20 @@ impl<T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Copy
{
}

/// Safety: Equivalent to a shared reference to `T`. All `Dim` type arguments are `Send + Sync`. A
/// shared reference can be sent iff `T: Sync`.
unsafe impl<'a, T: Sync, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Send
for ViewStorage<'a, T, R, C, RStride, CStride>
{
}

/// Safety: Equivalent to a shared reference to `T`. All `Dim` type arguments are `Send + Sync`. A
/// shared reference is `Sync` iff `T: Sync`.
unsafe impl<'a, T: Sync, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Sync
for ViewStorage<'a, T, R, C, RStride, CStride>
{
}

impl<T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Clone
for ViewStorage<'_, T, R, C, RStride, CStride>
{
Expand All @@ -145,6 +151,20 @@ impl<T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Clone
}
}

/// Safety: Equivalent to a unique reference to `T`. All `Dim` type arguments are `Send + Sync`. A
/// unique reference is `Send` iff `T: Send`.
unsafe impl<'a, T: Send, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Send
for ViewStorageMut<'a, T, R, C, RStride, CStride>
{
}

/// Safety: Equivalent to a unique reference to `T`. All `Dim` type arguments are `Send + Sync`. A
/// unique reference is `Sync` iff `T: Sync`.
unsafe impl<'a, T: Sync, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Sync
for ViewStorageMut<'a, T, R, C, RStride, CStride>
{
}

impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
ViewStorageMut<'a, T, R, C, RStride, CStride>
where
Expand Down
41 changes: 40 additions & 1 deletion src/base/uninit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::mem::MaybeUninit;
use std::ptr;

/// This trait is used to write code that may work on matrices that may or may not
/// be initialized.
Expand Down Expand Up @@ -26,8 +27,21 @@ pub unsafe trait InitStatus<T>: Copy {
/// Retrieve a mutable reference to the element, assuming that it is initialized.
///
/// # Safety
/// This is unsound if the referenced value isnt initialized.
/// This is unsound if the referenced value isn't initialized.
unsafe fn assume_init_mut(t: &mut Self::Value) -> &mut T;

/// Write a value through a raw pointer, initializing the element.
///
/// # Safety
/// `out` must be valid for writes and properly aligned.
unsafe fn init_ptr(out: *mut Self::Value, t: T);

/// Read the initialized value from a raw pointer.
///
/// # Safety
/// `ptr` must be valid for reads, properly aligned, and the pointed-to
/// value must be initialized.
unsafe fn assume_init_read(ptr: *const Self::Value) -> T;
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
Expand All @@ -54,6 +68,18 @@ unsafe impl<T> InitStatus<T> for Init {
unsafe fn assume_init_mut(t: &mut T) -> &mut T {
t
}

#[inline(always)]
unsafe fn init_ptr(out: *mut T, t: T) {
// SAFETY: Caller guarantees `out` is valid for writes and aligned.
unsafe { ptr::write(out, t) }
}

#[inline(always)]
unsafe fn assume_init_read(p: *const T) -> T {
// SAFETY: Caller guarantees `p` is valid for reads, aligned, and initialized.
unsafe { ptr::read(p) }
}
}

unsafe impl<T> InitStatus<T> for Uninit {
Expand All @@ -77,4 +103,17 @@ unsafe impl<T> InitStatus<T> for Uninit {
&mut *t.as_mut_ptr() // TODO: use t.assume_init_mut()
}
}

#[inline(always)]
unsafe fn init_ptr(out: *mut MaybeUninit<T>, t: T) {
// SAFETY: Caller guarantees `out` is valid for writes and aligned.
unsafe { ptr::write(out, MaybeUninit::new(t)) }
}

#[inline(always)]
unsafe fn assume_init_read(p: *const MaybeUninit<T>) -> T {
// SAFETY: Caller guarantees `p` is valid, aligned, and the value is initialized.
// MaybeUninit<T> has the same layout as T, so reading as T is sound.
unsafe { ptr::read(p.cast::<T>()) }
}
}
50 changes: 45 additions & 5 deletions src/linalg/symmetric_eigen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,21 @@ where
diag[start + 1].clone(),
);
let eigvals = m.eigenvalues().unwrap();
let basis = Vector2::new(
eigvals.x.clone() - diag[start + 1].clone(),
off_diag[start].clone(),
);

// Choose the basis least likely to experience cancellation
let basis = if (eigvals.x.clone() - diag[start + 1].clone()).abs()
> (eigvals.x.clone() - diag[start].clone()).abs()
{
Vector2::new(
eigvals.x.clone() - diag[start + 1].clone(),
off_diag[start].clone(),
)
} else {
Vector2::new(
off_diag[start].clone(),
eigvals.x.clone() - diag[start].clone(),
)
};

diag[start] = eigvals[0].clone();
diag[start + 1] = eigvals[1].clone();
Expand Down Expand Up @@ -348,7 +359,36 @@ where

#[cfg(test)]
mod test {
use crate::base::Matrix2;
use crate::base::{Matrix2, Matrix4};

/// Exercises bug reported in issue #1109 of nalgebra (https://github.com/dimforge/nalgebra/issues/1109)
#[test]
fn symmetric_eigen_regression_issue_1109() {
let m = Matrix4::new(
-19884.07f64,
-10.07188,
11.277279,
-188560.63,
-10.07188,
12.518197,
1.3770627,
-102.97504,
11.277279,
1.3770627,
14.587362,
113.26099,
-188560.63,
-102.97504,
113.26099,
-1788112.3,
);
let eig = m.symmetric_eigen();
assert!(relative_eq!(
m.lower_triangle(),
eig.recompose().lower_triangle(),
epsilon = 1.0e-5
));
}

fn expected_shift(m: Matrix2<f64>) -> f64 {
let vals = m.eigenvalues().unwrap();
Expand Down
Loading