diff --git a/AGENTS.md b/AGENTS.md index cf5e0f2..8f77f26 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -28,7 +28,7 @@ When making changes in this repo, prioritize (in order): - Pre-commit validation: `just ci` - Python tests: `just test-python` - Run a single test (by name filter): `cargo test solve_2x2_basic` (or the full path: `cargo test lu::tests::solve_2x2_basic`) -- Run examples: `just examples` (or `cargo run --example det_5x5` / `cargo run --example solve_5x5`) +- Run examples: `just examples` (or `cargo run --example det_5x5` / `cargo run --example solve_5x5` / `cargo run --example const_det_4x4`) - Spell check: `just spell-check` (uses `typos.toml` at repo root; add false positives to `[default.extend-words]`) ## Code structure (big picture) @@ -37,7 +37,7 @@ When making changes in this repo, prioritize (in order): - The linear algebra implementation is split across: - `src/lib.rs`: crate root + shared items (`LaError`, `DEFAULT_SINGULAR_TOL`, `DEFAULT_PIVOT_TOL`) + re-exports - `src/vector.rs`: `Vector` (`[f64; D]`) - - `src/matrix.rs`: `Matrix` (`[[f64; D]; D]`) + helpers (`get`, `set`, `inf_norm`, `det`) + - `src/matrix.rs`: `Matrix` (`[[f64; D]; D]`) + helpers (`get`, `set`, `inf_norm`, `det`, `det_direct`) - `src/lu.rs`: `Lu` factorization with partial pivoting (`solve_vec`, `det`) - `src/ldlt.rs`: `Ldlt` factorization without pivoting for symmetric SPD/PSD matrices (`solve_vec`, `det`) - A minimal `justfile` exists for common workflows (see `just --list`). diff --git a/Cargo.toml b/Cargo.toml index 1f89fde..4c2af4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,10 @@ name = "vs_linalg" harness = false required-features = [ "bench" ] +[profile.release] +lto = "fat" +codegen-units = 1 + [lints.rust] unsafe_code = "forbid" missing_docs = "warn" diff --git a/README.md b/README.md index 0043247..da85d72 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ while keeping the API intentionally small and explicit. - ✅ `Copy` types where possible - ✅ Const-generic dimensions (no dynamic sizes) +- ✅ `const fn` where possible (compile-time evaluation of determinants, dot products, etc.) - ✅ Explicit algorithms (LU, solve, determinant) - ✅ No runtime dependencies (dev-dependencies are for contributors only) - ✅ Stack storage only (no heap allocation in core types) @@ -103,12 +104,38 @@ let det = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap().det(); assert!((det - 1.0).abs() <= 1e-12); ``` +## ⚡ Compile-time determinants (D ≤ 4) + +`det_direct()` is a `const fn` providing closed-form determinants for D=0–4, +using fused multiply-add where applicable. `Matrix::<0>::zero().det_direct()` +returns `Some(1.0)` (the empty-product convention). For D=1–4, cofactor +expansion bypasses LU factorization entirely. This enables compile-time +evaluation when inputs are known at compile time: + +```rust +use la_stack::prelude::*; + +// Evaluated entirely at compile time — no runtime cost. +const DET: Option = { + let m = Matrix::<3>::from_rows([ + [2.0, 0.0, 0.0], + [0.0, 3.0, 0.0], + [0.0, 0.0, 5.0], + ]); + m.det_direct() +}; +assert_eq!(DET, Some(30.0)); +``` + +The public `det()` method automatically dispatches through the closed-form path +for D ≤ 4 and falls back to LU for D ≥ 5 — no API change needed. + ## 🧩 API at a glance | Type | Storage | Purpose | Key methods | |---|---|---|---| | `Vector` | `[f64; D]` | Fixed-length vector | `new`, `zero`, `dot`, `norm2_sq` | -| `Matrix` | `[[f64; D]; D]` | Fixed-size square matrix | `from_rows`, `zero`, `identity`, `lu`, `ldlt`, `det` | +| `Matrix` | `[[f64; D]; D]` | Fixed-size square matrix | `from_rows`, `zero`, `identity`, `lu`, `ldlt`, `det`, `det_direct` | | `Lu` | `Matrix` + pivot array | Factorization for solves/det | `solve_vec`, `det` | | `Ldlt` | `Matrix` | Factorization for symmetric SPD/PSD solves/det | `solve_vec`, `det` | @@ -123,6 +150,7 @@ just examples # or: cargo run --example solve_5x5 cargo run --example det_5x5 +cargo run --example const_det_4x4 ``` ## 🤝 Contributing diff --git a/benches/vs_linalg.rs b/benches/vs_linalg.rs index 4493c91..66505a8 100644 --- a/benches/vs_linalg.rs +++ b/benches/vs_linalg.rs @@ -179,6 +179,16 @@ macro_rules! gen_vs_linalg_benches_for_dim { }); }); + // === Determinant via det() (closed-form for D≤4, LU for D≥5) === + [].bench_function("la_stack_det", |bencher| { + bencher.iter(|| { + let det = black_box(a) + .det(la_stack::DEFAULT_PIVOT_TOL) + .expect("matrix should be non-singular"); + black_box(det); + }); + }); + // === LU factorization === [].bench_function("la_stack_lu", |bencher| { bencher.iter(|| { diff --git a/examples/const_det_4x4.rs b/examples/const_det_4x4.rs new file mode 100644 index 0000000..4b3bb75 --- /dev/null +++ b/examples/const_det_4x4.rs @@ -0,0 +1,36 @@ +//! Compile-time 4×4 determinant via `det_direct()`. +//! +//! Because `det_direct` is a `const fn` (Rust 1.94+), the determinant is +//! evaluated entirely at compile time — zero runtime cost. + +use la_stack::prelude::*; + +/// An example 4×4 matrix with small integer entries. +const MAT: Matrix<4> = Matrix::<4>::from_rows([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [2.0, 6.0, 1.0, 5.0], + [3.0, 8.0, 2.0, 9.0], +]); + +/// Determinant computed at compile time. +const DET: f64 = match MAT.det_direct() { + Some(d) => d, + None => panic!("det_direct only supports D <= 4"), +}; + +fn main() { + println!("4×4 matrix:"); + for r in 0..4 { + print!(" ["); + for c in 0..4 { + if c > 0 { + print!(", "); + } + print!("{:5.1}", MAT.get(r, c).unwrap()); + } + println!("]"); + } + println!(); + println!("det (computed at compile time) = {DET}"); +} diff --git a/justfile b/justfile index be9a895..b238f7c 100644 --- a/justfile +++ b/justfile @@ -187,6 +187,7 @@ doc-check: examples: cargo run --quiet --example det_5x5 cargo run --quiet --example solve_5x5 + cargo run --quiet --example const_det_4x4 # Fix (mutating): apply formatters/auto-fixes fix: toml-fmt fmt python-fix shell-fmt markdown-fix yaml-fix diff --git a/src/matrix.rs b/src/matrix.rs index 06b57af..fb0f591 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -196,7 +196,85 @@ impl Matrix { Ldlt::factor(self, tol) } - /// Determinant computed via LU decomposition. + /// Closed-form determinant for dimensions 0–4, bypassing LU factorization. + /// + /// Returns `Some(det)` for `D` ∈ {0, 1, 2, 3, 4}, `None` for D ≥ 5. + /// `D = 0` returns `Some(1.0)` (empty product). + /// This is a `const fn` (Rust 1.94+) and uses fused multiply-add (`mul_add`) + /// for improved accuracy and performance. + /// + /// For a determinant that works for any dimension (falling back to LU for D ≥ 5), + /// use [`det`](Self::det). + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + /// assert!((m.det_direct().unwrap() - (-2.0)).abs() <= 1e-12); + /// + /// // D = 0 is the empty product. + /// assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0)); + /// + /// // D ≥ 5 returns None. + /// assert!(Matrix::<5>::identity().det_direct().is_none()); + /// ``` + #[inline] + #[must_use] + pub const fn det_direct(&self) -> Option { + match D { + 0 => Some(1.0), + 1 => Some(self.rows[0][0]), + 2 => { + // ad - bc + Some(self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0]))) + } + 3 => { + // Cofactor expansion on first row. + let m00 = + self.rows[1][1].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][1])); + let m01 = + self.rows[1][0].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][0])); + let m02 = + self.rows[1][0].mul_add(self.rows[2][1], -(self.rows[1][1] * self.rows[2][0])); + Some( + self.rows[0][0] + .mul_add(m00, (-self.rows[0][1]).mul_add(m01, self.rows[0][2] * m02)), + ) + } + 4 => { + // Cofactor expansion on first row → four 3×3 sub-determinants. + // Hoist the 6 unique 2×2 minors from rows 2–3 (each used twice). + let r = &self.rows; + + // 2×2 minors: s_ij = r[2][i]*r[3][j] - r[2][j]*r[3][i] + let s23 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2])); // cols 2,3 + let s13 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1])); // cols 1,3 + let s12 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1])); // cols 1,2 + let s03 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0])); // cols 0,3 + let s02 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0])); // cols 0,2 + let s01 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0])); // cols 0,1 + + // 3×3 cofactors via row 1 expansion using hoisted minors. + let c00 = r[1][1].mul_add(s23, (-r[1][2]).mul_add(s13, r[1][3] * s12)); + let c01 = r[1][0].mul_add(s23, (-r[1][2]).mul_add(s03, r[1][3] * s02)); + let c02 = r[1][0].mul_add(s13, (-r[1][1]).mul_add(s03, r[1][3] * s01)); + let c03 = r[1][0].mul_add(s12, (-r[1][1]).mul_add(s02, r[1][2] * s01)); + + Some(r[0][0].mul_add( + c00, + (-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))), + )) + } + _ => None, + } + } + + /// Determinant, using closed-form formulas for D ≤ 4 and LU decomposition for D ≥ 5. + /// + /// For D ∈ {1, 2, 3, 4}, this bypasses LU factorization entirely for a significant + /// speedup (see [`det_direct`](Self::det_direct)). The `tol` parameter is only used + /// by the LU fallback path for D ≥ 5. /// /// # Examples /// ``` @@ -210,9 +288,17 @@ impl Matrix { /// ``` /// /// # Errors - /// Propagates LU factorization errors (e.g. singular matrices). + /// Returns [`LaError::NonFinite`] if the result contains NaN or infinity. + /// For D ≥ 5, propagates LU factorization errors (e.g. [`LaError::Singular`]). #[inline] pub fn det(self, tol: f64) -> Result { + if let Some(d) = self.det_direct() { + return if d.is_finite() { + Ok(d) + } else { + Err(LaError::NonFinite { pivot_col: 0 }) + }; + } self.lu(tol).map(|lu| lu.det()) } } @@ -231,6 +317,7 @@ mod tests { use approx::assert_abs_diff_eq; use pastey::paste; + use std::hint::black_box; macro_rules! gen_public_api_matrix_tests { ($d:literal) => { @@ -328,4 +415,188 @@ mod tests { gen_public_api_matrix_tests!(3); gen_public_api_matrix_tests!(4); gen_public_api_matrix_tests!(5); + + // === det_direct tests === + + #[test] + fn det_direct_d0_is_one() { + assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0)); + } + + #[test] + fn det_direct_d1_returns_element() { + let m = Matrix::<1>::from_rows([[42.0]]); + assert_eq!(m.det_direct(), Some(42.0)); + } + + #[test] + fn det_direct_d2_known_value() { + // [[1,2],[3,4]] → det = 1*4 - 2*3 = -2 + // black_box prevents compile-time constant folding of the const fn. + let m = black_box(Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]])); + assert_abs_diff_eq!(m.det_direct().unwrap(), -2.0, epsilon = 1e-15); + } + + #[test] + fn det_direct_d3_known_value() { + // Classic 3×3: det = 0 + let m = black_box(Matrix::<3>::from_rows([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ])); + assert_abs_diff_eq!(m.det_direct().unwrap(), 0.0, epsilon = 1e-12); + } + + #[test] + fn det_direct_d3_nonsingular() { + // [[2,1,0],[0,3,1],[1,0,2]] → det = 2*(6-0) - 1*(0-1) + 0 = 13 + let m = black_box(Matrix::<3>::from_rows([ + [2.0, 1.0, 0.0], + [0.0, 3.0, 1.0], + [1.0, 0.0, 2.0], + ])); + assert_abs_diff_eq!(m.det_direct().unwrap(), 13.0, epsilon = 1e-12); + } + + #[test] + fn det_direct_d4_identity() { + let m = black_box(Matrix::<4>::identity()); + assert_abs_diff_eq!(m.det_direct().unwrap(), 1.0, epsilon = 1e-15); + } + + #[test] + fn det_direct_d4_known_value() { + // Diagonal matrix: det = product of diagonal entries. + let mut rows = [[0.0f64; 4]; 4]; + rows[0][0] = 2.0; + rows[1][1] = 3.0; + rows[2][2] = 5.0; + rows[3][3] = 7.0; + let m = black_box(Matrix::<4>::from_rows(rows)); + assert_abs_diff_eq!(m.det_direct().unwrap(), 210.0, epsilon = 1e-12); + } + + #[test] + fn det_direct_d5_returns_none() { + assert_eq!(Matrix::<5>::identity().det_direct(), None); + } + + #[test] + fn det_direct_d8_returns_none() { + assert_eq!(Matrix::<8>::zero().det_direct(), None); + } + + macro_rules! gen_det_direct_agrees_with_lu { + ($d:literal) => { + paste! { + #[test] + #[allow(clippy::cast_precision_loss)] // r, c, D are tiny integers + fn []() { + // Well-conditioned matrix: diagonally dominant. + let mut rows = [[0.0f64; $d]; $d]; + for r in 0..$d { + for c in 0..$d { + rows[r][c] = if r == c { + (r as f64) + f64::from($d) + 1.0 + } else { + 0.1 / ((r + c + 1) as f64) + }; + } + } + let m = Matrix::<$d>::from_rows(rows); + let direct = m.det_direct().unwrap(); + let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det(); + let eps = lu_det.abs().mul_add(1e-12, 1e-12); + assert_abs_diff_eq!(direct, lu_det, epsilon = eps); + } + } + }; + } + + gen_det_direct_agrees_with_lu!(1); + gen_det_direct_agrees_with_lu!(2); + gen_det_direct_agrees_with_lu!(3); + gen_det_direct_agrees_with_lu!(4); + + #[test] + fn det_direct_identity_all_dims() { + assert_abs_diff_eq!( + Matrix::<1>::identity().det_direct().unwrap(), + 1.0, + epsilon = 0.0 + ); + assert_abs_diff_eq!( + Matrix::<2>::identity().det_direct().unwrap(), + 1.0, + epsilon = 0.0 + ); + assert_abs_diff_eq!( + Matrix::<3>::identity().det_direct().unwrap(), + 1.0, + epsilon = 0.0 + ); + assert_abs_diff_eq!( + Matrix::<4>::identity().det_direct().unwrap(), + 1.0, + epsilon = 0.0 + ); + } + + #[test] + fn det_direct_zero_matrix() { + assert_abs_diff_eq!( + Matrix::<2>::zero().det_direct().unwrap(), + 0.0, + epsilon = 0.0 + ); + assert_abs_diff_eq!( + Matrix::<3>::zero().det_direct().unwrap(), + 0.0, + epsilon = 0.0 + ); + assert_abs_diff_eq!( + Matrix::<4>::zero().det_direct().unwrap(), + 0.0, + epsilon = 0.0 + ); + } + + #[test] + fn det_returns_nonfinite_error_for_nan_d2() { + let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]); + assert_eq!( + m.det(DEFAULT_PIVOT_TOL), + Err(LaError::NonFinite { pivot_col: 0 }) + ); + } + + #[test] + fn det_returns_nonfinite_error_for_inf_d3() { + let m = + Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); + assert_eq!( + m.det(DEFAULT_PIVOT_TOL), + Err(LaError::NonFinite { pivot_col: 0 }) + ); + } + + #[test] + fn det_direct_is_const_evaluable_d2() { + // Const evaluation proves the function is truly const fn. + const DET: Option = { + let m = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0]]); + m.det_direct() + }; + assert_eq!(DET, Some(1.0)); + } + + #[test] + fn det_direct_is_const_evaluable_d3() { + const DET: Option = { + let m = Matrix::<3>::from_rows([[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]]); + m.det_direct() + }; + assert_eq!(DET, Some(30.0)); + } } diff --git a/tests/proptest_matrix.proptest-regressions b/tests/proptest_matrix.proptest-regressions new file mode 100644 index 0000000..141cd57 --- /dev/null +++ b/tests/proptest_matrix.proptest-regressions @@ -0,0 +1,8 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc a2989aa781cea4ab806e5c6242f6493e5aa1af226b439a7f6eb8c4257cd48dde # shrinks to diag = [-48.0, -29.1, -36.1], b_arr = [0.0, 0.0, 0.0] +cc d5cecfb996d10952fd516ed975202ebdf27ccd95649c66e671c348d7d0e4de92 # shrinks to diag = [-10.5, -52.4, 6.7, -15.4], b_arr = [0.0, 0.0, 0.0, 0.0] diff --git a/tests/proptest_matrix.rs b/tests/proptest_matrix.rs index 338e213..1dd1098 100644 --- a/tests/proptest_matrix.rs +++ b/tests/proptest_matrix.rs @@ -97,7 +97,11 @@ macro_rules! gen_public_api_matrix_proptests { } acc }; - assert_abs_diff_eq!(det, expected_det, epsilon = 1e-12); + // The closed-form and LU paths evaluate the diagonal + // product in different orders, so we allow a few ULPs of + // relative error (floor 1e-12 for near-zero determinants). + let eps = expected_det.abs().mul_add(1e-12, 1e-12); + assert_abs_diff_eq!(det, expected_det, epsilon = eps); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); let b = Vector::<$d>::new(b_arr);