Skip to content

Commit 8e67252

Browse files
authored
Add block size support to SuperLU_DIST matrix (#4252)
* Add support for block size - no tests yet. * Reduce comments. * Add vector Laplacian test * Add test for non-equal block sizes * Reduce comment
1 parent 4de489c commit 8e67252

2 files changed

Lines changed: 204 additions & 21 deletions

File tree

cpp/dolfinx/la/superlu_dist.cpp

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ extern "C"
1414
#include <superlu_zdefs.h>
1515
}
1616
#include <algorithm>
17+
#include <array>
1718
#include <dolfinx/common/Timer.h>
1819
#include <dolfinx/la/MatrixCSR.h>
1920
#include <dolfinx/la/Vector.h>
2021
#include <initializer_list>
22+
#include <numeric>
2123
#include <ranges>
2224
#include <stdexcept>
2325
#include <vector>
@@ -49,25 +51,101 @@ namespace
4951
template <typename...>
5052
constexpr bool always_false_v = false;
5153

54+
// Expand MatrixCSR block column indices to flattened column indices.
5255
std::vector<int_t> col_indices(const auto& A)
5356
{
54-
// Local number of non-zeros
55-
std::int32_t m_loc = A.num_owned_rows();
56-
std::int64_t nnz_loc = A.row_ptr().at(m_loc);
57-
57+
std::array<int, 2> bs = A.block_size();
58+
std::int32_t m_loc_block = A.num_owned_rows();
59+
std::int64_t nnz_loc_block = A.row_ptr().at(m_loc_block);
5860
std::vector global_indices(A.index_map(1)->global_indices());
59-
std::vector<int_t> col_indices(nnz_loc);
60-
std::transform(A.cols().begin(), std::next(A.cols().begin(), nnz_loc),
61-
col_indices.begin(), [&global_indices](auto idx) -> int_t
62-
{ return global_indices[idx]; });
61+
62+
if (bs[0] == 1 and bs[1] == 1)
63+
{
64+
std::vector<int_t> col_indices(nnz_loc_block);
65+
std::transform(A.cols().begin(), std::next(A.cols().begin(), nnz_loc_block),
66+
col_indices.begin(), [&global_indices](auto idx) -> int_t
67+
{ return global_indices[idx]; });
68+
return col_indices;
69+
}
70+
71+
std::vector<int_t> col_indices(nnz_loc_block * bs[0] * bs[1]);
72+
const auto& A_cols = A.cols();
73+
const auto& A_rowptr = A.row_ptr();
74+
std::int64_t pos = 0;
75+
for (std::int32_t i = 0; i < m_loc_block; ++i)
76+
{
77+
for (int i0 = 0; i0 < bs[0]; ++i0)
78+
{
79+
for (std::int64_t j = A_rowptr[i]; j < A_rowptr[i + 1]; ++j)
80+
{
81+
int_t col_block = global_indices[A_cols[j]];
82+
for (int i1 = 0; i1 < bs[1]; ++i1)
83+
col_indices[pos++] = col_block * bs[1] + i1;
84+
}
85+
}
86+
}
6387
return col_indices;
6488
}
6589
//----------------------------------------------------------------------------
90+
// Expand MatrixCSR block row pointer to flattened row pointer.
6691
std::vector<int_t> row_indices(const auto& A)
6792
{
68-
return std::vector<int_t>(
69-
A.row_ptr().begin(),
70-
std::next(A.row_ptr().begin(), A.num_owned_rows() + 1));
93+
std::array<int, 2> bs = A.block_size();
94+
std::int32_t m_loc_block = A.num_owned_rows();
95+
const auto& A_rowptr = A.row_ptr();
96+
97+
if (bs[0] == 1 and bs[1] == 1)
98+
{
99+
return std::vector<int_t>(A_rowptr.begin(),
100+
std::next(A_rowptr.begin(), m_loc_block + 1));
101+
}
102+
103+
// Write the per-scalar-row entry counts into `flattened_rowptr[1:]`, with
104+
// each block-row contributing `bs[0]` copies.
105+
std::vector<int_t> flattened_rowptr(m_loc_block * bs[0] + 1);
106+
for (std::int32_t i = 0; i < m_loc_block; ++i)
107+
{
108+
int_t delta = (A_rowptr[i + 1] - A_rowptr[i]) * bs[1];
109+
std::fill_n(std::next(flattened_rowptr.begin(), 1 + i * bs[0]), bs[0],
110+
delta);
111+
}
112+
std::inclusive_scan(std::next(flattened_rowptr.begin()),
113+
flattened_rowptr.end(),
114+
std::next(flattened_rowptr.begin()));
115+
return flattened_rowptr;
116+
}
117+
//----------------------------------------------------------------------------
118+
// Expand MatrixCSR block values to flattened CSR layout.
119+
template <typename T>
120+
std::vector<T> matrix_values(const MatrixCSR<T>& A)
121+
{
122+
std::array<int, 2> bs = A.block_size();
123+
std::int32_t m_loc_block = A.num_owned_rows();
124+
std::int64_t nnz_loc_block = A.row_ptr().at(m_loc_block);
125+
126+
if (bs[0] == 1 and bs[1] == 1)
127+
{
128+
return std::vector<T>(A.values().begin(),
129+
std::next(A.values().begin(), nnz_loc_block));
130+
}
131+
132+
std::vector<T> flattened_values(nnz_loc_block * bs[0] * bs[1]);
133+
const auto& A_values = A.values();
134+
const auto& A_rowptr = A.row_ptr();
135+
std::int64_t pos = 0;
136+
for (std::int32_t i = 0; i < m_loc_block; ++i)
137+
{
138+
for (int i0 = 0; i0 < bs[0]; ++i0)
139+
{
140+
for (std::int64_t j = A_rowptr[i]; j < A_rowptr[i + 1]; ++j)
141+
{
142+
for (int i1 = 0; i1 < bs[1]; ++i1)
143+
flattened_values[pos++]
144+
= A_values[j * bs[0] * bs[1] + i0 * bs[1] + i1];
145+
}
146+
}
147+
}
148+
return flattened_values;
71149
}
72150
//----------------------------------------------------------------------------
73151
template <typename T>
@@ -78,17 +156,18 @@ create_supermatrix(const auto& A, auto& A_mat_values, auto& rowptr, auto& cols)
78156

