Skip to content

Commit 4256600

Browse files
committed
Use one validate inputs function instead of 2
1 parent 0623efa commit 4256600

3 files changed

Lines changed: 22 additions & 40 deletions

File tree

physicsnemo/nn/functional/neighbors/radius_search/_torch_impl.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,7 @@
1717

1818
import torch
1919

20-
from .utils import format_returns
21-
22-
23-
def _validate_inputs(points: torch.Tensor, queries: torch.Tensor):
24-
"""Validate and normalize inputs to (B, N, 3) shape. Returns (points, queries, was_unbatched)."""
25-
if points.ndim == 2 and queries.ndim == 2:
26-
return points.unsqueeze(0), queries.unsqueeze(0), True
27-
elif points.ndim == 3 and queries.ndim == 3:
28-
if points.shape[0] != queries.shape[0]:
29-
raise ValueError(
30-
f"Batch dimensions must match: points has {points.shape[0]}, "
31-
f"queries has {queries.shape[0]}"
32-
)
33-
return points, queries, False
34-
else:
35-
raise ValueError(
36-
f"points and queries must be 2D (N, 3) or 3D (B, N, 3), "
37-
f"got {points.ndim}D and {queries.ndim}D"
38-
)
20+
from .utils import format_returns, validate_inputs
3921

4022

4123
def _radius_search_dynamic(
@@ -89,7 +71,7 @@ def radius_search_impl(
8971
This is a brute force implementation that is not memory efficient.
9072
"""
9173

92-
points, queries, was_unbatched = _validate_inputs(points, queries)
74+
points, queries, was_unbatched = validate_inputs(points, queries)
9375
B = points.shape[0]
9476

9577
if max_points is None:

physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
scatter_add_batched,
3838
scatter_add_unlimited,
3939
)
40-
from .utils import format_returns
40+
from .utils import format_returns, validate_inputs
4141

4242
wp.config.quiet = True
4343

@@ -46,24 +46,6 @@
4646
BLOCK_DIM = 32
4747

4848

49-
def _validate_inputs(points: torch.Tensor, queries: torch.Tensor):
50-
"""Validate and normalize inputs. Returns (points, queries, was_unbatched)."""
51-
if points.ndim == 2 and queries.ndim == 2:
52-
return points.unsqueeze(0), queries.unsqueeze(0), True
53-
elif points.ndim == 3 and queries.ndim == 3:
54-
if points.shape[0] != queries.shape[0]:
55-
raise ValueError(
56-
f"Batch dimensions must match: points has {points.shape[0]}, "
57-
f"queries has {queries.shape[0]}"
58-
)
59-
return points, queries, False
60-
else:
61-
raise ValueError(
62-
f"points and queries must be 2D (N, 3) or 3D (B, N, 3), "
63-
f"got {points.ndim}D and {queries.ndim}D"
64-
)
65-
66-
6749
def count_neighbors(
6850
grid: wp.HashGrid,
6951
wp_points: wp.array(dtype=wp.vec3),
@@ -226,7 +208,7 @@ def radius_search_impl(
226208
if points.device != queries.device:
227209
raise ValueError("points and queries must be on the same device")
228210

229-
points, queries, was_unbatched = _validate_inputs(points, queries)
211+
points, queries, was_unbatched = validate_inputs(points, queries)
230212
B = points.shape[0]
231213
N_queries = queries.shape[1]
232214

physicsnemo/nn/functional/neighbors/radius_search/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@
1717
import torch
1818

1919

20+
def validate_inputs(points: torch.Tensor, queries: torch.Tensor):
21+
"""Validate and normalize inputs to (B, N, 3) shape. Returns (points, queries, was_unbatched)."""
22+
if points.ndim == 2 and queries.ndim == 2:
23+
return points.unsqueeze(0), queries.unsqueeze(0), True
24+
elif points.ndim == 3 and queries.ndim == 3:
25+
if points.shape[0] != queries.shape[0]:
26+
raise ValueError(
27+
f"Batch dimensions must match: points has {points.shape[0]}, "
28+
f"queries has {queries.shape[0]}"
29+
)
30+
return points, queries, False
31+
else:
32+
raise ValueError(
33+
f"points and queries must be 2D (N, 3) or 3D (B, N, 3), "
34+
f"got {points.ndim}D and {queries.ndim}D"
35+
)
36+
37+
2038
def format_returns(
2139
indices: torch.Tensor,
2240
points: torch.Tensor,

0 commit comments

Comments
 (0)