Skip to content

Commit 81ecb35

Browse files
acgetchelloz-agent
andcommitted
feat: const-ify Lu/Ldlt det + solve_vec and Matrix inf_norm + det_errbound
Make six public methods `const fn`, completing const-evaluation parity with `Matrix::det_direct` and `Vector::dot` now that MSRV 1.95 exposes `f64::mul_add` as const (1.94) and `core::hint::cold_path` as const (1.95): - `Lu::det`, `Lu::solve_vec` - `Ldlt::det`, `Ldlt::solve_vec` - `Matrix::inf_norm`, `Matrix::det_errbound` Iterator chains (`.iter().map().sum()`, `.enumerate().take(i)`, `.enumerate().skip(i + 1)`, `.iter_mut().enumerate().take(D)`) were rewritten as `while` loops since they are not const-stable. Fix error-variant correctness in both solve_vec paths: - A corrupt stored `U` / `D` diagonal at `(i, i)` now surfaces as `LaError::NonFinite { row: Some(i), col: i }`, matching the convention used by `Matrix::det`, `Lu::factor`, and `Ldlt::factor`. - A computed-intermediate overflow keeps `row: None, col: i`. - Previously both were conflated into `row: None, col: i`, defeating debuggability for callers who construct factorizations directly. Sharpen `LaError::NonFinite` variant docs and the `# Errors` sections of `Lu::solve_vec` / `Ldlt::solve_vec` to spell out the `(row, col)` contract for each failure mode. Cleanup: - Import `ERR_COEFF_{2,3,4}` at `src/matrix.rs` top; drop `crate::` prefixes inside `Matrix::det_errbound`. - Rename intermediate bindings `q` / `v` → `quotient` to stay under `clippy::many_single_char_names`. - Rename const-evaluability tests from `*_is_const_evaluable_*` to `*_const_eval_*` for a concise, consistent style. Tests: - Added `{lu,ldlt,inf_norm,det_errbound}_const_eval_*` tests that force compile-time evaluation inside `const { … }` initializers. - Added `solve_vec_defensive_non_finite_diagonal_{2,3,4,5}d` in `src/lu.rs` covering the new split error path (previously only the `Singular` defensive path was exercised). - Updated matching `Ldlt` defensive tests to assert the new `row: Some(D - 1), col: D - 1` payload. Validation: `just ci` passes — 164 unit tests, 324 integration/feature tests, 28 doctests; clippy pedantic + nursery as errors; all examples. Co-Authored-By: Oz <oz-agent@warp.dev>
1 parent aba1dc7 commit 81ecb35

4 files changed

Lines changed: 342 additions & 56 deletions

File tree

src/ldlt.rs

