Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ name = "vs_linalg"
harness = false
required-features = [ "bench" ]

[profile.release]
lto = "fat"
codegen-units = 1

[lints.rust]
unsafe_code = "forbid"
missing_docs = "warn"
Expand Down
10 changes: 10 additions & 0 deletions benches/vs_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ macro_rules! gen_vs_linalg_benches_for_dim {
});
});

// === Determinant via det_direct (closed-form, no LU) ===
[<group_d $d>].bench_function("la_stack_det_direct", |bencher| {
bencher.iter(|| {
let det = black_box(a)
.det(la_stack::DEFAULT_PIVOT_TOL)
.expect("matrix should be non-singular");
black_box(det);
});
});

// === LU factorization ===
[<group_d $d>].bench_function("la_stack_lu", |bencher| {
bencher.iter(|| {
Expand Down
274 changes: 272 additions & 2 deletions src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,91 @@ impl<const D: usize> Matrix<D> {
Ldlt::factor(self, tol)
}

/// Determinant computed via LU decomposition.
/// Closed-form determinant for dimensions 1–4, bypassing LU factorization.
///
/// Returns `Some(det)` for `D` ∈ {1, 2, 3, 4}, `None` for larger matrices.
/// This is a `const fn` (Rust 1.94+) and uses fused multiply-add (`mul_add`)
/// for improved accuracy and performance.
///
/// For a determinant that works for any dimension (falling back to LU for D ≥ 5),
/// use [`det`](Self::det).
///
/// # Examples
/// ```
/// use la_stack::prelude::*;
///
/// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
/// assert!((m.det_direct().unwrap() - (-2.0)).abs() <= 1e-12);
///
/// // D ≥ 5 returns None.
/// assert!(Matrix::<5>::identity().det_direct().is_none());
/// ```
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
#[inline]
#[must_use]
pub const fn det_direct(&self) -> Option<f64> {
match D {
0 => Some(1.0),
1 => Some(self.rows[0][0]),
2 => {
// ad - bc
Some(self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0])))
}
3 => {
// Cofactor expansion on first row.
let m00 =
self.rows[1][1].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][1]));
let m01 =
self.rows[1][0].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][0]));
let m02 =
self.rows[1][0].mul_add(self.rows[2][1], -(self.rows[1][1] * self.rows[2][0]));
Some(
self.rows[0][0]
.mul_add(m00, (-self.rows[0][1]).mul_add(m01, self.rows[0][2] * m02)),
)
}
4 => {
// Cofactor expansion on first row → four 3×3 sub-determinants,
// each computed inline (closures are not const-compatible).
let r = &self.rows;

// Minor M00: rows 1-3, cols 1-3
let m00_0 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2]));
let m00_1 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1]));
let m00_2 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1]));
let c00 = r[1][1].mul_add(m00_0, (-r[1][2]).mul_add(m00_1, r[1][3] * m00_2));

// Minor M01: rows 1-3, cols 0,2,3
let m01_0 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2]));
let m01_1 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0]));
let m01_2 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0]));
let c01 = r[1][0].mul_add(m01_0, (-r[1][2]).mul_add(m01_1, r[1][3] * m01_2));

// Minor M02: rows 1-3, cols 0,1,3
let m02_0 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1]));
let m02_1 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0]));
let m02_2 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0]));
let c02 = r[1][0].mul_add(m02_0, (-r[1][1]).mul_add(m02_1, r[1][3] * m02_2));

// Minor M03: rows 1-3, cols 0,1,2
let m03_0 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1]));
let m03_1 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0]));
let m03_2 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0]));
let c03 = r[1][0].mul_add(m03_0, (-r[1][1]).mul_add(m03_1, r[1][2] * m03_2));

Some(r[0][0].mul_add(
c00,
(-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))),
))
}
_ => None,
}
}

