Skip to content

Commit ec67bd3

Browse files
authored
Introduce threading for collision detection (#4123)
* compute distances for multiple bodies towards single convex hull * Fix tolerance in GJK algorithm and remove looping in tests * Try to easen tolerance a bit Co-authored-by: Jørgen Schartum Dokken <dokken92@gmail.com> * Fix tolerance and simplify chunking * Stricter tolerance again * Add new termination criterion * Use <=
1 parent 428e0f0 commit ec67bd3

5 files changed

Lines changed: 179 additions & 44 deletions

File tree

cpp/dolfinx/geometry/gjk.h

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ std::array<T, 3> compute_distance_gjk(std::span<const T> p0,
318318
constexpr int maxk = 15; // Maximum number of iterations of the GJK algorithm
319319

320320
// Tolerance
321-
const U eps = 1.0e4 * std::numeric_limits<T>::epsilon();
321+
const U eps = 1000 * std::numeric_limits<U>::epsilon();
322322

323323
// Initialise vector and simplex
324324
std::array<U, 3> v = {p[0] - q[0], p[1] - q[1], p[2] - q[2]};
@@ -382,16 +382,75 @@ std::array<T, 3> compute_distance_gjk(std::span<const T> p0,
382382
SPDLOG_DEBUG("new s size={}", 3 * j);
383383
s.resize(3 * j);
384384

385+
// 2nd exit condition - strict monotonicity
386+
// Floating point can cause the algorithm to stagnate. Then we terminate.
385387
const U vn = impl_gjk::dot3(v, v);
386-
// 2nd exit condition - intersecting or touching
388+
if (vnorm2 <= vn)
389+
break;
390+
391+
// 3nd exit condition - intersecting or touching
387392
if (vn < eps * eps)
388393
break;
389394
}
390395

391396
if (k == maxk)
392397
throw std::runtime_error("GJK error - max iteration limit reached");
393-
394398
return {static_cast<T>(v[0]), static_cast<T>(v[1]), static_cast<T>(v[2])};
395399
}
396400

401+
/// @brief Compute the distance between a sequence of convex bodies `p0, ...,
402+
/// pN` and `q`, each defined by a set of points.
403+
///
404+
/// Uses the Gilbert–Johnson–Keerthi (GJK) distance algorithm.
405+
///
406+
/// @param[in] bodies List of the list of points that make up each of N bodies
407+
/// considered as body 1. `shape=(num_bodies, (num_points_body_j, 3)`. Row-major
408+
/// storage.
409+
/// @param[in] q Body 2 list of points, `shape=(num_points, 3)`. Row-major
410+
/// storage.
411+
/// @tparam T Floating point type
412+
/// @tparam U Floating point type used for geometry computations internally,
413+
/// which should be higher precision than T, to maintain accuracy.
414+
/// @return For each body in `p_j`, return the shortest distance vector to
415+
/// body 2. Shape (num_points, 3).
416+
template <std::floating_point T,
417+
typename U = boost::multiprecision::cpp_bin_float_double_extended>
418+
std::vector<T>
419+
compute_distances_gjk(const std::vector<std::span<const T>>& bodies,
420+
std::span<const T> q, size_t num_threads)
421+
{
422+
size_t total_size = bodies.size();
423+
num_threads = std::min(num_threads, total_size);
424+
425+
std::vector<T> results(total_size * 3);
426+
auto compute_chunk
427+
= [&results, &bodies](size_t c0, size_t c1, std::span<const T> q_ref)
428+
{
429+
for (size_t i = c0; i < c1; ++i)
430+
{
431+
// Using U explicitly as the internal precision type
432+
std::array<T, 3> dist = compute_distance_gjk<T, U>(bodies[i], q_ref);
433+
results[3 * i + 0] = dist[0];
434+
results[3 * i + 1] = dist[1];
435+
results[3 * i + 2] = dist[2];
436+
}
437+
};
438+
439+
if (num_threads <= 1)
440+
{
441+
compute_chunk(0, total_size, q);
442+
}
443+
else
444+
{
445+
std::vector<std::jthread> threads(num_threads);
446+
for (size_t i = 0; i < num_threads; ++i)
447+
{
448+
auto [c0, c1] = dolfinx::MPI::local_range(i, total_size, num_threads);
449+
threads[i] = std::jthread(compute_chunk, c0, c1, q);
450+
}
451+
}
452+
453+
return results;
454+
}
455+
397456
} // namespace dolfinx::geometry

