@@ -14,10 +14,12 @@ extern "C"
1414#include < superlu_zdefs.h>
1515}
1616#include < algorithm>
17+ #include < array>
1718#include < dolfinx/common/Timer.h>
1819#include < dolfinx/la/MatrixCSR.h>
1920#include < dolfinx/la/Vector.h>
2021#include < initializer_list>
22+ #include < numeric>
2123#include < ranges>
2224#include < stdexcept>
2325#include < vector>
@@ -49,25 +51,101 @@ namespace
4951template <typename ...>
5052constexpr bool always_false_v = false ;
5153
54+ // Expand MatrixCSR block column indices to flattened column indices.
5255std::vector<int_t > col_indices (const auto & A)
5356{
54- // Local number of non-zeros
55- std::int32_t m_loc = A.num_owned_rows ();
56- std::int64_t nnz_loc = A.row_ptr ().at (m_loc);
57-
57+ std::array<int , 2 > bs = A.block_size ();
58+ std::int32_t m_loc_block = A.num_owned_rows ();
59+ std::int64_t nnz_loc_block = A.row_ptr ().at (m_loc_block);
5860 std::vector global_indices (A.index_map (1 )->global_indices ());
59- std::vector<int_t > col_indices (nnz_loc);
60- std::transform (A.cols ().begin (), std::next (A.cols ().begin (), nnz_loc),
61- col_indices.begin (), [&global_indices](auto idx) -> int_t
62- { return global_indices[idx]; });
61+
62+ if (bs[0 ] == 1 and bs[1 ] == 1 )
63+ {
64+ std::vector<int_t > col_indices (nnz_loc_block);
65+ std::transform (A.cols ().begin (), std::next (A.cols ().begin (), nnz_loc_block),
66+ col_indices.begin (), [&global_indices](auto idx) -> int_t
67+ { return global_indices[idx]; });
68+ return col_indices;
69+ }
70+
71+ std::vector<int_t > col_indices (nnz_loc_block * bs[0 ] * bs[1 ]);
72+ const auto & A_cols = A.cols ();
73+ const auto & A_rowptr = A.row_ptr ();
74+ std::int64_t pos = 0 ;
75+ for (std::int32_t i = 0 ; i < m_loc_block; ++i)
76+ {
77+ for (int i0 = 0 ; i0 < bs[0 ]; ++i0)
78+ {
79+ for (std::int64_t j = A_rowptr[i]; j < A_rowptr[i + 1 ]; ++j)
80+ {
81+ int_t col_block = global_indices[A_cols[j]];
82+ for (int i1 = 0 ; i1 < bs[1 ]; ++i1)
83+ col_indices[pos++] = col_block * bs[1 ] + i1;
84+ }
85+ }
86+ }
6387 return col_indices;
6488}
6589// ----------------------------------------------------------------------------
90+ // Expand MatrixCSR block row pointer to flattened row pointer.
6691std::vector<int_t > row_indices (const auto & A)
6792{
68- return std::vector<int_t >(
69- A.row_ptr ().begin (),
70- std::next (A.row_ptr ().begin (), A.num_owned_rows () + 1 ));
93+ std::array<int , 2 > bs = A.block_size ();
94+ std::int32_t m_loc_block = A.num_owned_rows ();
95+ const auto & A_rowptr = A.row_ptr ();
96+
97+ if (bs[0 ] == 1 and bs[1 ] == 1 )
98+ {
99+ return std::vector<int_t >(A_rowptr.begin (),
100+ std::next (A_rowptr.begin (), m_loc_block + 1 ));
101+ }
102+
103+ // Write the per-scalar-row entry counts into `flattened_rowptr[1:]`, with
104+ // each block-row contributing `bs[0]` copies.
105+ std::vector<int_t > flattened_rowptr (m_loc_block * bs[0 ] + 1 );
106+ for (std::int32_t i = 0 ; i < m_loc_block; ++i)
107+ {
108+ int_t delta = (A_rowptr[i + 1 ] - A_rowptr[i]) * bs[1 ];
109+ std::fill_n (std::next (flattened_rowptr.begin (), 1 + i * bs[0 ]), bs[0 ],
110+ delta);
111+ }
112+ std::inclusive_scan (std::next (flattened_rowptr.begin ()),
113+ flattened_rowptr.end (),
114+ std::next (flattened_rowptr.begin ()));
115+ return flattened_rowptr;
116+ }
117+ // ----------------------------------------------------------------------------
118+ // Expand MatrixCSR block values to flattened CSR layout.
119+ template <typename T>
120+ std::vector<T> matrix_values (const MatrixCSR<T>& A)
121+ {
122+ std::array<int , 2 > bs = A.block_size ();
123+ std::int32_t m_loc_block = A.num_owned_rows ();
124+ std::int64_t nnz_loc_block = A.row_ptr ().at (m_loc_block);
125+
126+ if (bs[0 ] == 1 and bs[1 ] == 1 )
127+ {
128+ return std::vector<T>(A.values ().begin (),
129+ std::next (A.values ().begin (), nnz_loc_block));
130+ }
131+
132+ std::vector<T> flattened_values (nnz_loc_block * bs[0 ] * bs[1 ]);
133+ const auto & A_values = A.values ();
134+ const auto & A_rowptr = A.row_ptr ();
135+ std::int64_t pos = 0 ;
136+ for (std::int32_t i = 0 ; i < m_loc_block; ++i)
137+ {
138+ for (int i0 = 0 ; i0 < bs[0 ]; ++i0)
139+ {
140+ for (std::int64_t j = A_rowptr[i]; j < A_rowptr[i + 1 ]; ++j)
141+ {
142+ for (int i1 = 0 ; i1 < bs[1 ]; ++i1)
143+ flattened_values[pos++]
144+ = A_values[j * bs[0 ] * bs[1 ] + i0 * bs[1 ] + i1];
145+ }
146+ }
147+ }
148+ return flattened_values;
71149}
72150// ----------------------------------------------------------------------------
73151template <typename T>
@@ -78,17 +156,18 @@ create_supermatrix(const auto& A, auto& A_mat_values, auto& rowptr, auto& cols)
78156
79157 auto map0 = A.index_map (0 );
80158 auto map1 = A.index_map (1 );
159+ std::array<int , 2 > bs = A.block_size ();
81160
82- // Global size
83- std::int64_t m = map0->size_global ();
84- std::int64_t n = map1->size_global ();
161+ // Global size (scalar, after block expansion)
162+ std::int64_t m = map0->size_global () * bs[ 0 ] ;
163+ std::int64_t n = map1->size_global () * bs[ 1 ] ;
85164 if (m != n)
86165 throw std::runtime_error (" Cannot solve non-square system" );
87166
88- // Number of local rows, first row and local number of non-zeros
89- std::int32_t m_loc = A.num_owned_rows ();
90- std::int64_t first_row = map0->local_range ().front ();
91- std::int64_t nnz_loc = A.row_ptr ().at (m_loc) ;
167+ // Number of local rows, first row and local number of non-zeros.
168+ std::int32_t m_loc = A.num_owned_rows () * bs[ 0 ] ;
169+ std::int64_t first_row = map0->local_range ().front () * bs[ 0 ] ;
170+ std::int64_t nnz_loc = A.row_ptr ().at (A. num_owned_rows ()) * bs[ 0 ] * bs[ 1 ] ;
92171
93172 // Check values fit into upper range of int_t.
94173 auto check = [](std::int64_t x)
@@ -137,7 +216,7 @@ create_supermatrix(const auto& A, auto& A_mat_values, auto& rowptr, auto& cols)
137216// ----------------------------------------------------------------------------
138217template <typename T>
139218SuperLUDistMatrix<T>::SuperLUDistMatrix(const MatrixCSR<T>& A)
140- : _comm(A.comm()), _matA_values(A.values( )),
219+ : _comm(A.comm()), _matA_values(matrix_values(A )),
141220 _cols (std::make_unique<SuperLUDistStructs::vec_int_t >(col_indices(A))),
142221 _rowptr(std::make_unique<SuperLUDistStructs::vec_int_t >(row_indices(A))),
143222 _supermatrix(create_supermatrix<T>(A, _matA_values, *_rowptr, *_cols))
0 commit comments