Lines changed: 113 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,12 @@ impl<const D: usize> Ldlt<D> {
124124
/// ```
125125
#[inline]
126126
#[must_use]
127-
pub fn det(&self) -> f64 {
127+
pub const fn det(&self) -> f64 {
128128
let mut det = 1.0;
129-
for i in 0..D {
129+
let mut i = 0;
130+
while i < D {
130131
det *= self.factors.rows[i][i];
132+
i += 1;
131133
}
132134
det
133135
}
@@ -154,57 +156,82 @@ impl<const D: usize> Ldlt<D> {
154156
/// # Errors
155157
/// Returns [`LaError::Singular`] if a diagonal entry `d = D[i,i]` satisfies `d <= tol`
156158
/// (non-positive or too small), where `tol` is the tolerance that was used during factorization.
157-
/// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
159+
///
160+
/// Returns [`LaError::NonFinite`] if NaN/∞ is detected. The `row`/`col` coordinates
161+
/// follow the convention documented on [`LaError::NonFinite`]:
162+
///
163+
/// - `row: Some(i), col: i` — the stored `D` diagonal at `(i, i)` is non-finite
164+
/// (only reachable via direct `Ldlt` construction; [`Matrix::ldlt`](crate::Matrix::ldlt)
165+
/// rejects such factorizations).
166+
/// - `row: None, col: i` — a computed intermediate (forward/back-substitution
167+
/// accumulator or the quotient `x[i] / diag`) overflowed to NaN/∞ at step `i`.
158168
#[inline]
159-
pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
169+
pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
160170
let mut x = b.data;
161171

162172
// Forward substitution: L y = b (L has unit diagonal).
163-
for i in 0..D {
173+
let mut i = 0;
174+
while i < D {
164175
let mut sum = x[i];
165176
let row = self.factors.rows[i];
166-
for (j, x_j) in x.iter().enumerate().take(i) {
167-
sum = (-row[j]).mul_add(*x_j, sum);
177+
let mut j = 0;
178+
while j < i {
179+
sum = (-row[j]).mul_add(x[j], sum);
180+
j += 1;
168181
}
169182
if !sum.is_finite() {
170183
cold_path();
171184
return Err(LaError::NonFinite { row: None, col: i });
172185
}
173186
x[i] = sum;
187+
i += 1;
174188
}
175189

176190
// Diagonal solve: D z = y.
177-
for (i, x_i) in x.iter_mut().enumerate().take(D) {
191+
let mut i = 0;
192+
while i < D {
178193
let diag = self.factors.rows[i][i];
194+
// A corrupt stored diagonal is a specific matrix cell (i, i),
195+
// distinct from a computed overflow — report it with
196+
// `row: Some(i)` per the `LaError::NonFinite` convention used by
197+
// `Matrix::det`, `Lu::factor`, and `Ldlt::factor`.
179198
if !diag.is_finite() {
180199
cold_path();
181-
return Err(LaError::NonFinite { row: None, col: i });
200+
return Err(LaError::NonFinite {
201+
row: Some(i),
202+
col: i,
203+
});
182204
}
183205
if diag <= self.tol {
184206
cold_path();
185207
return Err(LaError::Singular { pivot_col: i });
186208
}
187209

188-
let v = *x_i / diag;
189-
if !v.is_finite() {
210+
let quotient = x[i] / diag;
211+
if !quotient.is_finite() {
190212
cold_path();
191213
return Err(LaError::NonFinite { row: None, col: i });
192214
}
193-
*x_i = v;
215+
x[i] = quotient;
216+
i += 1;
194217
}
195218

196219
// Back substitution: Lᵀ x = z.
197-
for ii in 0..D {
220+
let mut ii = 0;
221+
while ii < D {
198222
let i = D - 1 - ii;
199223
let mut sum = x[i];
200-
for (j, x_j) in x.iter().enumerate().skip(i + 1) {
201-
sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum);
224+
let mut j = i + 1;
225+
while j < D {
226+
sum = (-self.factors.rows[j][i]).mul_add(x[j], sum);
227+
j += 1;
202228
}
203229
if !sum.is_finite() {
204230
cold_path();
205231
return Err(LaError::NonFinite { row: None, col: i });
206232
}
207233
x[i] = sum;
234+
ii += 1;
208235
}
209236

210237
Ok(Vector::new(x))
@@ -487,7 +514,9 @@ mod tests {
487514
paste! {
488515
/// `solve_vec` must surface `NonFinite` when a stored
489516
/// diagonal is NaN, even though `factor` cannot produce
490-
/// such a factorization.
517+
/// such a factorization. The error must pinpoint the
518+
/// corrupt cell at `(D-1, D-1)` per the
519+
/// [`LaError::NonFinite`] convention.
491520
#[test]
492521
fn [<solve_vec_defensive_non_finite_diagonal_ $d d>]() {
493522
let mut factors = Matrix::<$d>::identity();
@@ -498,7 +527,13 @@ mod tests {
498527
};
499528
let b = Vector::<$d>::new([1.0; $d]);
500529
let err = ldlt.solve_vec(b).unwrap_err();
501-
assert_eq!(err, LaError::NonFinite { row: None, col: $d - 1 });
530+
assert_eq!(
531+
err,
532+
LaError::NonFinite {
533+
row: Some($d - 1),
534+
col: $d - 1,
535+
}
536+
);
502537
}
503538

504539
/// `solve_vec` must surface `Singular` when a stored
@@ -524,4 +559,65 @@ mod tests {
524559
gen_solve_vec_defensive_tests!(3);
525560
gen_solve_vec_defensive_tests!(4);
526561
gen_solve_vec_defensive_tests!(5);
562+
563+
// -----------------------------------------------------------------------
564+
// Const-evaluability tests.
565+
//
566+
// These prove that `Ldlt::det` and `Ldlt::solve_vec` are truly `const fn`
567+
// by forcing the compiler to evaluate them inside a `const` initializer.
568+
// `Ldlt::factor` is not (yet) `const fn` because the rank-1 update loop
569+
// uses array indexing patterns that still require non-const helpers on
570+
// some toolchains; we therefore construct `Ldlt<D>` directly.
571+
// -----------------------------------------------------------------------
572+
573+
#[test]
574+
fn ldlt_det_const_eval_d2() {
575+
const DET: f64 = {
576+
// Diagonal D = [4.0, 0.25] ⇒ det = 1.0.
577+
let mut factors = Matrix::<2>::identity();
578+
factors.rows[0][0] = 4.0;
579+
factors.rows[1][1] = 0.25;
580+
let ldlt = Ldlt::<2> {
581+
factors,
582+
tol: DEFAULT_SINGULAR_TOL,
583+
};
584+
ldlt.det()
585+
};
586+
assert!((DET - 1.0).abs() <= 1e-12);
587+
}
588+
589+
#[test]
590+
fn ldlt_det_const_eval_d3() {
591+
const DET: f64 = {
592+
// Diagonal D = [2.0, 3.0, 5.0] ⇒ det = 30.0.
593+
let mut factors = Matrix::<3>::identity();
594+
factors.rows[0][0] = 2.0;
595+
factors.rows[1][1] = 3.0;
596+
factors.rows[2][2] = 5.0;
597+
let ldlt = Ldlt::<3> {
598+
factors,
599+
tol: DEFAULT_SINGULAR_TOL,
600+
};
601+
ldlt.det()
602+
};
603+
assert!((DET - 30.0).abs() <= 1e-12);
604+
}
605+
606+
#[test]
607+
fn ldlt_solve_vec_const_eval_d2() {
608+
// Identity factors ⇒ solve_vec returns the RHS untouched.
609+
const X: [f64; 2] = {
610+
let ldlt = Ldlt::<2> {
611+
factors: Matrix::<2>::identity(),
612+
tol: DEFAULT_SINGULAR_TOL,
613+
};
614+
let b = Vector::<2>::new([1.0, 2.0]);
615+
match ldlt.solve_vec(b) {
616+
Ok(v) => v.into_array(),
617+
Err(_) => [0.0, 0.0],
618+
}
619+
};
620+
assert!((X[0] - 1.0).abs() <= 1e-12);
621+
assert!((X[1] - 2.0).abs() <= 1e-12);
622+
}
527623
}

src/lib.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,22 @@ pub enum LaError {
119119
pivot_col: usize,
120120
},
121121
/// A non-finite value (NaN/∞) was encountered.
122+
///
123+
/// The `(row, col)` coordinate follows a consistent convention across the crate:
124+
///
125+
/// - `row: Some(r), col: c` — a *stored* matrix cell at `(r, c)` is non-finite.
126+
/// Used by `Matrix::det`, `Lu::factor`, `Ldlt::factor`, and the `solve_vec`
127+
/// paths when they detect a corrupt stored factor (only reachable via
128+
/// direct struct construction; `factor` itself rejects such inputs).
129+
/// - `row: None, col: c` — the non-finite value is either a *vector input*
130+
/// entry at index `c`, or a *computed intermediate* at step `c`
131+
/// (e.g. an accumulator that overflowed during forward/back substitution).
122132
NonFinite {
123-
/// Row of the non-finite entry (for matrix inputs), or `None` when
124-
/// the error originates from a vector input or a computed intermediate.
133+
/// Row of the non-finite entry for a stored matrix cell, or `None` for
134+
/// a vector-input entry or a computed intermediate. See the variant
135+
/// docs for the full convention.
125136
row: Option<usize>,
126-
/// Column index (for matrix inputs), vector index, or factorization
137+
/// Column index (stored cell), vector index, or factorization/solve
127138
/// step where the non-finite value was detected.
128139
col: usize,
129140
},

0 commit comments

Comments
 (0)