Skip to content

Commit ce9d7a4

Browse files
committed
Fixing and valiating distributed ball query with batched updates.
1 parent 233f8be commit ce9d7a4

2 files changed

Lines changed: 84 additions & 30 deletions

File tree

physicsnemo/domain_parallel/shard_utils/point_cloud_ops.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def ring_ball_query(
8585
# We've already checked that the mesh is 1D so call the '0' index.
8686

8787
points_shard_sizes = points._spec.sharding_shapes()[0]
88+
points_shard_dim = points._spec.placements[0].dim
8889

8990
# Call the differentiable version of the ring-ball-query:
9091
indices_shard, outputs_shard, _, num_neighbors_shard = RingBallQuery.apply(
@@ -93,6 +94,7 @@ def ring_ball_query(
9394
mesh,
9495
ring_config,
9596
points_shard_sizes,
97+
points_shard_dim,
9698
bq_kwargs,
9799
)
98100

@@ -109,19 +111,31 @@ def ring_ball_query(
109111
outputs_shard_shapes = "infer"
110112
elif isinstance(queries._spec.placements[0], Shard):
111113
queries_shard_sizes = queries._spec.sharding_shapes()[0]
114+
q_shard_dim = queries._spec.placements[0].dim
112115

113116
# This conversion to shard tensor can be done explicitly computing the output shapes.
117+
# For batched inputs, shard sizes have a batch prefix (dims before the shard dim)
118+
# that must be preserved in the output shapes.
114119

115120
mp = indices_shard.shape[-1]
116121
d = queries.shape[-1]
117122
indices_shard_output_sharding = {
118-
0: tuple(torch.Size([s[0], mp]) for s in queries_shard_sizes),
123+
0: tuple(
124+
torch.Size([*s[:q_shard_dim], s[q_shard_dim], mp])
125+
for s in queries_shard_sizes
126+
),
119127
}
120128
num_neighbors_shard_output_sharding = {
121-
0: tuple(torch.Size([s[0]]) for s in queries_shard_sizes),
129+
0: tuple(
130+
torch.Size([*s[:q_shard_dim], s[q_shard_dim]])
131+
for s in queries_shard_sizes
132+
),
122133
}
123134
outputs_shard_output_sharding = {
124-
0: tuple(torch.Size([s[0], mp, d]) for s in queries_shard_sizes),
135+
0: tuple(
136+
torch.Size([*s[:q_shard_dim], s[q_shard_dim], mp, d])
137+
for s in queries_shard_sizes
138+
),
125139
}
126140

127141
indices_shard_shapes = indices_shard_output_sharding
@@ -197,15 +211,24 @@ def ringless_ball_query(
197211
num_neighbors_placement = {}
198212
output_points_placement = {}
199213

200-
# Output sharding should match the query shapes:
214+
# Output sharding should match the query shapes.
215+
# For batched inputs, shard sizes have a batch prefix (dims before the shard dim)
216+
# that must be preserved in the output shapes.
217+
queries_placement = queries._spec.placements[0]
218+
q_shard_dim = queries_placement.dim if queries_placement.is_shard() else 0
219+
201220
for i_dim, s in queries._spec.sharding_shapes().items():
202-
n_points = [int(_s[0]) for _s in s]
203221
indices_placement[i_dim] = tuple(
204-
torch.Size([np, max_points]) for np in n_points
222+
torch.Size([*_s[:q_shard_dim], _s[q_shard_dim], max_points])
223+
for _s in s
224+
)
225+
num_neighbors_placement[i_dim] = tuple(
226+
torch.Size([*_s[:q_shard_dim], _s[q_shard_dim]])
227+
for _s in s
205228
)
206-
num_neighbors_placement[i_dim] = tuple(torch.Size([np]) for np in n_points)
207229
output_points_placement[i_dim] = tuple(
208-
torch.Size([np, max_points, 3]) for np in n_points
230+
torch.Size([*_s[:q_shard_dim], _s[q_shard_dim], max_points, 3])
231+
for _s in s
209232
)
210233

211234
indices = ShardTensor.from_local(
@@ -307,8 +330,6 @@ def merge_indices_and_points(
307330
):
308331
return incoming_indices, incoming_num_neighbors, incoming_points
309332

310-
n_points, max_neighbors = current_indices.shape
311-
312333
# This is a gather/scatter operation:
313334
# We need to merge the incoming values into the current arrays. The arrays
314335
# are essentially a ragged tensor that has been padded to a consistent shape.
@@ -320,21 +341,51 @@ def merge_indices_and_points(
320341
# - gather / scatter from incoming to current.
321342
# - Update the current num neighbors correctly
322343

344+
# The warp kernel expects 2D indices, 1D num_neighbors, 3D points.
345+
# For batched inputs (3D indices, 2D num_neighbors, 4D points),
346+
# loop over the batch dimension.
347+
batched = current_indices.ndim == 3
348+
if batched:
349+
B = current_indices.shape[0]
350+
n_points = current_indices.shape[1]
351+
max_neighbors = current_indices.shape[2]
352+
else:
353+
n_points = current_indices.shape[0]
354+
max_neighbors = current_indices.shape[1]
355+
323356
stream = wp.stream_from_torch(current_indices.device)
324-
wp.launch(
325-
merge_indices_and_points,
326-
dim=n_points,
327-
inputs=[
328-
wp.from_torch(current_indices, return_ctype=True),
329-
wp.from_torch(current_num_neighbors, return_ctype=True),
330-
wp.from_torch(current_points, return_ctype=True),
331-
wp.from_torch(incoming_indices, return_ctype=True),
332-
wp.from_torch(incoming_num_neighbors, return_ctype=True),
333-
wp.from_torch(incoming_points, return_ctype=True),
334-
max_neighbors,
335-
],
336-
stream=stream,
337-
)
357+
358+
if batched:
359+
for b in range(B):
360+
wp.launch(
361+
merge_indices_and_points,
362+
dim=n_points,
363+
inputs=[
364+
wp.from_torch(current_indices[b], return_ctype=True),
365+
wp.from_torch(current_num_neighbors[b], return_ctype=True),
366+
wp.from_torch(current_points[b], return_ctype=True),
367+
wp.from_torch(incoming_indices[b], return_ctype=True),
368+
wp.from_torch(incoming_num_neighbors[b], return_ctype=True),
369+
wp.from_torch(incoming_points[b], return_ctype=True),
370+
max_neighbors,
371+
],
372+
stream=stream,
373+
)
374+
else:
375+
wp.launch(
376+
merge_indices_and_points,
377+
dim=n_points,
378+
inputs=[
379+
wp.from_torch(current_indices, return_ctype=True),
380+
wp.from_torch(current_num_neighbors, return_ctype=True),
381+
wp.from_torch(current_points, return_ctype=True),
382+
wp.from_torch(incoming_indices, return_ctype=True),
383+
wp.from_torch(incoming_num_neighbors, return_ctype=True),
384+
wp.from_torch(incoming_points, return_ctype=True),
385+
max_neighbors,
386+
],
387+
stream=stream,
388+
)
338389

339390
return current_indices, current_num_neighbors, current_points
340391

@@ -354,6 +405,7 @@ def forward(
354405
mesh: Any,
355406
ring_config: RingPassingConfig,
356407
shard_sizes: list,
408+
shard_dim: int,
357409
bq_kwargs: Any,
358410
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
359411
r"""Forward pass for distributed ball query computation.
@@ -372,6 +424,8 @@ def forward(
372424
Configuration for ring passing.
373425
shard_sizes : list
374426
Sizes of each shard across ranks.
427+
shard_dim : int
428+
The tensor dimension along which points are sharded.
375429
bq_kwargs : Any
376430
Keyword arguments for the ball query operation.
377431
@@ -399,7 +453,7 @@ def forward(
399453
# Store results from each rank to merge in the correct order
400454
rank_results = [None] * world_size
401455
# For uneven point clouds, the global stide is important:
402-
strides = [s[0] for s in shard_sizes]
456+
strides = [s[shard_dim] for s in shard_sizes]
403457

404458
ctx.max_points = bq_kwargs["max_points"]
405459
ctx.radius = bq_kwargs["radius"]

test/domain_parallel/ops/test_radius_search.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ def run_radius_search_module(model, data_dict, reverse_mapping):
105105
# Bounding box grid
106106
s_grid = data_dict["surf_grid"]
107107

108-
# Scaling factors
109-
surf_max = data_dict["surface_min_max"][:, 1]
110-
surf_min = data_dict["surface_min_max"][:, 0]
108+
# Scaling factors -- unsqueeze to (B, 1, 3) for broadcasting with (B, N, 3)
109+
surf_max = data_dict["surface_min_max"][:, 1].unsqueeze(1)
110+
surf_min = data_dict["surface_min_max"][:, 0].unsqueeze(1)
111111

112112
# Normalize based on BBox around surface (car)
113113
geo_centers_surf = 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1
@@ -121,8 +121,9 @@ def run_radius_search_module(model, data_dict, reverse_mapping):
121121
@pytest.mark.parametrize("shard_points", [True, False])
122122
@pytest.mark.parametrize("shard_grid", [True, False])
123123
@pytest.mark.parametrize("reverse_mapping", [True])
124+
@pytest.mark.parametrize("bsize", [1, 2])
124125
def test_sharded_radius_search_layer_forward(
125-
distributed_mesh, shard_points, shard_grid, reverse_mapping
126+
distributed_mesh, shard_points, shard_grid, reverse_mapping, bsize
126127
):
127128
from physicsnemo.nn import BQWarp
128129

@@ -131,7 +132,6 @@ def test_sharded_radius_search_layer_forward(
131132
device = dm.device
132133

133134
# Create the input dict:
134-
bsize = 1
135135
npoints = 8 * 17
136136
nx, ny, nz = 8 * 12, 6, 4
137137
# This is pretty aggressive, it'd never actually be this many.

0 commit comments

Comments
 (0)