Skip to content

Commit a4b8500

Browse files
committed
Fix SuperLU_DIST memory leak (#4258)
* Fix memory bug in options. * Add check on incoming matrix. * More defensive check for passing new operator A. * Add some state after reviewing against PETSc - unavoidable. * And add to Python interface * Improve doctoring
1 parent f701043 commit a4b8500

4 files changed

Lines changed: 84 additions & 15 deletions

File tree

cpp/dolfinx/la/superlu_dist.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ std::vector<int_t> row_indices(const auto& A)
103103
// Write the per-scalar-row entry counts into `flattened_rowptr[1:]`, with
104104
// each block-row contributing `bs[0]` copies.
105105
std::vector<int_t> flattened_rowptr(m_loc_block * bs[0] + 1);
106+
flattened_rowptr[0] = A_rowptr[0] * bs[1];
106107
for (std::int32_t i = 0; i < m_loc_block; ++i)
107108
{
108109
int_t delta = (A_rowptr[i + 1] - A_rowptr[i]) * bs[1];
@@ -462,8 +463,7 @@ template <typename T>
462463
void SuperLUDistSolver<T>::set_options(
463464
SuperLUDistStructs::superlu_dist_options_t options)
464465
{
465-
_options = std::make_unique<SuperLUDistStructs::superlu_dist_options_t>(
466-
std::move(options));
466+
*_options = options;
467467
}
468468
//----------------------------------------------------------------------------
469469
template <typename T>
@@ -544,13 +544,69 @@ void SuperLUDistSolver<T>::set_option(std::string name, std::string value)
544544
}
545545
//----------------------------------------------------------------------------
546546
template <typename T>
547-
void SuperLUDistSolver<T>::set_A(std::shared_ptr<const SuperLUDistMatrix<T>> A)
547+
SuperLUDistSolver<T>::~SuperLUDistSolver()
548548
{
549+
if (_factored)
550+
{
551+
int_t n = _superlu_matA->supermatrix()->ncol;
552+
if constexpr (std::is_same_v<T, double>)
553+
dDestroy_LU(n, _gridinfo.get(), _lustruct.get());
554+
else if constexpr (std::is_same_v<T, float>)
555+
sDestroy_LU(n, _gridinfo.get(), _lustruct.get());
556+
else if constexpr (std::is_same_v<T, std::complex<double>>)
557+
zDestroy_LU(n, _gridinfo.get(), _lustruct.get());
558+
else
559+
static_assert(always_false_v<T>, "Invalid scalar type");
560+
}
561+
}
562+
//----------------------------------------------------------------------------
563+
template <typename T>
564+
void SuperLUDistSolver<T>::set_A(std::shared_ptr<const SuperLUDistMatrix<T>> A,
565+
std::string fact)
566+
{
567+
if (A->supermatrix()->nrow != _superlu_matA->supermatrix()->nrow
568+
or A->supermatrix()->ncol != _superlu_matA->supermatrix()->ncol)
569+
{
570+
throw std::runtime_error(
571+
"New matrix A has different size to the matrix used to construct the "
572+
"solver.");
573+
}
549574
_superlu_matA = A;
575+
576+
// See pddistribute in SuperLU_DIST: on Fact=DOFACT or Fact=SamePattern
577+
// pdgssvx overwrites LUstruct's internal pointers without freeing the
578+
// previous arrays, so a prior factorisation must be released first.
579+
// Fact=SamePattern_SameRowPerm reuses the existing arrays in place.
580+
if (fact == "DOFACT" or fact == "SamePattern")
581+
{
582+
if (_factored)
583+
{
584+
int_t n = _superlu_matA->supermatrix()->ncol;
585+
if constexpr (std::is_same_v<T, double>)
586+
dDestroy_LU(n, _gridinfo.get(), _lustruct.get());
587+
else if constexpr (std::is_same_v<T, float>)
588+
sDestroy_LU(n, _gridinfo.get(), _lustruct.get());
589+
else if constexpr (std::is_same_v<T, std::complex<double>>)
590+
zDestroy_LU(n, _gridinfo.get(), _lustruct.get());
591+
else
592+
static_assert(always_false_v<T>, "Invalid scalar type");
593+
_factored = false;
594+
}
595+
_options->Fact = (fact == "DOFACT") ? DOFACT : SamePattern;
596+
}
597+
else if (fact == "SamePattern_SameRowPerm")
598+
{
599+
_options->Fact = SamePattern_SameRowPerm;
600+
}
601+
else
602+
{
603+
throw std::runtime_error("set_A fact must be one of 'DOFACT', "
604+
"'SamePattern', 'SamePattern_SameRowPerm'");
605+
}
550606
}
551607
//----------------------------------------------------------------------------
552608
template <typename T>
553-
int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
609+
int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u)
554610
{
555611
common::Timer tsolve("SuperLU_DIST solve");
556612

@@ -614,6 +670,10 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
614670
if (info != 0)
615671
spdlog::info("SuperLU_DIST p*gssvx() error: {}", info);
616672

673+
// pdgssvx allocates LUstruct internals during factorisation. Record this
674+
// so the destructor / set_A can call Destroy_LU to release them.
675+
_factored = true;
676+
617677
PStatPrint(_options.get(), &stat, _gridinfo.get());
618678
PStatFree(&stat);
619679

cpp/dolfinx/la/superlu_dist.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ class SuperLUDistSolver
204204
/// Copy assignment
205205
SuperLUDistSolver& operator=(const SuperLUDistSolver&) = delete;
206206

207+
/// @brief Destructor. Frees internal LU arrays before LUstructFree.
208+
~SuperLUDistSolver();
209+
207210
/// @brief Set solver option name to value
208211
///
209212
/// See SuperLU_DIST User's Guide for option names and values.
@@ -237,11 +240,14 @@ class SuperLUDistSolver
237240

238241
/// @brief Set assembled left-hand side matrix A.
239242
///
240-
/// For advanced use with SuperLU_DIST option `Factor` allowing use of
241-
/// previously computed permutations when solving with new matrix A.
243+
/// New matrix must have the same size/parallel layout as the matrix
244+
/// used to construct the solver.
242245
///
243246
/// @param A Assembled left-hand side matrix.
244-
void set_A(std::shared_ptr<const SuperLUDistMatrix<T>> A);
247+
/// @param fact One of `"DOFACT"`, `"SamePattern"`,
248+
/// `"SamePattern_SameRowPerm"`. See the SuperLU_DIST documentation
249+
/// for the meaning of these values.
250+
void set_A(std::shared_ptr<const SuperLUDistMatrix<T>> A, std::string fact);
245251

246252
/// @brief Solve linear system Au = b.
247253
///
@@ -255,7 +261,7 @@ class SuperLUDistSolver
255261
/// @note The values of `A` are modified in-place during the solve.
256262
/// @note To solve with successive right-hand sides the caller must
257263
/// `solver.set_options("Factor", "FACTORED")` after the first solve.
258-
int solve(const Vector<T>& b, Vector<T>& u) const;
264+
int solve(const Vector<T>& b, Vector<T>& u);
259265

260266
private:
261267
// Assembled left-hand side matrix
@@ -275,6 +281,10 @@ class SuperLUDistSolver
275281
// Pointer to 'typed' struct *SOLVEstruct
276282
std::unique_ptr<typename map_t<T>::SOLVEstruct_t, SolveStructDeleter>
277283
_solvestruct;
284+
285+
// True once pdgssvx has populated LUstruct with per-block-column arrays
286+
// that must be released via Destroy_LU before LUstructFree.
287+
bool _factored = false;
278288
};
279289
} // namespace dolfinx::la
280290
#endif

