Skip to content

Commit d7458d8

Browse files
author
Lenny Wang
committed
remove zeros to match against cuda impl
1 parent cb8a847 commit d7458d8

3 files changed

Lines changed: 11 additions & 10 deletions

File tree

benchmarks/voxelization_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
help="Number of input points",
2828
)
2929
@click.option(
30-
"--max_num_points_per_voxel",
30+
"--max-num-points-per-voxel",
3131
required=False,
3232
type=int,
3333
default=4,
@@ -41,7 +41,7 @@
4141
help="Voxel dimension same for x,y,z",
4242
)
4343
@click.option(
44-
"--grid_range",
44+
"--grid-range",
4545
required=False,
4646
type=float,
4747
default=50,
@@ -56,14 +56,14 @@
5656
"--iteration-time-ms",
5757
required=False,
5858
type=int,
59-
default=10000,
59+
default=100,
6060
help="Time in milliseconds to run benchmark",
6161
)
6262
@click.option(
6363
"--warmup-time-ms",
6464
required=False,
6565
type=int,
66-
default=100,
66+
default=10,
6767
help="Time in milliseconds to warmup before recording times",
6868
)
6969
@click.option(

conch/kernels/vision/voxelization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,11 @@ def generate_voxels_triton_kernel( # noqa: PLR0913, D417
136136
# store all feature points, including padded 0s
137137
for point_idx in range(0, max_num_points_per_voxel):
138138
input_idx = flat_voxel_idx * max_num_points_per_voxel + point_idx
139-
point_x = tl.load(dense_point_features_ptr + input_idx * 4 + 0, mask=valid_voxel)
140-
point_y = tl.load(dense_point_features_ptr + input_idx * 4 + 1, mask=valid_voxel)
141-
point_z = tl.load(dense_point_features_ptr + input_idx * 4 + 2, mask=valid_voxel)
142-
point_w = tl.load(dense_point_features_ptr + input_idx * 4 + 3, mask=valid_voxel)
139+
valid_point = (point_idx < num_points_in_voxel) and valid_voxel
140+
point_x = tl.load(dense_point_features_ptr + input_idx * 4 + 0, mask=valid_point, other=0)
141+
point_y = tl.load(dense_point_features_ptr + input_idx * 4 + 1, mask=valid_point, other=0)
142+
point_z = tl.load(dense_point_features_ptr + input_idx * 4 + 2, mask=valid_point, other=0)
143+
point_w = tl.load(dense_point_features_ptr + input_idx * 4 + 3, mask=valid_point, other=0)
143144

144145
output_idx = voxel_idx * max_num_points_per_voxel + point_idx
145146
tl.store(point_features_ptr + output_idx * 4 + 0, point_x, mask=valid_voxel)

conch/ops/vision/voxelization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def generate_voxels(
6868
# same as original nvidia cuda impl
6969
num_elements_per_voxel_index = 4
7070

71-
# dense (must set to 0s)
71+
# dense
7272
dense_num_points_per_voxel = torch.zeros((param.max_num_voxels), dtype=torch.int32, device=device)
73-
dense_point_features = torch.zeros(
73+
dense_point_features = torch.empty(
7474
(param.max_num_voxels, param.max_num_points_per_voxel, num_features_per_point), dtype=torch.float, device=device
7575
)
7676

0 commit comments

Comments
 (0)