11//! LU decomposition and solves.
22
3+ use core:: hint:: cold_path;
4+
35use crate :: LaError ;
46use crate :: matrix:: Matrix ;
57use crate :: vector:: Vector ;
@@ -31,6 +33,7 @@ impl<const D: usize> Lu<D> {
3133 let mut pivot_row = k;
3234 let mut pivot_abs = lu. rows [ k] [ k] . abs ( ) ;
3335 if !pivot_abs. is_finite ( ) {
36+ cold_path ( ) ;
3437 return Err ( LaError :: NonFinite {
3538 row : Some ( k) ,
3639 col : k,
@@ -40,6 +43,7 @@ impl<const D: usize> Lu<D> {
4043 for r in ( k + 1 ) ..D {
4144 let v = lu. rows [ r] [ k] . abs ( ) ;
4245 if !v. is_finite ( ) {
46+ cold_path ( ) ;
4347 return Err ( LaError :: NonFinite {
4448 row : Some ( r) ,
4549 col : k,
@@ -52,6 +56,7 @@ impl<const D: usize> Lu<D> {
5256 }
5357
5458 if pivot_abs <= tol {
59+ cold_path ( ) ;
5560 return Err ( LaError :: Singular { pivot_col : k } ) ;
5661 }
5762
@@ -63,6 +68,7 @@ impl<const D: usize> Lu<D> {
6368
6469 let pivot = lu. rows [ k] [ k] ;
6570 if !pivot. is_finite ( ) {
71+ cold_path ( ) ;
6672 return Err ( LaError :: NonFinite {
6773 row : Some ( k) ,
6874 col : k,
@@ -73,6 +79,7 @@ impl<const D: usize> Lu<D> {
7379 for r in ( k + 1 ) ..D {
7480 let mult = lu. rows [ r] [ k] / pivot;
7581 if !mult. is_finite ( ) {
82+ cold_path ( ) ;
7683 return Err ( LaError :: NonFinite {
7784 row : Some ( r) ,
7885 col : k,
@@ -132,6 +139,7 @@ impl<const D: usize> Lu<D> {
132139 sum = ( -row[ j] ) . mul_add ( * x_j, sum) ;
133140 }
134141 if !sum. is_finite ( ) {
142+ cold_path ( ) ;
135143 return Err ( LaError :: NonFinite { row : None , col : i } ) ;
136144 }
137145 x[ i] = sum;
@@ -148,14 +156,17 @@ impl<const D: usize> Lu<D> {
148156
149157 let diag = row[ i] ;
150158 if !diag. is_finite ( ) || !sum. is_finite ( ) {
159+ cold_path ( ) ;
151160 return Err ( LaError :: NonFinite { row : None , col : i } ) ;
152161 }
153162 if diag. abs ( ) <= self . tol {
163+ cold_path ( ) ;
154164 return Err ( LaError :: Singular { pivot_col : i } ) ;
155165 }
156166
157167 let q = sum / diag;
158168 if !q. is_finite ( ) {
169+ cold_path ( ) ;
159170 return Err ( LaError :: NonFinite { row : None , col : i } ) ;
160171 }
161172 x[ i] = q;
@@ -474,4 +485,20 @@ mod tests {
474485 let err = lu. solve_vec ( b) . unwrap_err ( ) ;
475486 assert_eq ! ( err, LaError :: NonFinite { row: None , col: 1 } ) ;
476487 }
488+
489+ #[ test]
490+ fn solve_vec_nonfinite_back_substitution_sum_overflow ( ) {
491+ // Upper-triangular U with a very large off-diagonal in row 1 and a
492+ // very large x[2] produced by the RHS. The back-substitution
493+ // accumulator `sum = (-row[j]).mul_add(x[j], sum)` overflows while
494+ // reducing row 1, so the failure is detected via the `!sum.is_finite()`
495+ // branch of the combined diag/sum check (distinct from the
496+ // `q = sum / diag` overflow path covered above).
497+ let a = Matrix :: < 3 > :: from_rows ( [ [ 1.0 , 0.0 , 0.0 ] , [ 0.0 , 1.0 , 1.0e200 ] , [ 0.0 , 0.0 , 1.0 ] ] ) ;
498+ let lu = a. lu ( DEFAULT_PIVOT_TOL ) . unwrap ( ) ;
499+
500+ let b = Vector :: < 3 > :: new ( [ 0.0 , 0.0 , 1.0e200 ] ) ;
501+ let err = lu. solve_vec ( b) . unwrap_err ( ) ;
502+ assert_eq ! ( err, LaError :: NonFinite { row: None , col: 1 } ) ;
503+ }
477504}
0 commit comments