@@ -71,7 +71,7 @@ template <rk2_tensor Matrix> class DenseLUFactor {
7171 // put permutation in block_perm in place
7272 // put pivot perturbation in has_pivot_perturbation in place
7373 template <class Derived >
74- static void factorize_block_in_place (Eigen::MatrixBase<Derived>&& matrix, BlockPerm& block_perm,
74+ static void factorize_block_in_place (Eigen::MatrixBase<Derived>& matrix, BlockPerm& block_perm,
7575 double perturb_threshold, bool use_pivot_perturbation,
7676 bool & has_pivot_perturbation)
7777 requires(std::same_as<typename Derived::Scalar, Scalar> && rk2_tensor<Derived> &&
@@ -150,7 +150,66 @@ template <rk2_tensor Matrix> class DenseLUFactor {
150150 throw SparseMatrixError{}; // can not specify error code
151151 }
152152 }
153- capturing::into_the_void (std::move (matrix));
153+ }
154+
155+ // Forward substitution with the L matrix. The diagonal entries of L are implicit 1.0
156+ // The rhs may be a vector or a matrix; matrix rhs columns are solved simultaneously.
157+ template <class LUDerived , class RHSDerived >
158+ static void forward_substitute_inplace (Eigen::MatrixBase<LUDerived> const & lu_matrix, RHSDerived& rhs)
159+ requires(std::same_as<typename LUDerived::Scalar, Scalar> &&
160+ std::same_as<typename RHSDerived::Scalar, Scalar> && rk2_tensor<LUDerived> &&
161+ (LUDerived::RowsAtCompileTime == size) && (LUDerived::ColsAtCompileTime == size) &&
162+ (RHSDerived::RowsAtCompileTime == size))
163+ {
164+ for (int8_t row = 0 ; row < size; ++row) {
165+ for (int8_t col = 0 ; col < row; ++col) {
166+ rhs.row (row) -= lu_matrix (row, col) * rhs.row (col);
167+ }
168+ }
169+ }
170+
171+ // Backward substitution with the U matrix stored in lu_matrix.
172+ // The rhs may be a vector or a matrix; matrix rhs columns are solved simultaneously.
173+ template <class LUDerived , class RHSDerived >
174+ static void backward_substitute_inplace (Eigen::MatrixBase<LUDerived> const & lu_matrix, RHSDerived& rhs)
175+ requires(std::same_as<typename LUDerived::Scalar, Scalar> &&
176+ std::same_as<typename RHSDerived::Scalar, Scalar> && rk2_tensor<LUDerived> &&
177+ (LUDerived::RowsAtCompileTime == size) && (LUDerived::ColsAtCompileTime == size) &&
178+ (RHSDerived::RowsAtCompileTime == size))
179+ {
180+ for (int8_t row = size - 1 ; row > -1 ; --row) {
181+ for (int8_t col = size - 1 ; col > row; --col) {
182+ rhs.row (row) -= lu_matrix (row, col) * rhs.row (col);
183+ }
184+ rhs.row (row) /= lu_matrix (row, row);
185+ }
186+ }
187+
188+ // given the factorized block satisfies L * U = P * A * Q
189+ // compute the inverse of the factorized L * U * X = I
190+ // returns X = (L * U)^-1 = (P * A * Q)^-1
191+ template <class Derived >
192+ static Matrix inverse_factorized_block (Eigen::MatrixBase<Derived> const & lu_matrix)
193+ requires(std::same_as<typename Derived::Scalar, Scalar> && rk2_tensor<Derived> &&
194+ (Derived::RowsAtCompileTime == size) && (Derived::ColsAtCompileTime == size))
195+ {
196+ Matrix inverse = Matrix::Identity ();
197+ forward_substitute_inplace (lu_matrix, inverse);
198+ backward_substitute_inplace (lu_matrix, inverse);
199+ return inverse;
200+ }
201+
202+ template <class Derived >
203+ static Matrix dense_inverse (Eigen::MatrixBase<Derived> const & lu_matrix, BlockPerm const & block_perm)
204+ requires(std::same_as<typename Derived::Scalar, Scalar> && rk2_tensor<Derived> &&
205+ (Derived::RowsAtCompileTime == size) && (Derived::ColsAtCompileTime == size))
206+ {
207+ // given the factorized block satisfies L * U = P * A * Q
208+ // return A^-1 = Q * (L * U)^-1 * P
209+ // lu_matrix is read-only: the packed L/U factor is preserved.
210+ // inverse_factorized_block() performs in-place substitutions only on
211+ // its local Identity() RHS, not on lu_matrix.
212+ return block_perm.q * inverse_factorized_block (lu_matrix) * block_perm.p ;
154213 }
155214};
156215
@@ -291,9 +350,9 @@ template <class Tensor, class RHSVector, class XVector> class SparseLUSolver {
291350 if constexpr (is_block) {
292351 // use machine precision by default
293352 // record block permutation
294- LUFactor::factorize_block_in_place ( lu_matrix[pivot_idx].matrix (), block_perm_array[pivot_row_col],
295- perturb_threshold, use_pivot_perturbation ,
296- has_pivot_perturbation_);
353+ auto pivot_matrix = lu_matrix[pivot_idx].matrix ();
354+ LUFactor::factorize_block_in_place (pivot_matrix, block_perm_array[pivot_row_col], perturb_threshold ,
355+ use_pivot_perturbation, has_pivot_perturbation_);
297356 return block_perm_array[pivot_row_col];
298357 } else {
299358 if (use_pivot_perturbation) {
@@ -612,13 +671,8 @@ template <class Tensor, class RHSVector, class XVector> class SparseLUSolver {
612671 }
613672 // forward substitution inside block, for block matrix
614673 if constexpr (is_block) {
615- XVector& xb = x[row];
616674 Tensor const & pivot = lu_matrix[diag_lu[row]];
617- for (Idx br = 0 ; br < block_size; ++br) {
618- for (Idx bc = 0 ; bc < br; ++bc) {
619- xb (br) -= pivot (br, bc) * xb (bc);
620- }
621- }
675+ LUFactor::forward_substitute_inplace (pivot.matrix (), x[row]);
622676 }
623677 }
624678
@@ -635,14 +689,8 @@ template <class Tensor, class RHSVector, class XVector> class SparseLUSolver {
635689 // solve the diagonal pivot
636690 if constexpr (is_block) {
637691 // backward substitution inside block
638- XVector& xb = x[row];
639692 Tensor const & pivot = lu_matrix[diag_lu[row]];
640- for (Idx br = block_size - 1 ; br != -1 ; --br) {
641- for (Idx bc = block_size - 1 ; bc > br; --bc) {
642- xb (br) -= pivot (br, bc) * xb (bc);
643- }
644- xb (br) = xb (br) / pivot (br, br);
645- }
693+ LUFactor::backward_substitute_inplace (pivot.matrix (), x[row]);
646694 } else {
647695 x[row] = x[row] / lu_matrix[diag_lu[row]];
648696 }
0 commit comments