79157
auto map0 = A.index_map(0);
80158
auto map1 = A.index_map(1);
159+
std::array<int, 2> bs = A.block_size();
81160

82-
// Global size
83-
std::int64_t m = map0->size_global();
84-
std::int64_t n = map1->size_global();
161+
// Global size (scalar, after block expansion)
162+
std::int64_t m = map0->size_global() * bs[0];
163+
std::int64_t n = map1->size_global() * bs[1];
85164
if (m != n)
86165
throw std::runtime_error("Cannot solve non-square system");
87166

88-
// Number of local rows, first row and local number of non-zeros
89-
std::int32_t m_loc = A.num_owned_rows();
90-
std::int64_t first_row = map0->local_range().front();
91-
std::int64_t nnz_loc = A.row_ptr().at(m_loc);
167+
// Number of local rows, first row and local number of non-zeros.
168+
std::int32_t m_loc = A.num_owned_rows() * bs[0];
169+
std::int64_t first_row = map0->local_range().front() * bs[0];
170+
std::int64_t nnz_loc = A.row_ptr().at(A.num_owned_rows()) * bs[0] * bs[1];
92171

93172
// Check values fit into upper range of int_t.
94173
auto check = [](std::int64_t x)
@@ -137,7 +216,7 @@ create_supermatrix(const auto& A, auto& A_mat_values, auto& rowptr, auto& cols)
137216
//----------------------------------------------------------------------------
138217
template <typename T>
139218
SuperLUDistMatrix<T>::SuperLUDistMatrix(const MatrixCSR<T>& A)
140-
: _comm(A.comm()), _matA_values(A.values()),
219+
: _comm(A.comm()), _matA_values(matrix_values(A)),
141220
_cols(std::make_unique<SuperLUDistStructs::vec_int_t>(col_indices(A))),
142221
_rowptr(std::make_unique<SuperLUDistStructs::vec_int_t>(row_indices(A))),
143222
_supermatrix(create_supermatrix<T>(A, _matA_values, *_rowptr, *_cols))

python/test/unit/la/test_superlu_dist.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import pytest
1212

1313
import dolfinx
14+
from dolfinx.common import IndexMap
15+
from dolfinx.cpp.la import SparsityPattern
1416
from dolfinx.fem import (
1517
Function,
1618
apply_lifting,
@@ -22,9 +24,9 @@
2224
functionspace,
2325
locate_dofs_topological,
2426
)
25-
from dolfinx.la import InsertMode
27+
from dolfinx.la import InsertMode, matrix_csr, vector
2628
from dolfinx.mesh import create_unit_square, exterior_facet_indices
27-
from ufl import SpatialCoordinate, TestFunction, TrialFunction, div, dx, grad, inner
29+
from ufl import SpatialCoordinate, TestFunction, TrialFunction, as_vector, div, dx, grad, inner
2830

2931

