Skip to content

Commit cbb1bf1

Browse files
authored
Batched ball query (#1602)
* Update ball query and radius search to support more than batch size 1. * Fixing and valiating distributed ball query with batched updates. * Update changelog * Fix formatting issues * Update DoMINO to enable multi-batch tests to pass. Deliberately looping an index_select over the batch index, rather than something vectorized, because we don't want to expand that op and it's domain parallelized. * Use one validate inputs function instead of 2
1 parent 4c52a45 commit cbb1bf1

17 files changed

Lines changed: 1336 additions & 633 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
9090
Flare, GeoTransolver with Flare-attention, bring your own!). Leverages
9191
mesh datasets and non-dimensionalization to enable dataset mixing and
9292
matching at runtime. Train with surface or volume data.
93+
- Added support for Batched radius search, which enables Domino
94+
and GeoTransolver with local features and batch size > 1.
9395

9496
### Changed
9597

physicsnemo/domain_parallel/shard_utils/point_cloud_ops.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def ring_ball_query(
8585
# We've already checked that the mesh is 1D so call the '0' index.
8686

8787
points_shard_sizes = points._spec.sharding_shapes()[0]
88+
points_shard_dim = points._spec.placements[0].dim
8889

8990
# Call the differentiable version of the ring-ball-query:
9091
indices_shard, outputs_shard, _, num_neighbors_shard = RingBallQuery.apply(
@@ -93,6 +94,7 @@ def ring_ball_query(
9394
mesh,
9495
ring_config,
9596
points_shard_sizes,
97+
points_shard_dim,
9698
bq_kwargs,
9799
)
98100

@@ -109,19 +111,31 @@ def ring_ball_query(
109111
outputs_shard_shapes = "infer"
110112
elif isinstance(queries._spec.placements[0], Shard):
111113
queries_shard_sizes = queries._spec.sharding_shapes()[0]
114+
q_shard_dim = queries._spec.placements[0].dim
112115

113116
# This conversion to shard tensor can be done explicitly computing the output shapes.
117+
# For batched inputs, shard sizes have a batch prefix (dims before the shard dim)
118+
# that must be preserved in the output shapes.
114119

115120
mp = indices_shard.shape[-1]
116121
d = queries.shape[-1]
117122
indices_shard_output_sharding = {
118-
0: tuple(torch.Size([s[0], mp]) for s in queries_shard_sizes),
123+
0: tuple(
124+
torch.Size([*s[:q_shard_dim], s[q_shard_dim], mp])
125+
for s in queries_shard_sizes
126+
),
119127
}
120128
num_neighbors_shard_output_sharding = {
121-
0: tuple(torch.Size([s[0]]) for s in queries_shard_sizes),
129+
0: tuple(
130+
torch.Size([*s[:q_shard_dim], s[q_shard_dim]])
131+
for s in queries_shard_sizes
132+
),
122133
}
123134
outputs_shard_output_sharding = {
124-
0: tuple(torch.Size([s[0], mp, d]) for s in queries_shard_sizes),
135+
0: tuple(
136+
torch.Size([*s[:q_shard_dim], s[q_shard_dim], mp, d])
137+
for s in queries_shard_sizes
138+
),
125139
}
126140

127141
indices_shard_shapes = indices_shard_output_sharding
@@ -197,15 +211,21 @@ def ringless_ball_query(
197211
num_neighbors_placement = {}
198212
output_points_placement = {}
199213

200-
# Output sharding should match the query shapes:
214+
# Output sharding should match the query shapes.
215+
# For batched inputs, shard sizes have a batch prefix (dims before the shard dim)
216+
# that must be preserved in the output shapes.
217+
queries_placement = queries._spec.placements[0]
218+
q_shard_dim = queries_placement.dim if queries_placement.is_shard() else 0
219+
201220
for i_dim, s in queries._spec.sharding_shapes().items():
202-
n_points = [int(_s[0]) for _s in s]
203221
indices_placement[i_dim] = tuple(
204-
torch.Size([np, max_points]) for np in n_points
222+
torch.Size([*_s[:q_shard_dim], _s[q_shard_dim], max_points]) for _s in s
223+
)
224+
num_neighbors_placement[i_dim] = tuple(
225+
torch.Size([*_s[:q_shard_dim], _s[q_shard_dim]]) for _s in s
205226
)
206-
num_neighbors_placement[i_dim] = tuple(torch.Size([np]) for np in n_points)
207227
output_points_placement[i_dim] = tuple(
208-
torch.Size([np, max_points, 3]) for np in n_points
228+
torch.Size([*_s[:q_shard_dim], _s[q_shard_dim], max_points, 3]) for _s in s
209229
)
210230

211231
indices = ShardTensor.from_local(
@@ -307,8 +327,6 @@ def merge_indices_and_points(
307327
):
308328
return incoming_indices, incoming_num_neighbors, incoming_points
309329

310-
n_points, max_neighbors = current_indices.shape
311-
312330
# This is a gather/scatter operation:
313331
# We need to merge the incoming values into the current arrays. The arrays
314332
# are essentially a ragged tensor that has been padded to a consistent shape.
@@ -320,21 +338,51 @@ def merge_indices_and_points(
320338
# - gather / scatter from incoming to current.
321339
# - Update the current num neighbors correctly
322340

341+
# The warp kernel expects 2D indices, 1D num_neighbors, 3D points.
342+
# For batched inputs (3D indices, 2D num_neighbors, 4D points),
343+
# loop over the batch dimension.
344+
batched = current_indices.ndim == 3
345+
if batched:
346+
B = current_indices.shape[0]
347+
n_points = current_indices.shape[1]
348+
max_neighbors = current_indices.shape[2]
349+
else:
350+
n_points = current_indices.shape[0]
351+
max_neighbors = current_indices.shape[1]
352+
323353
stream = wp.stream_from_torch(current_indices.device)
324-
wp.launch(
325-
merge_indices_and_points,
326-
dim=n_points,
327-
inputs=[
328-
wp.from_torch(current_indices, return_ctype=True),
329-
wp.from_torch(current_num_neighbors, return_ctype=True),
330-
wp.from_torch(current_points, return_ctype=True),
331-
wp.from_torch(incoming_indices, return_ctype=True),
332-
wp.from_torch(incoming_num_neighbors, return_ctype=True),
333-
wp.from_torch(incoming_points, return_ctype=True),
334-
max_neighbors,
335-
],
336-
stream=stream,
337-
)
354+
355+
if batched:
356+
for b in range(B):
357+
wp.launch(
358+
merge_indices_and_points,
359+
dim=n_points,
360+
inputs=[
361+
wp.from_torch(current_indices[b], return_ctype=True),
362+
wp.from_torch(current_num_neighbors[b], return_ctype=True),
363+
wp.from_torch(current_points[b], return_ctype=True),
364+
wp.from_torch(incoming_indices[b], return_ctype=True),
365+
wp.from_torch(incoming_num_neighbors[b], return_ctype=True),
366+
wp.from_torch(incoming_points[b], return_ctype=True),
367+
max_neighbors,
368+
],
369+
stream=stream,
370+
)
371+
else:
372+
wp.launch(
373+
merge_indices_and_points,
374+
dim=n_points,
375+
inputs=[
376+
wp.from_torch(current_indices, return_ctype=True),
377+
wp.from_torch(current_num_neighbors, return_ctype=True),
378+
wp.from_torch(current_points, return_ctype=True),
379+
wp.from_torch(incoming_indices, return_ctype=True),
380+
wp.from_torch(incoming_num_neighbors, return_ctype=True),
381+
wp.from_torch(incoming_points, return_ctype=True),
382+
max_neighbors,
383+
],
384+
stream=stream,
385+
)
338386

339387
return current_indices, current_num_neighbors, current_points
340388

@@ -354,6 +402,7 @@ def forward(
354402
mesh: Any,
355403
ring_config: RingPassingConfig,
356404
shard_sizes: list,
405+
shard_dim: int,
357406
bq_kwargs: Any,
358407
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
359408
r"""Forward pass for distributed ball query computation.
@@ -372,6 +421,8 @@ def forward(
372421
Configuration for ring passing.
373422
shard_sizes : list
374423
Sizes of each shard across ranks.
424+
shard_dim : int
425+
The tensor dimension along which points are sharded.
375426
bq_kwargs : Any
376427
Keyword arguments for the ball query operation.
377428
@@ -399,7 +450,7 @@ def forward(
399450
# Store results from each rank to merge in the correct order
400451
rank_results = [None] * world_size
401452
# For uneven point clouds, the global stide is important:
402-
strides = [s[0] for s in shard_sizes]
453+
strides = [s[shard_dim] for s in shard_sizes]
403454

404455
ctx.max_points = bq_kwargs["max_points"]
405456
ctx.radius = bq_kwargs["radius"]

physicsnemo/models/domino/encodings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,11 @@ def forward(
159159
for j in range(encoding_g.shape[1]):
160160
geo_encoding = rearrange(encoding_g[:, j], "b nx ny nz -> b 1 (nx ny nz)")
161161

162-
geo_encoding_sampled = torch.index_select(
163-
geo_encoding, 2, mapping.flatten()
164-
)
165-
geo_encoding_sampled = torch.reshape(geo_encoding_sampled, mask.shape)
162+
sampled = []
163+
for b in range(batch_size):
164+
s = torch.index_select(geo_encoding[b], 1, mapping[b].flatten())
165+
sampled.append(s.reshape(1, *mask.shape[1:]))
166+
geo_encoding_sampled = torch.cat(sampled, dim=0)
166167
geo_encoding_sampled = geo_encoding_sampled * mask
167168

168169
encoding_g_inner.append(geo_encoding_sampled)

physicsnemo/models/domino/geometry_rep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def forward(
195195
self.grid_resolution[1],
196196
self.grid_resolution[2],
197197
)
198-
grid = grid.reshape(1, nx * ny * nz, 3, 1)
198+
grid = grid.reshape(grid.shape[0], nx * ny * nz, 3, 1)
199199

200200
# Rearrange input to flatten spatial and neighbor dimensions
201201
x = rearrange(

physicsnemo/models/domino/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def forward(
574574

575575
# Normalize geometry coordinates based on computational domain
576576
if "volume_min_max" in data_dict.keys():
577-
vol_max = data_dict["volume_min_max"][:, 1]
578-
vol_min = data_dict["volume_min_max"][:, 0]
577+
vol_max = data_dict["volume_min_max"][:, 1].unsqueeze(1)
578+
vol_min = data_dict["volume_min_max"][:, 0].unsqueeze(1)
579579
geo_centers_vol = (
580580
2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1
581581
)
@@ -618,8 +618,8 @@ def forward(
618618
# Represent geometry on bounding box
619619
# Normalize geometry coordinates based on surface bounding box
620620
if "surface_min_max" in data_dict.keys():
621-
surf_max = data_dict["surface_min_max"][:, 1]
622-
surf_min = data_dict["surface_min_max"][:, 0]
621+
surf_max = data_dict["surface_min_max"][:, 1].unsqueeze(1)
622+
surf_min = data_dict["surface_min_max"][:, 0].unsqueeze(1)
623623
geo_centers_surf = (
624624
2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1
625625
)

0 commit comments

Comments
 (0)