2020from .utils import format_returns
2121
2222
23+ def _validate_inputs (points : torch .Tensor , queries : torch .Tensor ):
24+ """Validate and normalize inputs to (B, N, 3) shape. Returns (points, queries, was_unbatched)."""
25+ if points .ndim == 2 and queries .ndim == 2 :
26+ return points .unsqueeze (0 ), queries .unsqueeze (0 ), True
27+ elif points .ndim == 3 and queries .ndim == 3 :
28+ if points .shape [0 ] != queries .shape [0 ]:
29+ raise ValueError (
30+ f"Batch dimensions must match: points has { points .shape [0 ]} , "
31+ f"queries has { queries .shape [0 ]} "
32+ )
33+ return points , queries , False
34+ else :
35+ raise ValueError (
36+ f"points and queries must be 2D (N, 3) or 3D (B, N, 3), "
37+ f"got { points .ndim } D and { queries .ndim } D"
38+ )
39+
40+
41+ def _radius_search_dynamic (
42+ points : torch .Tensor ,
43+ queries : torch .Tensor ,
44+ radius : float ,
45+ return_dists : bool ,
46+ return_points : bool ,
47+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
48+ """Single-element dynamic radius search on (N, 3) and (Q, 3) tensors.
49+
50+ Finds ALL neighbors within radius (variable-length output).
51+ """
52+
53+ # Without the compute mode set, this is numerically unstable.
54+ dists = torch .cdist (
55+ points , queries , p = 2.0 , compute_mode = "donot_use_mm_for_euclid_dist"
56+ )
57+
58+ selection = dists <= radius
59+ selected_indices = torch .nonzero (selection , as_tuple = False ).t ().contiguous ()
60+ selected_indices = selected_indices [[1 , 0 ], :]
61+
62+ if return_points :
63+ points = torch .index_select (points , 0 , selected_indices [1 ])
64+ else :
65+ points = torch .empty (
66+ (0 , points .shape [1 ]), device = points .device , dtype = points .dtype
67+ )
68+
69+ if return_dists :
70+ dists = dists [selection ]
71+ else :
72+ dists = torch .empty ((0 ,), device = dists .device , dtype = dists .dtype )
73+
74+ return selected_indices , points , dists
75+
76+
2377def radius_search_impl (
2478 points : torch .Tensor ,
2579 queries : torch .Tensor ,
@@ -31,81 +85,97 @@ def radius_search_impl(
3185 """
3286 Pure PyTorch implementation of the radius search.
3387
88+ Accepts both unbatched (N, 3) and batched (B, N, 3) inputs.
3489 This is a brute force implementation that is not memory efficient.
3590 """
3691
37- # Without the compute mode set, this is numerically unstable.
92+ points , queries , was_unbatched = _validate_inputs (points , queries )
93+ B = points .shape [0 ]
94+
95+ if max_points is None :
96+ # Dynamic output: loop over batch, concatenate with batch indices
97+ all_indices = []
98+ all_pts = []
99+ all_dists = []
100+ for b in range (B ):
101+ idx_b , pts_b , dists_b = _radius_search_dynamic (
102+ points [b ], queries [b ], radius , return_dists , return_points
103+ )
104+ # idx_b is (2, count_b); prepend a batch-index row
105+ batch_row = torch .full (
106+ (1 , idx_b .shape [1 ]), b , dtype = idx_b .dtype , device = idx_b .device
107+ )
108+ all_indices .append (torch .cat ([batch_row , idx_b ], dim = 0 ))
109+ all_pts .append (pts_b )
110+ all_dists .append (dists_b )
111+
112+ selected_indices = torch .cat (all_indices , dim = 1 ) # (3, total_count)
113+ pts_out = torch .cat (all_pts , dim = 0 ) if return_points else all_pts [0 ]
114+ dists_out = torch .cat (all_dists , dim = 0 ) if return_dists else all_dists [0 ]
115+
116+ if was_unbatched :
117+ # Strip the batch-index row to restore (2, count) format
118+ selected_indices = selected_indices [1 :]
119+
120+ return selected_indices , pts_out , dists_out
121+
122+ # Deterministic output: fully batched via cdist + topk
123+ # dists: (B, N, Q)
38124 dists = torch .cdist (
39125 points , queries , p = 2.0 , compute_mode = "donot_use_mm_for_euclid_dist"
40126 )
41127
42- if max_points is None :
43- # Find all points within radius
44- selection = dists <= radius
45- selected_indices = torch .nonzero (selection , as_tuple = False ).t ().contiguous ()
46- selected_indices = selected_indices [[1 , 0 ], :]
128+ # topk along dim=1 (points dim): (B, max_points, Q)
129+ k = min (max_points , dists .shape [1 ])
130+ values , indices = torch .topk (dists , k = k , dim = 1 , largest = False )
47131
48- if return_points :
49- points = torch .index_select (points , 0 , selected_indices [1 ])
50- else :
51- points = torch .empty (
52- (0 , points .shape [1 ]), device = points .device , dtype = points .dtype
53- )
132+ # Pad if k < max_points (fewer points than requested)
133+ if k < max_points :
134+ pad_size = max_points - k
135+ values = torch .nn .functional .pad (values , (0 , 0 , 0 , pad_size ), value = 0.0 )
136+ indices = torch .nn .functional .pad (indices , (0 , 0 , 0 , pad_size ), value = 0 )
54137
55- if return_dists :
56- dists = dists [selection ]
57- else :
58- dists = torch .empty ((0 ,), device = dists .device , dtype = dists .dtype )
138+ # Filter to within radius: (B, max_points, Q)
139+ selection = values <= radius
140+ selected_indices = torch .where (selection , indices , 0 )
141+ # Transpose to (B, Q, max_points)
142+ selected_indices = selected_indices .permute (0 , 2 , 1 )
59143
144+ if return_dists :
145+ dists_out = torch .where (selection , values , 0 ).permute (0 , 2 , 1 )
60146 else :
61- # Take the max_points lowest distances for each query
62- closest_points = torch .topk (
63- dists , k = min (max_points , dists .shape [0 ]), dim = 0 , largest = False
147+ dists_out = torch .empty (0 , dtype = dists .dtype , device = dists .device )
148+
149+ if return_points :
150+ # Gather points for each (batch, query, neighbor) triple
151+ # selection: (B, max_points, Q), indices: (B, max_points, Q)
152+ safe_locs = torch .where (selection )
153+ batch_loc , mp_loc , query_loc = safe_locs
154+ input_point_locs = indices [batch_loc , mp_loc , query_loc ]
155+ selected_points = points [batch_loc , input_point_locs ]
156+ output_points = torch .zeros (
157+ B ,
158+ queries .shape [1 ],
159+ max_points ,
160+ 3 ,
161+ device = queries .device ,
162+ dtype = points .dtype ,
163+ )
164+ output_points [batch_loc , query_loc , mp_loc ] = selected_points
165+ pts_out = output_points
166+ else :
167+ pts_out = torch .empty (
168+ 0 , max_points , 3 , device = points .device , dtype = points .dtype
64169 )
65- values , indices = closest_points
66- # Values and indices have shape [max_points, n_queries]
67- # The first dim of indices represents the index into input points
68-
69- # Filter to points within radius
70- selection = values <= radius
71- selected_indices = torch .where (selection , indices , 0 ).t ()
72170
171+ if was_unbatched :
172+ selected_indices = selected_indices .squeeze (0 )
73173 if return_dists :
74- dists = torch .where (selection , values , 0 ).t ()
75- else :
76- dists = torch .empty (
77- (0 , values .shape [1 ]), device = values .device , dtype = values .dtype
78- )
79-
174+ dists_out = dists_out .squeeze (0 )
80175 if return_points :
81- # selected_indices: (num_queries, max_points)
82- # points: (num_points, point_dim)
83- # We want: selected_points: (num_queries, max_points, point_dim)
84-
85- safe_indices = torch .where (selection )
86- max_points_loc , queries_loc = safe_indices
87-
88- # Use these to get the input points locations:
89- input_point_locs = indices [max_points_loc , queries_loc ]
90- selected_points = points [input_point_locs ]
91- # Construct default output points:
92- output_points = torch .zeros (
93- queries .shape [0 ],
94- max_points ,
95- 3 ,
96- device = queries .device ,
97- dtype = points .dtype ,
98- )
99- # Put the selected points in:
100- output_points [queries_loc , max_points_loc ] = selected_points
176+ pts_out = pts_out .squeeze (0 )
101177
102- points = output_points
103- else :
104- points = torch .empty (
105- (0 , points .shape [1 ]), device = points .device , dtype = points .dtype
106- )
107-
108- return selected_indices , points , dists
178+ return selected_indices , pts_out , dists_out
109179
110180
111181def radius_search (
@@ -116,6 +186,7 @@ def radius_search(
116186 return_dists : bool = False ,
117187 return_points : bool = False ,
118188):
189+ """Torch-backend entry point for radius search with formatted returns."""
119190 indices , points_out , distances = radius_search_impl (
120191 points , queries , radius , max_points , return_dists , return_points
121192 )
0 commit comments