3032
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex128])
@@ -130,3 +132,105 @@ def solve_and_check(solver, b):
130132
solver_2.set_option("Fact", "SamePattern")
131133
uh_2 = solve_and_check(solver_2, b_2)
132134
check_error(u_ex, uh_2)
135+
136+
137+
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex128])
138+
@pytest.mark.skipif(not dolfinx.has_superlu_dist, reason="No SuperLU_DIST")
139+
def test_superlu_solver_blocked(dtype):
140+
"""Vector Poisson problem on a vector Lagrange space (block size 2)."""
141+
from dolfinx.la.superlu_dist import superlu_dist_matrix, superlu_dist_solver
142+
143+
mesh_dtype = dtype().real.dtype
144+
mesh = create_unit_square(MPI.COMM_WORLD, 5, 5, dtype=mesh_dtype)
145+
V = functionspace(mesh, ("Lagrange", 3, (2,)))
146+
u, v = TrialFunction(V), TestFunction(V)
147+
148+
a = form(inner(grad(u), grad(v)) * dx, dtype=dtype)
149+
150+
def u_ex(x):
151+
return np.vstack((x[1] ** 3, x[0] ** 3))
152+
153+
x = SpatialCoordinate(mesh)
154+
u_ex_ufl = as_vector((x[1] ** 3, x[0] ** 3))
155+
f = -div(grad(u_ex_ufl))
156+
L = form(inner(f, v) * dx, dtype=dtype)
157+
158+
u_bc = Function(V, dtype=dtype)
159+
u_bc.interpolate(u_ex)
160+
161+
facetdim = mesh.topology.dim - 1
162+
mesh.topology.create_connectivity(facetdim, mesh.topology.dim)
163+
bndry_facets = exterior_facet_indices(mesh.topology)
164+
bdofs = locate_dofs_topological(V, facetdim, bndry_facets)
165+
bc = dirichletbc(u_bc, bdofs)
166+
167+
b = assemble_vector(L)
168+
apply_lifting(b.array, [a], bcs=[[bc]])
169+
b.scatter_reverse(InsertMode.add)
170+
bc.set(b.array)
171+
172+
A = assemble_matrix(a, bcs=[bc])
173+
A.scatter_reverse()
174+
assert A.block_size == [2, 2]
175+
176+
A_superlu = superlu_dist_matrix(A)
177+
solver = superlu_dist_solver(A_superlu)
178+
solver.set_option("SymmetricMode", "YES")
179+
uh = Function(V, dtype=dtype)
180+
error_code = solver.solve(b, uh.x)
181+
assert error_code == 0
182+
uh.x.scatter_forward()
183+
184+
M = form(inner(u_ex_ufl - uh, u_ex_ufl - uh) * dx, dtype=dtype)
185+
error = mesh.comm.allreduce(assemble_scalar(M), op=MPI.SUM)
186+
eps = np.sqrt(np.finfo(dtype).eps)
187+
assert np.isclose(error, 0.0, atol=eps)
188+
189+
190+
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex128])
191+
@pytest.mark.skipif(not dolfinx.has_superlu_dist, reason="No SuperLU_DIST")
192+
@pytest.mark.skipif(MPI.COMM_WORLD.size > 1, reason="Hand-built single-rank matrix")
193+
def test_superlu_solver_asymmetric_blocks(dtype):
194+
"""Hand-built MatrixCSR with bs[0] = 2 and bs[1] = 3 and final size 6 x 6."""
195+
from dolfinx.la.superlu_dist import superlu_dist_matrix, superlu_dist_solver
196+
197+
bs0, bs1 = 2, 3
198+
n_row_blocks, n_col_blocks = 3, 2
199+
200+
im_row = IndexMap(MPI.COMM_WORLD, n_row_blocks)
201+
im_col = IndexMap(MPI.COMM_WORLD, n_col_blocks)
202+
sp = SparsityPattern(MPI.COMM_WORLD, [im_row, im_col], [bs0, bs1])
203+
for i in range(n_row_blocks):
204+
for j in range(n_col_blocks):
205+
sp.insert(i, j)
206+
sp.finalize()
207+
208+
A = matrix_csr(sp, dtype=dtype)
209+
assert A.block_size == [bs0, bs1]
210+
211+
rng = np.random.default_rng(0)
212+
A_dense = (np.eye(6) * 10.0 + rng.standard_normal((6, 6))).astype(dtype)
213+
214+
for i in range(n_row_blocks):
215+
for j in range(n_col_blocks):
216+
block_idx = i * n_col_blocks + j
217+
for i0 in range(bs0):
218+
for i1 in range(bs1):
219+
A.data[block_idx * bs0 * bs1 + i0 * bs1 + i1] = A_dense[
220+
i * bs0 + i0, j * bs1 + i1
221+
]
222+
223+
b_np = rng.standard_normal(6).astype(dtype)
224+
x_expected = np.linalg.solve(A_dense, b_np)
225+
226+
b = vector(im_row, bs=bs0, dtype=dtype)
227+
b.array[:] = b_np
228+
u = vector(im_col, bs=bs1, dtype=dtype)
229+
230+
A_superlu = superlu_dist_matrix(A)
231+
solver = superlu_dist_solver(A_superlu)
232+
error_code = solver.solve(b, u)
233+
assert error_code == 0
234+
u.scatter_forward()
235+
236+
assert np.allclose(u.array, x_expected)

0 commit comments

Comments
 (0)