python/dolfinx/la/superlu_dist.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,16 @@ def set_option(self, name: str, value: str):
119119
"""
120120
self._cpp_object.set_option(name, value)
121121

122-
def set_A(self, A: SuperLUDistMatrix[_T]):
122+
def set_A(self, A: SuperLUDistMatrix[_T], fact: str):
123123
"""Set assembled left-hand side matrix.
124124
125-
For advanced use with SuperLU_DIST option `Factor` allowing use of
126-
previously computed permutations when solving with new matrix A.
127-
128125
Args:
129126
A: Assembled left-hand side matrix :math:`A`.
127+
fact: One of ``"DOFACT"``, ``"SamePattern"``,
128+
``"SamePattern_SameRowPerm"``. See the SuperLU_DIST
129+
documentation for the meaning of these values.
130130
"""
131-
self._cpp_object.set_A(A._cpp_object)
131+
self._cpp_object.set_A(A._cpp_object, fact)
132132

133133
def solve(self, b: dolfinx.la.Vector[_T], u: dolfinx.la.Vector[_T]) -> int:
134134
"""Solve linear system :math:`Au = b`.

python/test/unit/la/test_superlu_dist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ def solve_and_check(solver, b):
128128
A_3.scatter_reverse()
129129

130130
A_superlu_3 = superlu_dist_matrix(A_3)
131-
solver_2.set_A(A_superlu_3)
132-
solver_2.set_option("Fact", "SamePattern")
131+
solver_2.set_A(A_superlu_3, "SamePattern")
133132
uh_2 = solve_and_check(solver_2, b_2)
134133
check_error(u_ex, uh_2)
135134

0 commit comments

Comments
 (0)