-
Notifications
You must be signed in to change notification settings - Fork 0
perf: closed-form determinant for D=1–4 (#27) #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a8344e5
40cc2b7
159d04e
1308612
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,7 +196,91 @@ impl<const D: usize> Matrix<D> { | |
| Ldlt::factor(self, tol) | ||
| } | ||
|
|
||
| /// Determinant computed via LU decomposition. | ||
| /// Closed-form determinant for dimensions 1–4, bypassing LU factorization. | ||
| /// | ||
| /// Returns `Some(det)` for `D` ∈ {1, 2, 3, 4}, `None` for larger matrices. | ||
| /// This is a `const fn` (Rust 1.94+) and uses fused multiply-add (`mul_add`) | ||
| /// for improved accuracy and performance. | ||
| /// | ||
| /// For a determinant that works for any dimension (falling back to LU for D ≥ 5), | ||
| /// use [`det`](Self::det). | ||
| /// | ||
| /// # Examples | ||
| /// ``` | ||
| /// use la_stack::prelude::*; | ||
| /// | ||
| /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); | ||
| /// assert!((m.det_direct().unwrap() - (-2.0)).abs() <= 1e-12); | ||
| /// | ||
| /// // D ≥ 5 returns None. | ||
| /// assert!(Matrix::<5>::identity().det_direct().is_none()); | ||
| /// ``` | ||
| #[inline] | ||
| #[must_use] | ||
| pub const fn det_direct(&self) -> Option<f64> { | ||
| match D { | ||
| 0 => Some(1.0), | ||
| 1 => Some(self.rows[0][0]), | ||
| 2 => { | ||
| // ad - bc | ||
| Some(self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0]))) | ||
| } | ||
| 3 => { | ||
| // Cofactor expansion on first row. | ||
| let m00 = | ||
| self.rows[1][1].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][1])); | ||
| let m01 = | ||
| self.rows[1][0].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][0])); | ||
| let m02 = | ||
| self.rows[1][0].mul_add(self.rows[2][1], -(self.rows[1][1] * self.rows[2][0])); | ||
| Some( | ||
| self.rows[0][0] | ||
| .mul_add(m00, (-self.rows[0][1]).mul_add(m01, self.rows[0][2] * m02)), | ||
| ) | ||
| } | ||
| 4 => { | ||
| // Cofactor expansion on first row → four 3×3 sub-determinants, | ||
| // each computed inline (closures are not const-compatible). | ||
| let r = &self.rows; | ||
|
|
||
| // Minor M00: rows 1-3, cols 1-3 | ||
| let m00_0 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2])); | ||
| let m00_1 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1])); | ||
| let m00_2 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1])); | ||
| let c00 = r[1][1].mul_add(m00_0, (-r[1][2]).mul_add(m00_1, r[1][3] * m00_2)); | ||
|
|
||
| // Minor M01: rows 1-3, cols 0,2,3 | ||
| let m01_0 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2])); | ||
| let m01_1 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0])); | ||
| let m01_2 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0])); | ||
| let c01 = r[1][0].mul_add(m01_0, (-r[1][2]).mul_add(m01_1, r[1][3] * m01_2)); | ||
|
|
||
| // Minor M02: rows 1-3, cols 0,1,3 | ||
| let m02_0 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1])); | ||
| let m02_1 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0])); | ||
| let m02_2 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0])); | ||
| let c02 = r[1][0].mul_add(m02_0, (-r[1][1]).mul_add(m02_1, r[1][3] * m02_2)); | ||
|
|
||
| // Minor M03: rows 1-3, cols 0,1,2 | ||
| let m03_0 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1])); | ||
| let m03_1 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0])); | ||
| let m03_2 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0])); | ||
| let c03 = r[1][0].mul_add(m03_0, (-r[1][1]).mul_add(m03_1, r[1][2] * m03_2)); | ||
|
|
||
| Some(r[0][0].mul_add( | ||
| c00, | ||
| (-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))), | ||
| )) | ||
| } | ||
| _ => None, | ||
| } | ||
| } | ||
|
|
||
| /// Determinant, using closed-form formulas for D ≤ 4 and LU decomposition for D ≥ 5. | ||
| /// | ||
| /// For D ∈ {1, 2, 3, 4}, this bypasses LU factorization entirely for a significant | ||
| /// speedup (see [`det_direct`](Self::det_direct)). The `tol` parameter is only used | ||
| /// by the LU fallback path for D ≥ 5. | ||
| /// | ||
| /// # Examples | ||
| /// ``` | ||
|
|
@@ -210,9 +294,17 @@ impl<const D: usize> Matrix<D> { | |
| /// ``` | ||
| /// | ||
| /// # Errors | ||
| /// Propagates LU factorization errors (e.g. singular matrices). | ||
| /// Returns [`LaError::NonFinite`] if the result contains NaN or infinity. | ||
| /// For D ≥ 5, propagates LU factorization errors (e.g. [`LaError::Singular`]). | ||
| #[inline] | ||
| pub fn det(self, tol: f64) -> Result<f64, LaError> { | ||
| if let Some(d) = self.det_direct() { | ||
| return if d.is_finite() { | ||
| Ok(d) | ||
| } else { | ||
| Err(LaError::NonFinite { pivot_col: 0 }) | ||
| }; | ||
| } | ||
| self.lu(tol).map(|lu| lu.det()) | ||
|
Comment on lines
294
to
302
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep This now makes 🤖 Prompt for AI Agents |
||
| } | ||
| } | ||
|
|
@@ -328,4 +420,182 @@ mod tests { | |
| gen_public_api_matrix_tests!(3); | ||
| gen_public_api_matrix_tests!(4); | ||
| gen_public_api_matrix_tests!(5); | ||
|
|
||
| // === det_direct tests === | ||
|
|
||
| #[test] | ||
| fn det_direct_d0_is_one() { | ||
| assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d1_returns_element() { | ||
| let m = Matrix::<1>::from_rows([[42.0]]); | ||
| assert_eq!(m.det_direct(), Some(42.0)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d2_known_value() { | ||
| // [[1,2],[3,4]] → det = 1*4 - 2*3 = -2 | ||
| let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); | ||
| assert_abs_diff_eq!(m.det_direct().unwrap(), -2.0, epsilon = 1e-15); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d3_known_value() { | ||
| // Classic 3×3: det = 0 | ||
| let m = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); | ||
| assert_abs_diff_eq!(m.det_direct().unwrap(), 0.0, epsilon = 1e-12); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d3_nonsingular() { | ||
| // [[2,1,0],[0,3,1],[1,0,2]] → det = 2*(6-0) - 1*(0-1) + 0 = 13 | ||
| let m = Matrix::<3>::from_rows([[2.0, 1.0, 0.0], [0.0, 3.0, 1.0], [1.0, 0.0, 2.0]]); | ||
| assert_abs_diff_eq!(m.det_direct().unwrap(), 13.0, epsilon = 1e-12); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d4_identity() { | ||
| assert_abs_diff_eq!( | ||
| Matrix::<4>::identity().det_direct().unwrap(), | ||
| 1.0, | ||
| epsilon = 1e-15 | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d4_known_value() { | ||
| // Diagonal matrix: det = product of diagonal entries. | ||
| let mut rows = [[0.0f64; 4]; 4]; | ||
| rows[0][0] = 2.0; | ||
| rows[1][1] = 3.0; | ||
| rows[2][2] = 5.0; | ||
| rows[3][3] = 7.0; | ||
| let m = Matrix::<4>::from_rows(rows); | ||
| assert_abs_diff_eq!(m.det_direct().unwrap(), 210.0, epsilon = 1e-12); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d5_returns_none() { | ||
| assert_eq!(Matrix::<5>::identity().det_direct(), None); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_d8_returns_none() { | ||
| assert_eq!(Matrix::<8>::zero().det_direct(), None); | ||
| } | ||
|
|
||
| macro_rules! gen_det_direct_agrees_with_lu { | ||
| ($d:literal) => { | ||
| paste! { | ||
| #[test] | ||
| #[allow(clippy::cast_precision_loss)] // r, c, D are tiny integers | ||
| fn [<det_direct_agrees_with_lu_ $d d>]() { | ||
| // Well-conditioned matrix: diagonally dominant. | ||
| let mut rows = [[0.0f64; $d]; $d]; | ||
| for r in 0..$d { | ||
| for c in 0..$d { | ||
| rows[r][c] = if r == c { | ||
| (r as f64) + f64::from($d) + 1.0 | ||
| } else { | ||
| 0.1 / ((r + c + 1) as f64) | ||
| }; | ||
| } | ||
| } | ||
| let m = Matrix::<$d>::from_rows(rows); | ||
| let direct = m.det_direct().unwrap(); | ||
| let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det(); | ||
| let eps = lu_det.abs().mul_add(1e-12, 1e-12); | ||
| assert_abs_diff_eq!(direct, lu_det, epsilon = eps); | ||
| } | ||
| } | ||
| }; | ||
| } | ||
|
|
||
| gen_det_direct_agrees_with_lu!(1); | ||
| gen_det_direct_agrees_with_lu!(2); | ||
| gen_det_direct_agrees_with_lu!(3); | ||
| gen_det_direct_agrees_with_lu!(4); | ||
|
|
||
| #[test] | ||
| fn det_direct_identity_all_dims() { | ||
| assert_abs_diff_eq!( | ||
| Matrix::<1>::identity().det_direct().unwrap(), | ||
| 1.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| assert_abs_diff_eq!( | ||
| Matrix::<2>::identity().det_direct().unwrap(), | ||
| 1.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| assert_abs_diff_eq!( | ||
| Matrix::<3>::identity().det_direct().unwrap(), | ||
| 1.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| assert_abs_diff_eq!( | ||
| Matrix::<4>::identity().det_direct().unwrap(), | ||
| 1.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_zero_matrix() { | ||
| assert_abs_diff_eq!( | ||
| Matrix::<2>::zero().det_direct().unwrap(), | ||
| 0.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| assert_abs_diff_eq!( | ||
| Matrix::<3>::zero().det_direct().unwrap(), | ||
| 0.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| assert_abs_diff_eq!( | ||
| Matrix::<4>::zero().det_direct().unwrap(), | ||
| 0.0, | ||
| epsilon = 0.0 | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_returns_nonfinite_error_for_nan_d2() { | ||
| let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]); | ||
| assert_eq!( | ||
| m.det(DEFAULT_PIVOT_TOL), | ||
| Err(LaError::NonFinite { pivot_col: 0 }) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_returns_nonfinite_error_for_inf_d3() { | ||
| let m = | ||
| Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); | ||
| assert_eq!( | ||
| m.det(DEFAULT_PIVOT_TOL), | ||
| Err(LaError::NonFinite { pivot_col: 0 }) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_is_const_evaluable_d2() { | ||
| // Const evaluation proves the function is truly const fn. | ||
| const DET: Option<f64> = { | ||
| let m = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0]]); | ||
| m.det_direct() | ||
| }; | ||
| assert_eq!(DET, Some(1.0)); | ||
| } | ||
|
|
||
| #[test] | ||
| fn det_direct_is_const_evaluable_d3() { | ||
| const DET: Option<f64> = { | ||
| let m = Matrix::<3>::from_rows([[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]]); | ||
| m.det_direct() | ||
| }; | ||
| assert_eq!(DET, Some(30.0)); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Seeds for failure cases proptest has generated in the past. It is | ||
| # automatically read and these particular cases re-run before any | ||
| # novel cases are generated. | ||
| # | ||
| # It is recommended to check this file in to source control so that | ||
| # everyone who runs the test benefits from these saved cases. | ||
| cc a2989aa781cea4ab806e5c6242f6493e5aa1af226b439a7f6eb8c4257cd48dde # shrinks to diag = [-48.0, -29.1, -36.1], b_arr = [0.0, 0.0, 0.0] | ||
| cc d5cecfb996d10952fd516ed975202ebdf27ccd95649c66e671c348d7d0e4de92 # shrinks to diag = [-10.5, -52.4, 6.7, -15.4], b_arr = [0.0, 0.0, 0.0, 0.0] |
Uh oh!
There was an error while loading. Please reload this page.