cpp/dolfinx/geometry/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ determine_point_ownership(const mesh::Mesh<T>& mesh, std::span<const T> points,
872872
std::array<T, 3> point;
873873
std::copy_n(std::next(received_points.begin(), 3 * i), 3, point.begin());
874874

875-
// Find shortest distance among cells with colldiing bounding box
875+
// Find shortest distance among cells with colliding bounding box
876876
T shortest_distance = std::numeric_limits<T>::max();
877877
std::int32_t closest_cell = -1;
878878
for (auto cell : candidate_collisions.links(i))

python/dolfinx/geometry.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"compute_collisions_points",
2929
"compute_collisions_trees",
3030
"compute_distance_gjk",
31+
"compute_distances_gjk",
3132
"create_midpoint_tree",
3233
"determine_point_ownership",
3334
"squared_distance",
@@ -290,6 +291,34 @@ def compute_distance_gjk(
290291
raise RuntimeError("Invalid dtype in compute_distance_gjk")
291292

292293

294+
def compute_distances_gjk(
295+
bodies: list[npt.NDArray[np.floating]], q: npt.NDArray[np.floating], num_threads: int
296+
) -> npt.NDArray[np.floating]:
297+
"""Compute the distance between a set of convex bodies.
298+
299+
For each convex body defined in `bodies`;
300+
(a set of 3D points for each body) find the shortest distance vector
301+
to to the body `q` defined by another set of 3D points.
302+
The method uses the
303+
Gilbert-Johnson-Keerthi (GJK) distance algorithm.
304+
305+
Args:
306+
bodies: List of bodies, where each body is an array of
307+
(``shape=(num_points_i, 3, gdim)``).
308+
q: Body 2 list of points (``shape=(num_points_2, 3)``).
309+
num_threads: Number of threads to use for GJK computation.
310+
311+
Returns:
312+
Shortest vector between the two bodies.
313+
"""
314+
assert all([p.dtype == q.dtype for p in bodies])
315+
if np.issubdtype(q.dtype, np.float32):
316+
return _cpp.geometry.compute_distances_gjk_float32(bodies, q, num_threads)
317+
elif np.issubdtype(q.dtype, np.float64):
318+
return _cpp.geometry.compute_distances_gjk_float64(bodies, q, num_threads)
319+
raise RuntimeError("Invalid dtype in compute_distances_gjk")
320+
321+
293322
def determine_point_ownership(
294323
mesh: Mesh,
295324
points: npt.NDArray[np.floating],

python/dolfinx/wrappers/dolfinx_wrappers/geometry.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2017-2025 Chris N. Richardson and Garth N. Wells
1+
// Copyright (C) 2017-2026 Chris N. Richardson and Garth N. Wells
22
//
33
// This file is part of DOLFINx (https://www.fenicsproject.org)
44
//
@@ -18,6 +18,7 @@
1818
#include <nanobind/nanobind.h>
1919
#include <nanobind/ndarray.h>
2020
#include <nanobind/stl/optional.h>
21+
#include <nanobind/stl/vector.h>
2122
#include <optional>
2223
#include <span>
2324
#include <string>
@@ -196,6 +197,40 @@ void declare_bbtree(nb::module_& m, const std::string& type)
196197
},
197198
nb::arg("p"), nb::arg("q"));
198199

200+
std::string gjks_name = "compute_distances_gjk_" + type;
201+
m.def(
202+
gjks_name.c_str(),
203+
[](const std::vector<nb::ndarray<const T, nb::c_contig>>& bodies,
204+
nb::ndarray<const T, nb::c_contig> q, size_t num_threads)
205+
{
206+
// If array is 1D assume single point
207+
std::size_t q_s0 = q.ndim() == 1 ? 1 : q.shape(0);
208+
std::span<const T> _q(q.data(), 3 * q_s0);
209+
210+
std::vector<std::span<const T>> _bodies;
211+
_bodies.reserve(bodies.size());
212+
213+
std::ranges::transform(
214+
bodies, std::back_inserter(_bodies),
215+
[](auto& body)
216+
{
217+
// If sub-array in 1D assume single point
218+
std::size_t body_s0 = body.ndim() == 1 ? 1 : body.shape(0);
219+
return std::span<const T>(body.data(), 3 * body_s0);
220+
});
221+
222+
using U = typename std::conditional<
223+
std::is_same_v<T, float>, double,
224+
boost::multiprecision::cpp_bin_float_double_extended>::type;
225+
226+
std::vector<T> distances
227+
= dolfinx::geometry::compute_distances_gjk<T, U>(_bodies, _q,
228+
num_threads);
229+
return dolfinx_wrappers::as_nbarray(std::move(distances),
230+
{distances.size() / 3, 3});
231+
},
232+
nb::arg("bodies"), nb::arg("q"), nb::arg("num_threads"));
233+
199234
m.def(
200235
"squared_distance",
201236
[](const dolfinx::mesh::Mesh<T>& mesh, int dim,

python/test/unit/geometry/test_bounding_box_tree.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2013-2025 Anders Logg, Jørgen S. Dokken, Chris Richardson
1+
# Copyright (C) 2013-2026 Anders Logg, Jørgen S. Dokken, Chris Richardson
22
#
33
# This file is part of DOLFINx (https://www.fenicsproject.org)
44
#
@@ -17,6 +17,7 @@
1717
compute_collisions_points,
1818
compute_collisions_trees,
1919
compute_distance_gjk,
20+
compute_distances_gjk,
2021
create_midpoint_tree,
2122
determine_point_ownership,
2223
)
@@ -50,47 +51,47 @@ def extract_geometricial_data(mesh, dim, entities):
5051
return mesh_nodes
5152

5253

53-
def expand_bbox(bbox, dtype):
54-
"""Expand min max bbox to convex hull."""
55-
return np.array(
56-
[
57-
[bbox[0][0], bbox[0][1], bbox[0][2]],
58-
[bbox[0][0], bbox[0][1], bbox[1][2]],
59-
[bbox[0][0], bbox[1][1], bbox[0][2]],
60-
[bbox[1][0], bbox[0][1], bbox[0][2]],
61-
[bbox[1][0], bbox[0][1], bbox[1][2]],
62-
[bbox[1][0], bbox[1][1], bbox[0][2]],
63-
[bbox[0][0], bbox[1][1], bbox[1][2]],
64-
[bbox[1][0], bbox[1][1], bbox[1][2]],
65-
],
66-
dtype=dtype,
67-
)
54+
def expand_bboxes(bboxes, dtype):
55+
"""Expand an array of min/max bboxes to convex hulls."""
56+
if len(bboxes.shape) == 2:
57+
bboxes = bboxes.reshape(1, *bboxes.shape)
58+
idx_x = [0, 0, 0, 1, 1, 1, 0, 1]
59+
idx_y = [0, 0, 1, 0, 0, 1, 1, 1]
60+
idx_z = [0, 1, 0, 0, 1, 0, 1, 1]
61+
62+
x = bboxes[:, idx_x, 0]
63+
y = bboxes[:, idx_y, 1]
64+
z = bboxes[:, idx_z, 2]
65+
return np.stack((x, y, z), axis=-1).astype(dtype)
6866

6967

70-
def find_colliding_cells(mesh, bbox, dtype):
68+
def find_colliding_cells(mesh, bbox, dtype, num_threads):
7169
"""Given a mesh and a bounding box((xmin, ymin, zmin), (xmax, ymax,
7270
zmax)) find all colliding cells.
7371
"""
7472
# Find actual cells using known bounding box tree
75-
colliding_cells = []
7673
num_cells = mesh.topology.index_map(mesh.topology.dim).size_local
7774
x_indices = entities_to_geometry(
7875
mesh, mesh.topology.dim, np.arange(num_cells, dtype=np.int32), False
7976
)
8077
points = mesh.geometry.x
81-
bounding_box = expand_bbox(bbox, dtype)
82-
for cell in range(num_cells):
83-
vertex_coords = points[x_indices[cell]]
84-
bbox_cell = np.array([vertex_coords[0], vertex_coords[0]])
85-
# Create bounding box for cell
86-
for i in range(1, vertex_coords.shape[0]):
87-
for j in range(3):
88-
bbox_cell[0, j] = min(bbox_cell[0, j], vertex_coords[i, j])
89-
bbox_cell[1, j] = max(bbox_cell[1, j], vertex_coords[i, j])
90-
distance = compute_distance_gjk(expand_bbox(bbox_cell, dtype), bounding_box)
91-
if np.dot(distance, distance) < 1e-16:
92-
colliding_cells.append(cell)
93-
78+
bounding_box = expand_bboxes(bbox, dtype)[0]
79+
80+
# Pack the data for each of the cell bounding boxes
81+
# First pack min and max values for each cell
82+
min_in_cell = np.min(points[x_indices], axis=1)
83+
max_in_cell = np.max(points[x_indices], axis=1)
84+
bboxes = np.zeros((num_cells, 2, 3))
85+
bboxes[:, 0, :] = min_in_cell
86+
bboxes[:, 1, :] = max_in_cell
87+
# Expand min and max values to bounding box
88+
body_1 = list(expand_bboxes(bboxes, dtype))
89+
90+
# Compute distances and check for collision
91+
distances = compute_distances_gjk(body_1, bounding_box, num_threads)
92+
norm_dist = np.linalg.norm(distances, axis=1) ** 2
93+
tol = 10 * np.finfo(dtype).eps
94+
colliding_cells = np.flatnonzero(norm_dist < tol).astype(np.int32)
9495
return colliding_cells
9596

9697

@@ -224,9 +225,10 @@ def locator_B(x):
224225

225226

226227
@pytest.mark.skip_in_parallel
228+
@pytest.mark.parametrize("num_threads", [1, 2])
227229
@pytest.mark.parametrize("point", [np.array([0.52, 0.51, 0.0]), np.array([0.9, -0.9, 0.0])])
228230
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
229-
def test_compute_collisions_tree_2d(point, dtype):
231+
def test_compute_collisions_tree_2d(point, dtype, num_threads):
230232
mesh_A = create_unit_square(MPI.COMM_WORLD, 3, 3, dtype=dtype)
231233
mesh_B = create_unit_square(MPI.COMM_WORLD, 5, 5, dtype=dtype)
232234
bgeom = mesh_B.geometry.x
@@ -237,18 +239,24 @@ def test_compute_collisions_tree_2d(point, dtype):
237239

238240
entities_A = np.sort(np.unique([q[0] for q in entities]))
239241
entities_B = np.sort(np.unique([q[1] for q in entities]))
240-
cells_A = find_colliding_cells(mesh_A, tree_B.get_bbox(tree_B.num_bboxes - 1), dtype)
241-
cells_B = find_colliding_cells(mesh_B, tree_A.get_bbox(tree_A.num_bboxes - 1), dtype)
242+
cells_A = find_colliding_cells(
243+
mesh_A, tree_B.get_bbox(tree_B.num_bboxes - 1), dtype, num_threads
244+
)
245+
cells_B = find_colliding_cells(
246+
mesh_B, tree_A.get_bbox(tree_A.num_bboxes - 1), dtype, num_threads
247+
)
242248
assert np.allclose(entities_A, cells_A)
243249
assert np.allclose(entities_B, cells_B)
244250

245251

246252
@pytest.mark.skip_in_parallel
253+
@pytest.mark.parametrize("num_threads", [1, 2])
247254
@pytest.mark.parametrize("point", [np.array([0.52, 0.51, 0.3]), np.array([0.9, -0.9, 0.3])])
248255
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
249-
def test_compute_collisions_tree_3d(point, dtype):
250-
mesh_A = create_unit_cube(MPI.COMM_WORLD, 2, 2, 2, dtype=dtype)
251-
mesh_B = create_unit_cube(MPI.COMM_WORLD, 2, 2, 2, dtype=dtype)
256+
def test_compute_collisions_tree_3d(point, dtype, num_threads):
257+
M = 10
258+
mesh_A = create_unit_cube(MPI.COMM_WORLD, M, M, M, dtype=dtype)
259+
mesh_B = create_unit_cube(MPI.COMM_WORLD, M, M, M, dtype=dtype)
252260

253261
bgeom = mesh_B.geometry.x
254262
bgeom += point
@@ -258,8 +266,12 @@ def test_compute_collisions_tree_3d(point, dtype):
258266
entities = compute_collisions_trees(tree_A, tree_B)
259267
entities_A = np.sort(np.unique([q[0] for q in entities]))
260268
entities_B = np.sort(np.unique([q[1] for q in entities]))
261-
cells_A = find_colliding_cells(mesh_A, tree_B.get_bbox(tree_B.num_bboxes - 1), dtype)
262-
cells_B = find_colliding_cells(mesh_B, tree_A.get_bbox(tree_A.num_bboxes - 1), dtype)
269+
cells_A = find_colliding_cells(
270+
mesh_A, tree_B.get_bbox(tree_B.num_bboxes - 1), dtype, num_threads
271+
)
272+
cells_B = find_colliding_cells(
273+
mesh_B, tree_A.get_bbox(tree_A.num_bboxes - 1), dtype, num_threads
274+
)
263275
assert np.allclose(entities_A, cells_A)
264276
assert np.allclose(entities_B, cells_B)
265277

0 commit comments

Comments
 (0)