Skip to content

Commit f26c181

Browse files
Merge pull request #40 from stackav-oss/feature/jmanning/voxelization-cleanup-v2
Change voxelization interface to match CUDA for 1:1 benchmarking
2 parents f5142d8 + 40a074a commit f26c181

2 files changed

Lines changed: 13 additions & 12 deletions

File tree

conch/ops/vision/voxelization.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ def _compute_grid_dim(self) -> tuple[int, int, int]:
4747

4848
def generate_voxels(
4949
points: torch.Tensor, param: VoxelizationParameter
50-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
50+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
5151
"""Generates voxels from input points, output voxels and points are randomly ordered due to use of atomics.
5252
5353
Args:
5454
points: input points; expected dimensions (num_points, 4), last dimension should have fields of x,y,z,_.
5555
param: parameters.
5656
5757
Returns:
58-
tuple of voxels:
59-
capped_num_points_per_voxel, shape [num_filled_voxels], note per voxel point counters are capped with max_num_points_per_voxel.
60-
point_features, shape [num_filled_voxels, max_num_points_per_voxel, num_features_per_point],
61-
empty points are filled with 0.
58+
tuple of tensors:
59+
num_of_filled_voxels, shape [1, cpu]
60+
point_features, shape [num_filled_voxels, max_num_points_per_voxel, num_features_per_point], empty points are filled with 0.
6261
voxel_indices, shape [num_filled_voxels, 4], only first 3 fields are used for x,y,z indices.
62+
capped_num_points_per_voxel, shape [num_filled_voxels], note per voxel point counters are capped with max_num_points_per_voxel.
6363
"""
6464
device = points.device
6565
num_points, num_features_per_point = points.shape
@@ -124,10 +124,6 @@ def generate_voxels(
124124
num_warps=block_size // num_threads_per_warp,
125125
)
126126

127-
total_filled_voxels = num_filled_voxels.cpu()[0]
127+
total_filled_voxels = num_filled_voxels.cpu()
128128

129-
return (
130-
num_points_per_voxel[:total_filled_voxels],
131-
point_features[:total_filled_voxels, :, :],
132-
voxel_indices[:total_filled_voxels, :],
133-
)
129+
return total_filled_voxels, point_features, voxel_indices, num_points_per_voxel

tests/voxelization_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ def test_voxelization(
5555
print(f"Max number of points per voxel: {param.max_num_points_per_voxel}")
5656

5757
# pure triton version
58-
num_points_per_voxel, point_features, voxel_indices = generate_voxels(points, param)
58+
total_filled_voxels, point_features, voxel_indices, num_points_per_voxel = generate_voxels(points, param)
59+
60+
filled_voxels = total_filled_voxels[0]
61+
num_points_per_voxel = num_points_per_voxel[:filled_voxels]
62+
point_features = point_features[:filled_voxels, :, :]
63+
voxel_indices = voxel_indices[:filled_voxels, :]
5964

6065
# triton/torch hybrid, 2-step, stable voxelization first then generate a feature tensor the same format as above
6166
actual_num_points_per_voxel, point_raw_indices, flat_voxel_indices_stable = voxelization_stable(

0 commit comments

Comments
 (0)