@@ -103,6 +103,7 @@ std::vector<int_t> row_indices(const auto& A)
103103 // Write the per-scalar-row entry counts into `flattened_rowptr[1:]`, with
104104 // each block-row contributing `bs[0]` copies.
105105 std::vector<int_t > flattened_rowptr (m_loc_block * bs[0 ] + 1 );
106+ flattened_rowptr[0 ] = A_rowptr[0 ] * bs[1 ];
106107 for (std::int32_t i = 0 ; i < m_loc_block; ++i)
107108 {
108109 int_t delta = (A_rowptr[i + 1 ] - A_rowptr[i]) * bs[1 ];
@@ -462,8 +463,7 @@ template <typename T>
462463void SuperLUDistSolver<T>::set_options(
463464 SuperLUDistStructs::superlu_dist_options_t options)
464465{
465- _options = std::make_unique<SuperLUDistStructs::superlu_dist_options_t >(
466- std::move (options));
466+ *_options = options;
467467}
468468// ----------------------------------------------------------------------------
469469template <typename T>
@@ -544,13 +544,69 @@ void SuperLUDistSolver<T>::set_option(std::string name, std::string value)
544544}
545545// ----------------------------------------------------------------------------
546546template <typename T>
547- void SuperLUDistSolver<T>::set_A(std::shared_ptr< const SuperLUDistMatrix<T>> A )
547+ SuperLUDistSolver<T>::~SuperLUDistSolver ( )
548548{
549+ if (_factored)
550+ {
551+ int_t n = _superlu_matA->supermatrix ()->ncol ;
552+ if constexpr (std::is_same_v<T, double >)
553+ dDestroy_LU (n, _gridinfo.get (), _lustruct.get ());
554+ else if constexpr (std::is_same_v<T, float >)
555+ sDestroy_LU (n, _gridinfo.get (), _lustruct.get ());
556+ else if constexpr (std::is_same_v<T, std::complex <double >>)
557+ zDestroy_LU (n, _gridinfo.get (), _lustruct.get ());
558+ else
559+ static_assert (always_false_v<T>, " Invalid scalar type" );
560+ }
561+ }
562+ // ----------------------------------------------------------------------------
563+ template <typename T>
564+ void SuperLUDistSolver<T>::set_A(std::shared_ptr<const SuperLUDistMatrix<T>> A,
565+ std::string fact)
566+ {
567+ if (A->supermatrix ()->nrow != _superlu_matA->supermatrix ()->nrow
568+ or A->supermatrix ()->ncol != _superlu_matA->supermatrix ()->ncol )
569+ {
570+ throw std::runtime_error (
571+ " New matrix A has different size to the matrix used to construct the "
572+ " solver." );
573+ }
549574 _superlu_matA = A;
575+
576+ // See pddistribute in SuperLU_DIST: on Fact=DOFACT or Fact=SamePattern
577+ // pdgssvx overwrites LUstruct's internal pointers without freeing the
578+ // previous arrays, so a prior factorisation must be released first.
579+ // Fact=SamePattern_SameRowPerm reuses the existing arrays in place.
580+ if (fact == " DOFACT" or fact == " SamePattern" )
581+ {
582+ if (_factored)
583+ {
584+ int_t n = _superlu_matA->supermatrix ()->ncol ;
585+ if constexpr (std::is_same_v<T, double >)
586+ dDestroy_LU (n, _gridinfo.get (), _lustruct.get ());
587+ else if constexpr (std::is_same_v<T, float >)
588+ sDestroy_LU (n, _gridinfo.get (), _lustruct.get ());
589+ else if constexpr (std::is_same_v<T, std::complex <double >>)
590+ zDestroy_LU (n, _gridinfo.get (), _lustruct.get ());
591+ else
592+ static_assert (always_false_v<T>, " Invalid scalar type" );
593+ _factored = false ;
594+ }
595+ _options->Fact = (fact == " DOFACT" ) ? DOFACT : SamePattern;
596+ }
597+ else if (fact == " SamePattern_SameRowPerm" )
598+ {
599+ _options->Fact = SamePattern_SameRowPerm;
600+ }
601+ else
602+ {
603+ throw std::runtime_error (" set_A fact must be one of 'DOFACT', "
604+ " 'SamePattern', 'SamePattern_SameRowPerm'" );
605+ }
550606}
551607// ----------------------------------------------------------------------------
552608template <typename T>
553- int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
609+ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u)
554610{
555611 common::Timer tsolve (" SuperLU_DIST solve" );
556612
@@ -614,6 +670,10 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
614670 if (info != 0 )
615671 spdlog::info (" SuperLU_DIST p*gssvx() error: {}" , info);
616672
673+ // pdgssvx allocates LUstruct internals during factorisation. Record this
674+ // so the destructor / set_A can call Destroy_LU to release them.
675+ _factored = true ;
676+
617677 PStatPrint (_options.get (), &stat, _gridinfo.get ());
618678 PStatFree (&stat);
619679
0 commit comments