Skip to content

Commit f5142d8

Browse files
Merge pull request #39 from stackav-oss/feature/jmanning/voxelization-cleanup
Cleanup voxelization
2 parents bf1e621 + 73d156c commit f5142d8

5 files changed

Lines changed: 41 additions & 27 deletions

File tree

benchmarks/voxelization_benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from conch.ops.vision.voxelization import VoxelizationParameter, generate_voxels
1212
from conch.platforms import current_platform
1313
from conch.reference.vision.voxelization import collect_point_features, voxelization_stable
14+
from conch.third_party.vllm.utils import seed_everything
1415
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
1516

1617

@@ -116,6 +117,8 @@ def main(
116117
compile_ref: Flag to torch.compile() the pure torch reference implementation.
117118
cuda_ref: Flag to enable CUDA reference implementation.
118119
"""
120+
seed: Final = 0
121+
seed_everything(seed)
119122

120123
device: Final = torch.device(gpu)
121124
torch.set_default_device(device)
@@ -139,7 +142,7 @@ def main(
139142

140143
def generate_voxels_torch(
141144
points: torch.Tensor, param: VoxelizationParameter
142-
) -> tuple[torch.tensor, torch.tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
145+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
143146
"""reference triton/torch hybrid, 2-step, stable voxelization first then generate a feature tensor."""
144147
use_triton = not torch_ref
145148
actual_num_points_per_voxel, point_raw_indices, flat_voxel_indices = voxelization_stable(
@@ -163,7 +166,7 @@ def generate_voxels_torch(
163166
print(f"number of filled voxels: {actual_num_points_per_voxel.shape[0]}")
164167
print(f"Avg number of points per voxel: {torch.mean(actual_num_points_per_voxel, dtype=torch.float32)}")
165168
print(f"Max number of points per voxel: {torch.max(actual_num_points_per_voxel)}")
166-
overflow_count = (actual_num_points_per_voxel > param.max_num_points_per_voxel).int().sum()
169+
overflow_count = (actual_num_points_per_voxel > param.max_num_points_per_voxel).to(torch.int32).sum()
167170
print(f"Number of voxels with overflowing points: {overflow_count}")
168171

169172
# Benchmark implementations

conch/kernels/vision/voxelization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import triton.language as tl
99

1010

11-
@triton.jit
12-
def generate_dense_voxels_triton_kernel( # noqa: PLR0913, D417
11+
@triton.jit # type: ignore[misc]
12+
def generate_dense_voxels_triton_kernel(
1313
# input
1414
points_ptr: torch.Tensor,
1515
num_points: int,
@@ -73,8 +73,8 @@ def generate_dense_voxels_triton_kernel( # noqa: PLR0913, D417
7373
tl.store(dense_point_features_ptr + output_idx * 4 + 3, point_w, mask=output_mask)
7474

7575

76-
@triton.jit
77-
def generate_voxels_triton_kernel( # noqa: PLR0913, D417
76+
@triton.jit # type: ignore[misc]
77+
def generate_voxels_triton_kernel(
7878
# input
7979
dense_point_features_ptr: torch.Tensor,
8080
dense_num_points_per_voxel_ptr: torch.Tensor,

conch/ops/vision/voxelization.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ def generate_voxels(
6161
empty points are filled with 0.
6262
voxel_indices, shape [num_filled_voxels, 4], only first 3 fields are used for x,y,z indices.
6363
"""
64-
assert points.is_cuda
6564
device = points.device
6665
num_points, num_features_per_point = points.shape
67-
assert num_features_per_point == 4 # noqa: PLR2004
66+
assert num_features_per_point == 4
6867
# same as original nvidia cuda impl
6968
num_elements_per_voxel_index = 4
7069

@@ -105,7 +104,7 @@ def generate_voxels(
105104
dense_num_points_per_voxel,
106105
dense_point_features,
107106
cxpr_block_size=block_size,
108-
num_warps=block_size // num_threads_per_warp, # pyright: ignore[reportCallIssue]
107+
num_warps=block_size // num_threads_per_warp,
109108
)
110109

111110
# compress into contiguous/sparse filled voxels
@@ -122,7 +121,7 @@ def generate_voxels(
122121
point_features,
123122
voxel_indices,
124123
cxpr_block_size=block_size,
125-
num_warps=block_size // num_threads_per_warp, # pyright: ignore[reportCallIssue]
124+
num_warps=block_size // num_threads_per_warp,
126125
)
127126

128127
total_filled_voxels = num_filled_voxels.cpu()[0]

conch/reference/vision/voxelization.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from conch.ops.vision.voxelization import VoxelizationParameter
1111

12-
@triton.jit
13-
def filter_and_label_points_triton_kernel( # noqa: PLR0913, D417
12+
13+
@triton.jit # type: ignore[misc]
14+
def filter_and_label_points_triton_kernel(
1415
# input
1516
points_ptr: torch.Tensor,
1617
num_points: int,
@@ -62,7 +63,7 @@ def filter_and_label_points_triton_kernel( # noqa: PLR0913, D417
6263
tl.store(point_voxel_indices_ptr + point_idx, flat_voxel_idx, mask=point_mask)
6364

6465

65-
def filter_and_label_points_torch( # noqa: PLR0913, D417
66+
def filter_and_label_points_torch(
6667
points: torch.Tensor,
6768
min_range: tuple[float, float, float],
6869
voxel_dim: tuple[float, float, float],
@@ -122,7 +123,6 @@ def voxelization_stable(
122123
contiguous with segment size specified in num_points_per_voxel.
123124
flat_voxel_indices, shape [num_filled_voxels].
124125
"""
125-
assert points.is_cuda
126126
device = points.device
127127
num_points, num_features_per_point = points.shape
128128

@@ -154,7 +154,7 @@ def voxelization_stable(
154154
param.max_num_voxels,
155155
point_voxel_indices,
156156
cxpr_block_size=block_size,
157-
num_warps=block_size // num_threads_per_warp, # pyright: ignore[reportCallIssue]
157+
num_warps=block_size // num_threads_per_warp,
158158
)
159159
else:
160160
filter_and_label_points_torch(
@@ -177,8 +177,8 @@ def voxelization_stable(
177177
return num_points_per_voxel.to(torch.int32), sorted_raw_indices, voxel_indices
178178

179179

180-
@triton.jit
181-
def collect_point_features_triton_kernel( # noqa: PLR0913, D417
180+
@triton.jit # type: ignore[misc]
181+
def collect_point_features_triton_kernel(
182182
# input
183183
points_ptr: torch.Tensor,
184184
num_features_per_point: int,
@@ -227,7 +227,7 @@ def collect_point_features_triton_kernel( # noqa: PLR0913, D417
227227
tl.store(point_features_ptr + output_idx * num_features_per_point + feature_idx, value, mask=voxel_mask)
228228

229229

230-
def collect_point_features_torch( # noqa: PLR0913, D417
230+
def collect_point_features_torch(
231231
points: torch.Tensor,
232232
num_points_per_voxel: torch.Tensor,
233233
segment_offsets: torch.Tensor,
@@ -287,7 +287,6 @@ def collect_point_features(
287287
filled with 0.
288288
capped_num_points_per_voxel: shape [num_filled_voxels], number of points in each voxel after max capping.
289289
"""
290-
assert points.is_cuda
291290
device = points.device
292291
num_points, num_features_per_point = points.shape
293292

@@ -319,7 +318,7 @@ def collect_point_features(
319318
point_features,
320319
capped_num_points_per_voxel,
321320
cxpr_block_size=block_size,
322-
num_warps=block_size // num_threads_per_warp, # pyright: ignore[reportCallIssue]
321+
num_warps=block_size // num_threads_per_warp,
323322
)
324323
else:
325324
collect_point_features_torch(

tests/voxelization_test.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# Copyright 2025 Stack AV Co.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
# pyright: reportPrivateUsage=false
54
"""Test voxelization."""
65

6+
from typing import Final
7+
78
import pytest
89
import torch
910

1011
from conch.ops.vision.voxelization import VoxelizationParameter, generate_voxels
12+
from conch.platforms import current_platform
1113
from conch.reference.vision.voxelization import collect_point_features, voxelization_stable
14+
from conch.third_party.vllm.utils import seed_everything
1215

1316

1417
def voxel_coords_to_flat_indices(coords: torch.Tensor, grid_dim: tuple[int, int, int]) -> torch.Tensor:
@@ -23,18 +26,28 @@ def voxel_coords_to_flat_indices(coords: torch.Tensor, grid_dim: tuple[int, int,
2326

2427
# whether or not use Triton for the reference Torch implementation
2528
@pytest.mark.parametrize("use_triton", [True, False])
26-
def test_voxelization(use_triton: bool) -> None:
29+
@pytest.mark.parametrize("num_points", [1000, 500000])
30+
@pytest.mark.parametrize("range_xyz", [50.0, 80.0])
31+
@pytest.mark.parametrize("voxel_dim", [2.5, 5.0])
32+
@pytest.mark.parametrize("max_num_points_per_voxel", [4, 10])
33+
def test_voxelization(
34+
use_triton: bool, num_points: int, range_xyz: float, voxel_dim: float, max_num_points_per_voxel: int
35+
) -> None:
2736
"""Test triton/pytorch voxelization."""
28-
num_points = 500000
37+
seed: Final = 0
38+
seed_everything(seed)
39+
40+
device: Final = torch.device(current_platform.device)
41+
torch.set_default_device(device)
42+
2943
num_features_per_point = 4
30-
range_xyz = 50.0
31-
points = torch.randn((num_points, num_features_per_point), device="cuda") * range_xyz
44+
points = torch.randn((num_points, num_features_per_point)) * range_xyz
3245

3346
param = VoxelizationParameter(
3447
min_range=(-range_xyz, -range_xyz, -range_xyz),
3548
max_range=(range_xyz, range_xyz, range_xyz),
36-
voxel_dim=(2.5, 2.5, 2.5),
37-
max_num_points_per_voxel=4,
49+
voxel_dim=(voxel_dim, voxel_dim, voxel_dim),
50+
max_num_points_per_voxel=max_num_points_per_voxel,
3851
)
3952

4053
print(f"Grid dimensions: {param.grid_dim}")

0 commit comments

Comments
 (0)