@@ -85,6 +85,7 @@ def ring_ball_query(
8585 # We've already checked that the mesh is 1D so call the '0' index.
8686
8787 points_shard_sizes = points ._spec .sharding_shapes ()[0 ]
88+ points_shard_dim = points ._spec .placements [0 ].dim
8889
8990 # Call the differentiable version of the ring-ball-query:
9091 indices_shard , outputs_shard , _ , num_neighbors_shard = RingBallQuery .apply (
@@ -93,6 +94,7 @@ def ring_ball_query(
9394 mesh ,
9495 ring_config ,
9596 points_shard_sizes ,
97+ points_shard_dim ,
9698 bq_kwargs ,
9799 )
98100
@@ -109,19 +111,31 @@ def ring_ball_query(
109111 outputs_shard_shapes = "infer"
110112 elif isinstance (queries ._spec .placements [0 ], Shard ):
111113 queries_shard_sizes = queries ._spec .sharding_shapes ()[0 ]
114+ q_shard_dim = queries ._spec .placements [0 ].dim
112115
113116 # This conversion to shard tensor can be done explicitly computing the output shapes.
117+ # For batched inputs, shard sizes have a batch prefix (dims before the shard dim)
118+ # that must be preserved in the output shapes.
114119
115120 mp = indices_shard .shape [- 1 ]
116121 d = queries .shape [- 1 ]
117122 indices_shard_output_sharding = {
118- 0 : tuple (torch .Size ([s [0 ], mp ]) for s in queries_shard_sizes ),
123+ 0 : tuple (
124+ torch .Size ([* s [:q_shard_dim ], s [q_shard_dim ], mp ])
125+ for s in queries_shard_sizes
126+ ),
119127 }
120128 num_neighbors_shard_output_sharding = {
121- 0 : tuple (torch .Size ([s [0 ]]) for s in queries_shard_sizes ),
129+ 0 : tuple (
130+ torch .Size ([* s [:q_shard_dim ], s [q_shard_dim ]])
131+ for s in queries_shard_sizes
132+ ),
122133 }
123134 outputs_shard_output_sharding = {
124- 0 : tuple (torch .Size ([s [0 ], mp , d ]) for s in queries_shard_sizes ),
135+ 0 : tuple (
136+ torch .Size ([* s [:q_shard_dim ], s [q_shard_dim ], mp , d ])
137+ for s in queries_shard_sizes
138+ ),
125139 }
126140
127141 indices_shard_shapes = indices_shard_output_sharding
@@ -197,15 +211,24 @@ def ringless_ball_query(
197211 num_neighbors_placement = {}
198212 output_points_placement = {}
199213
200- # Output sharding should match the query shapes:
214+ # Output sharding should match the query shapes.
215+ # For batched inputs, shard sizes have a batch prefix (dims before the shard dim)
216+ # that must be preserved in the output shapes.
217+ queries_placement = queries ._spec .placements [0 ]
218+ q_shard_dim = queries_placement .dim if queries_placement .is_shard () else 0
219+
201220 for i_dim , s in queries ._spec .sharding_shapes ().items ():
202- n_points = [int (_s [0 ]) for _s in s ]
203221 indices_placement [i_dim ] = tuple (
204- torch .Size ([np , max_points ]) for np in n_points
222+ torch .Size ([* _s [:q_shard_dim ], _s [q_shard_dim ], max_points ])
223+ for _s in s
224+ )
225+ num_neighbors_placement [i_dim ] = tuple (
226+ torch .Size ([* _s [:q_shard_dim ], _s [q_shard_dim ]])
227+ for _s in s
205228 )
206- num_neighbors_placement [i_dim ] = tuple (torch .Size ([np ]) for np in n_points )
207229 output_points_placement [i_dim ] = tuple (
208- torch .Size ([np , max_points , 3 ]) for np in n_points
230+ torch .Size ([* _s [:q_shard_dim ], _s [q_shard_dim ], max_points , 3 ])
231+ for _s in s
209232 )
210233
211234 indices = ShardTensor .from_local (
@@ -307,8 +330,6 @@ def merge_indices_and_points(
307330 ):
308331 return incoming_indices , incoming_num_neighbors , incoming_points
309332
310- n_points , max_neighbors = current_indices .shape
311-
312333 # This is a gather/scatter operation:
313334 # We need to merge the incoming values into the current arrays. The arrays
314335 # are essentially a ragged tensor that has been padded to a consistent shape.
@@ -320,21 +341,51 @@ def merge_indices_and_points(
320341 # - gather / scatter from incoming to current.
321342 # - Update the current num neighbors correctly
322343
344+ # The warp kernel expects 2D indices, 1D num_neighbors, 3D points.
345+ # For batched inputs (3D indices, 2D num_neighbors, 4D points),
346+ # loop over the batch dimension.
347+ batched = current_indices .ndim == 3
348+ if batched :
349+ B = current_indices .shape [0 ]
350+ n_points = current_indices .shape [1 ]
351+ max_neighbors = current_indices .shape [2 ]
352+ else :
353+ n_points = current_indices .shape [0 ]
354+ max_neighbors = current_indices .shape [1 ]
355+
323356 stream = wp .stream_from_torch (current_indices .device )
324- wp .launch (
325- merge_indices_and_points ,
326- dim = n_points ,
327- inputs = [
328- wp .from_torch (current_indices , return_ctype = True ),
329- wp .from_torch (current_num_neighbors , return_ctype = True ),
330- wp .from_torch (current_points , return_ctype = True ),
331- wp .from_torch (incoming_indices , return_ctype = True ),
332- wp .from_torch (incoming_num_neighbors , return_ctype = True ),
333- wp .from_torch (incoming_points , return_ctype = True ),
334- max_neighbors ,
335- ],
336- stream = stream ,
337- )
357+
358+ if batched :
359+ for b in range (B ):
360+ wp .launch (
361+ merge_indices_and_points ,
362+ dim = n_points ,
363+ inputs = [
364+ wp .from_torch (current_indices [b ], return_ctype = True ),
365+ wp .from_torch (current_num_neighbors [b ], return_ctype = True ),
366+ wp .from_torch (current_points [b ], return_ctype = True ),
367+ wp .from_torch (incoming_indices [b ], return_ctype = True ),
368+ wp .from_torch (incoming_num_neighbors [b ], return_ctype = True ),
369+ wp .from_torch (incoming_points [b ], return_ctype = True ),
370+ max_neighbors ,
371+ ],
372+ stream = stream ,
373+ )
374+ else :
375+ wp .launch (
376+ merge_indices_and_points ,
377+ dim = n_points ,
378+ inputs = [
379+ wp .from_torch (current_indices , return_ctype = True ),
380+ wp .from_torch (current_num_neighbors , return_ctype = True ),
381+ wp .from_torch (current_points , return_ctype = True ),
382+ wp .from_torch (incoming_indices , return_ctype = True ),
383+ wp .from_torch (incoming_num_neighbors , return_ctype = True ),
384+ wp .from_torch (incoming_points , return_ctype = True ),
385+ max_neighbors ,
386+ ],
387+ stream = stream ,
388+ )
338389
339390 return current_indices , current_num_neighbors , current_points
340391
@@ -354,6 +405,7 @@ def forward(
354405 mesh : Any ,
355406 ring_config : RingPassingConfig ,
356407 shard_sizes : list ,
408+ shard_dim : int ,
357409 bq_kwargs : Any ,
358410 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
359411 r"""Forward pass for distributed ball query computation.
@@ -372,6 +424,8 @@ def forward(
372424 Configuration for ring passing.
373425 shard_sizes : list
374426 Sizes of each shard across ranks.
427+ shard_dim : int
428+ The tensor dimension along which points are sharded.
375429 bq_kwargs : Any
376430 Keyword arguments for the ball query operation.
377431
@@ -399,7 +453,7 @@ def forward(
399453 # Store results from each rank to merge in the correct order
400454 rank_results = [None ] * world_size
401455 # For uneven point clouds, the global stide is important:
402- strides = [s [0 ] for s in shard_sizes ]
456+ strides = [s [shard_dim ] for s in shard_sizes ]
403457
404458 ctx .max_points = bq_kwargs ["max_points" ]
405459 ctx .radius = bq_kwargs ["radius" ]
0 commit comments