/// Determinant, using closed-form formulas for D ≤ 4 and LU decomposition for D ≥ 5.
///
/// For D ∈ {1, 2, 3, 4}, this bypasses LU factorization entirely for a significant
/// speedup (see [`det_direct`](Self::det_direct)). The `tol` parameter is only used
/// by the LU fallback path for D ≥ 5.
///
/// # Examples
/// ```
Expand All @@ -210,9 +294,17 @@ impl<const D: usize> Matrix<D> {
/// ```
///
/// # Errors
/// Propagates LU factorization errors (e.g. singular matrices).
/// Returns [`LaError::NonFinite`] if the result contains NaN or infinity.
/// For D ≥ 5, propagates LU factorization errors (e.g. [`LaError::Singular`]).
#[inline]
pub fn det(self, tol: f64) -> Result<f64, LaError> {
if let Some(d) = self.det_direct() {
return if d.is_finite() {
Ok(d)
} else {
Err(LaError::NonFinite { pivot_col: 0 })
};
}
self.lu(tol).map(|lu| lu.det())
Comment on lines 294 to 302
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep det() behavior consistent for singular matrices.

This now makes Matrix::det() dimension-dependent: Matrix::<2>::zero().det(..) returns Ok(0.0) from the direct path, while Matrix::<5>::zero().det(..) still returns Err(LaError::Singular) via LU. That is a breaking behavioral change for callers that use the error to detect degeneracy.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/matrix.rs` around lines 300 - 308, det() currently returns Ok(0.0) when
det_direct() yields a finite 0 for small matrices but defers to lu(tol) (which
returns Err(LaError::Singular)) for larger ones; make behavior consistent by
treating a finite zero determinant from det_direct() as singular instead of Ok,
i.e. in Matrix::det() after getting d from det_direct() return
Err(LaError::Singular) (or the same variant/structure used by lu/tol path) when
d == 0.0 (and preserve the existing NonFinite handling for non-finite d),
leaving the lu(tol).map(|lu| lu.det()) branch unchanged.

}
}
Expand Down Expand Up @@ -328,4 +420,182 @@ mod tests {
gen_public_api_matrix_tests!(3);
gen_public_api_matrix_tests!(4);
gen_public_api_matrix_tests!(5);

// === det_direct tests ===

#[test]
fn det_direct_d0_is_one() {
assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0));
}

#[test]
fn det_direct_d1_returns_element() {
let m = Matrix::<1>::from_rows([[42.0]]);
assert_eq!(m.det_direct(), Some(42.0));
}

#[test]
fn det_direct_d2_known_value() {
// [[1,2],[3,4]] → det = 1*4 - 2*3 = -2
let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
assert_abs_diff_eq!(m.det_direct().unwrap(), -2.0, epsilon = 1e-15);
}

#[test]
fn det_direct_d3_known_value() {
// Classic 3×3: det = 0
let m = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
assert_abs_diff_eq!(m.det_direct().unwrap(), 0.0, epsilon = 1e-12);
}

#[test]
fn det_direct_d3_nonsingular() {
// [[2,1,0],[0,3,1],[1,0,2]] → det = 2*(6-0) - 1*(0-1) + 0 = 13
let m = Matrix::<3>::from_rows([[2.0, 1.0, 0.0], [0.0, 3.0, 1.0], [1.0, 0.0, 2.0]]);
assert_abs_diff_eq!(m.det_direct().unwrap(), 13.0, epsilon = 1e-12);
}

#[test]
fn det_direct_d4_identity() {
assert_abs_diff_eq!(
Matrix::<4>::identity().det_direct().unwrap(),
1.0,
epsilon = 1e-15
);
}

#[test]
fn det_direct_d4_known_value() {
// Diagonal matrix: det = product of diagonal entries.
let mut rows = [[0.0f64; 4]; 4];
rows[0][0] = 2.0;
rows[1][1] = 3.0;
rows[2][2] = 5.0;
rows[3][3] = 7.0;
let m = Matrix::<4>::from_rows(rows);
assert_abs_diff_eq!(m.det_direct().unwrap(), 210.0, epsilon = 1e-12);
}

#[test]
fn det_direct_d5_returns_none() {
assert_eq!(Matrix::<5>::identity().det_direct(), None);
}

#[test]
fn det_direct_d8_returns_none() {
assert_eq!(Matrix::<8>::zero().det_direct(), None);
}

macro_rules! gen_det_direct_agrees_with_lu {
($d:literal) => {
paste! {
#[test]
#[allow(clippy::cast_precision_loss)] // r, c, D are tiny integers
fn [<det_direct_agrees_with_lu_ $d d>]() {
// Well-conditioned matrix: diagonally dominant.
let mut rows = [[0.0f64; $d]; $d];
for r in 0..$d {
for c in 0..$d {
rows[r][c] = if r == c {
(r as f64) + f64::from($d) + 1.0
} else {
0.1 / ((r + c + 1) as f64)
};
}
}
let m = Matrix::<$d>::from_rows(rows);
let direct = m.det_direct().unwrap();
let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det();
let eps = lu_det.abs().mul_add(1e-12, 1e-12);
assert_abs_diff_eq!(direct, lu_det, epsilon = eps);
}
}
};
}

gen_det_direct_agrees_with_lu!(1);
gen_det_direct_agrees_with_lu!(2);
gen_det_direct_agrees_with_lu!(3);
gen_det_direct_agrees_with_lu!(4);

#[test]
fn det_direct_identity_all_dims() {
assert_abs_diff_eq!(
Matrix::<1>::identity().det_direct().unwrap(),
1.0,
epsilon = 0.0
);
assert_abs_diff_eq!(
Matrix::<2>::identity().det_direct().unwrap(),
1.0,
epsilon = 0.0
);
assert_abs_diff_eq!(
Matrix::<3>::identity().det_direct().unwrap(),
1.0,
epsilon = 0.0
);
assert_abs_diff_eq!(
Matrix::<4>::identity().det_direct().unwrap(),
1.0,
epsilon = 0.0
);
}

#[test]
fn det_direct_zero_matrix() {
assert_abs_diff_eq!(
Matrix::<2>::zero().det_direct().unwrap(),
0.0,
epsilon = 0.0
);
assert_abs_diff_eq!(
Matrix::<3>::zero().det_direct().unwrap(),
0.0,
epsilon = 0.0
);
assert_abs_diff_eq!(
Matrix::<4>::zero().det_direct().unwrap(),
0.0,
epsilon = 0.0
);
}

#[test]
fn det_returns_nonfinite_error_for_nan_d2() {
let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]);
assert_eq!(
m.det(DEFAULT_PIVOT_TOL),
Err(LaError::NonFinite { pivot_col: 0 })
);
}

#[test]
fn det_returns_nonfinite_error_for_inf_d3() {
let m =
Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
assert_eq!(
m.det(DEFAULT_PIVOT_TOL),
Err(LaError::NonFinite { pivot_col: 0 })
);
}

#[test]
fn det_direct_is_const_evaluable_d2() {
// Const evaluation proves the function is truly const fn.
const DET: Option<f64> = {
let m = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0]]);
m.det_direct()
};
assert_eq!(DET, Some(1.0));
}

#[test]
fn det_direct_is_const_evaluable_d3() {
const DET: Option<f64> = {
let m = Matrix::<3>::from_rows([[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]]);
m.det_direct()
};
assert_eq!(DET, Some(30.0));
}
}
8 changes: 8 additions & 0 deletions tests/proptest_matrix.proptest-regressions
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc a2989aa781cea4ab806e5c6242f6493e5aa1af226b439a7f6eb8c4257cd48dde # shrinks to diag = [-48.0, -29.1, -36.1], b_arr = [0.0, 0.0, 0.0]
cc d5cecfb996d10952fd516ed975202ebdf27ccd95649c66e671c348d7d0e4de92 # shrinks to diag = [-10.5, -52.4, 6.7, -15.4], b_arr = [0.0, 0.0, 0.0, 0.0]
6 changes: 5 additions & 1 deletion tests/proptest_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ macro_rules! gen_public_api_matrix_proptests {
}
acc
};
assert_abs_diff_eq!(det, expected_det, epsilon = 1e-12);
// The closed-form and LU paths evaluate the diagonal
// product in different orders, so we allow a few ULPs of
// relative error (floor 1e-12 for near-zero determinants).
let eps = expected_det.abs().mul_add(1e-12, 1e-12);
assert_abs_diff_eq!(det, expected_det, epsilon = eps);

let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
let b = Vector::<$d>::new(b_arr);
Expand Down
Loading