Skip to content

Commit 233f8be

Browse files
committed
Update ball query and radius search to support more than batch size 1.
1 parent efb6abb commit 233f8be

11 files changed

Lines changed: 1259 additions & 592 deletions

File tree

physicsnemo/nn/functional/neighbors/radius_search/_torch_impl.py

Lines changed: 130 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,60 @@
2020
from .utils import format_returns
2121

2222

23+
def _validate_inputs(points: torch.Tensor, queries: torch.Tensor):
24+
"""Validate and normalize inputs to (B, N, 3) shape. Returns (points, queries, was_unbatched)."""
25+
if points.ndim == 2 and queries.ndim == 2:
26+
return points.unsqueeze(0), queries.unsqueeze(0), True
27+
elif points.ndim == 3 and queries.ndim == 3:
28+
if points.shape[0] != queries.shape[0]:
29+
raise ValueError(
30+
f"Batch dimensions must match: points has {points.shape[0]}, "
31+
f"queries has {queries.shape[0]}"
32+
)
33+
return points, queries, False
34+
else:
35+
raise ValueError(
36+
f"points and queries must be 2D (N, 3) or 3D (B, N, 3), "
37+
f"got {points.ndim}D and {queries.ndim}D"
38+
)
39+
40+
41+
def _radius_search_dynamic(
42+
points: torch.Tensor,
43+
queries: torch.Tensor,
44+
radius: float,
45+
return_dists: bool,
46+
return_points: bool,
47+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
48+
"""Single-element dynamic radius search on (N, 3) and (Q, 3) tensors.
49+
50+
Finds ALL neighbors within radius (variable-length output).
51+
"""
52+
53+
# Without the compute mode set, this is numerically unstable.
54+
dists = torch.cdist(
55+
points, queries, p=2.0, compute_mode="donot_use_mm_for_euclid_dist"
56+
)
57+
58+
selection = dists <= radius
59+
selected_indices = torch.nonzero(selection, as_tuple=False).t().contiguous()
60+
selected_indices = selected_indices[[1, 0], :]
61+
62+
if return_points:
63+
points = torch.index_select(points, 0, selected_indices[1])
64+
else:
65+
points = torch.empty(
66+
(0, points.shape[1]), device=points.device, dtype=points.dtype
67+
)
68+
69+
if return_dists:
70+
dists = dists[selection]
71+
else:
72+
dists = torch.empty((0,), device=dists.device, dtype=dists.dtype)
73+
74+
return selected_indices, points, dists
75+
76+
2377
def radius_search_impl(
2478
points: torch.Tensor,
2579
queries: torch.Tensor,
@@ -31,81 +85,97 @@ def radius_search_impl(
3185
"""
3286
Pure PyTorch implementation of the radius search.
3387
88+
Accepts both unbatched (N, 3) and batched (B, N, 3) inputs.
3489
This is a brute force implementation that is not memory efficient.
3590
"""
3691

37-
# Without the compute mode set, this is numerically unstable.
92+
points, queries, was_unbatched = _validate_inputs(points, queries)
93+
B = points.shape[0]
94+
95+
if max_points is None:
96+
# Dynamic output: loop over batch, concatenate with batch indices
97+
all_indices = []
98+
all_pts = []
99+
all_dists = []
100+
for b in range(B):
101+
idx_b, pts_b, dists_b = _radius_search_dynamic(
102+
points[b], queries[b], radius, return_dists, return_points
103+
)
104+
# idx_b is (2, count_b); prepend a batch-index row
105+
batch_row = torch.full(
106+
(1, idx_b.shape[1]), b, dtype=idx_b.dtype, device=idx_b.device
107+
)
108+
all_indices.append(torch.cat([batch_row, idx_b], dim=0))
109+
all_pts.append(pts_b)
110+
all_dists.append(dists_b)
111+
112+
selected_indices = torch.cat(all_indices, dim=1) # (3, total_count)
113+
pts_out = torch.cat(all_pts, dim=0) if return_points else all_pts[0]
114+
dists_out = torch.cat(all_dists, dim=0) if return_dists else all_dists[0]
115+
116+
if was_unbatched:
117+
# Strip the batch-index row to restore (2, count) format
118+
selected_indices = selected_indices[1:]
119+
120+
return selected_indices, pts_out, dists_out
121+
122+
# Deterministic output: fully batched via cdist + topk
123+
# dists: (B, N, Q)
38124
dists = torch.cdist(
39125
points, queries, p=2.0, compute_mode="donot_use_mm_for_euclid_dist"
40126
)
41127

42-
if max_points is None:
43-
# Find all points within radius
44-
selection = dists <= radius
45-
selected_indices = torch.nonzero(selection, as_tuple=False).t().contiguous()
46-
selected_indices = selected_indices[[1, 0], :]
128+
# topk along dim=1 (points dim): (B, max_points, Q)
129+
k = min(max_points, dists.shape[1])
130+
values, indices = torch.topk(dists, k=k, dim=1, largest=False)
47131

48-
if return_points:
49-
points = torch.index_select(points, 0, selected_indices[1])
50-
else:
51-
points = torch.empty(
52-
(0, points.shape[1]), device=points.device, dtype=points.dtype
53-
)
132+
# Pad if k < max_points (fewer points than requested)
133+
if k < max_points:
134+
pad_size = max_points - k
135+
values = torch.nn.functional.pad(values, (0, 0, 0, pad_size), value=0.0)
136+
indices = torch.nn.functional.pad(indices, (0, 0, 0, pad_size), value=0)
54137

55-
if return_dists:
56-
dists = dists[selection]
57-
else:
58-
dists = torch.empty((0,), device=dists.device, dtype=dists.dtype)
138+
# Filter to within radius: (B, max_points, Q)
139+
selection = values <= radius
140+
selected_indices = torch.where(selection, indices, 0)
141+
# Transpose to (B, Q, max_points)
142+
selected_indices = selected_indices.permute(0, 2, 1)
59143

144+
if return_dists:
145+
dists_out = torch.where(selection, values, 0).permute(0, 2, 1)
60146
else:
61-
# Take the max_points lowest distances for each query
62-
closest_points = torch.topk(
63-
dists, k=min(max_points, dists.shape[0]), dim=0, largest=False
147+
dists_out = torch.empty(0, dtype=dists.dtype, device=dists.device)
148+
149+
if return_points:
150+
# Gather points for each (batch, query, neighbor) triple
151+
# selection: (B, max_points, Q), indices: (B, max_points, Q)
152+
safe_locs = torch.where(selection)
153+
batch_loc, mp_loc, query_loc = safe_locs
154+
input_point_locs = indices[batch_loc, mp_loc, query_loc]
155+
selected_points = points[batch_loc, input_point_locs]
156+
output_points = torch.zeros(
157+
B,
158+
queries.shape[1],
159+
max_points,
160+
3,
161+
device=queries.device,
162+
dtype=points.dtype,
163+
)
164+
output_points[batch_loc, query_loc, mp_loc] = selected_points
165+
pts_out = output_points
166+
else:
167+
pts_out = torch.empty(
168+
0, max_points, 3, device=points.device, dtype=points.dtype
64169
)
65-
values, indices = closest_points
66-
# Values and indices have shape [max_points, n_queries]
67-
# The first dim of indices represents the index into input points
68-
69-
# Filter to points within radius
70-
selection = values <= radius
71-
selected_indices = torch.where(selection, indices, 0).t()
72170

171+
if was_unbatched:
172+
selected_indices = selected_indices.squeeze(0)
73173
if return_dists:
74-
dists = torch.where(selection, values, 0).t()
75-
else:
76-
dists = torch.empty(
77-
(0, values.shape[1]), device=values.device, dtype=values.dtype
78-
)
79-
174+
dists_out = dists_out.squeeze(0)
80175
if return_points:
81-
# selected_indices: (num_queries, max_points)
82-
# points: (num_points, point_dim)
83-
# We want: selected_points: (num_queries, max_points, point_dim)
84-
85-
safe_indices = torch.where(selection)
86-
max_points_loc, queries_loc = safe_indices
87-
88-
# Use these to get the input points locations:
89-
input_point_locs = indices[max_points_loc, queries_loc]
90-
selected_points = points[input_point_locs]
91-
# Construct default output points:
92-
output_points = torch.zeros(
93-
queries.shape[0],
94-
max_points,
95-
3,
96-
device=queries.device,
97-
dtype=points.dtype,
98-
)
99-
# Put the selected points in:
100-
output_points[queries_loc, max_points_loc] = selected_points
176+
pts_out = pts_out.squeeze(0)
101177

102-
points = output_points
103-
else:
104-
points = torch.empty(
105-
(0, points.shape[1]), device=points.device, dtype=points.dtype
106-
)
107-
108-
return selected_indices, points, dists
178+
return selected_indices, pts_out, dists_out
109179

110180

111181
def radius_search(
@@ -116,6 +186,7 @@ def radius_search(
116186
return_dists: bool = False,
117187
return_points: bool = False,
118188
):
189+
"""Torch-backend entry point for radius search with formatted returns."""
119190
indices, points_out, distances = radius_search_impl(
120191
points, queries, radius, max_points, return_dists, return_points
121192
)

0 commit comments

Comments
 (0)