99
1010from conch .ops .vision .voxelization import VoxelizationParameter
1111
12- @triton .jit
13- def filter_and_label_points_triton_kernel ( # noqa: PLR0913, D417
12+
13+ @triton .jit # type: ignore[misc]
14+ def filter_and_label_points_triton_kernel (
1415 # input
1516 points_ptr : torch .Tensor ,
1617 num_points : int ,
@@ -62,7 +63,7 @@ def filter_and_label_points_triton_kernel( # noqa: PLR0913, D417
6263 tl .store (point_voxel_indices_ptr + point_idx , flat_voxel_idx , mask = point_mask )
6364
6465
65- def filter_and_label_points_torch ( # noqa: PLR0913, D417
66+ def filter_and_label_points_torch (
6667 points : torch .Tensor ,
6768 min_range : tuple [float , float , float ],
6869 voxel_dim : tuple [float , float , float ],
@@ -122,7 +123,6 @@ def voxelization_stable(
122123 contiguous with segment size specified in num_points_per_voxel.
123124 flat_voxel_indices, shape [num_filled_voxels].
124125 """
125- assert points .is_cuda
126126 device = points .device
127127 num_points , num_features_per_point = points .shape
128128
@@ -154,7 +154,7 @@ def voxelization_stable(
154154 param .max_num_voxels ,
155155 point_voxel_indices ,
156156 cxpr_block_size = block_size ,
157- num_warps = block_size // num_threads_per_warp , # pyright: ignore[reportCallIssue]
157+ num_warps = block_size // num_threads_per_warp ,
158158 )
159159 else :
160160 filter_and_label_points_torch (
@@ -177,8 +177,8 @@ def voxelization_stable(
177177 return num_points_per_voxel .to (torch .int32 ), sorted_raw_indices , voxel_indices
178178
179179
180- @triton .jit
181- def collect_point_features_triton_kernel ( # noqa: PLR0913, D417
180+ @triton .jit # type: ignore[misc]
181+ def collect_point_features_triton_kernel (
182182 # input
183183 points_ptr : torch .Tensor ,
184184 num_features_per_point : int ,
@@ -227,7 +227,7 @@ def collect_point_features_triton_kernel( # noqa: PLR0913, D417
227227 tl .store (point_features_ptr + output_idx * num_features_per_point + feature_idx , value , mask = voxel_mask )
228228
229229
230- def collect_point_features_torch ( # noqa: PLR0913, D417
230+ def collect_point_features_torch (
231231 points : torch .Tensor ,
232232 num_points_per_voxel : torch .Tensor ,
233233 segment_offsets : torch .Tensor ,
@@ -287,7 +287,6 @@ def collect_point_features(
287287 filled with 0.
288288 capped_num_points_per_voxel: shape [num_filled_voxels], number of points in each voxel after max capping.
289289 """
290- assert points .is_cuda
291290 device = points .device
292291 num_points , num_features_per_point = points .shape
293292
@@ -319,7 +318,7 @@ def collect_point_features(
319318 point_features ,
320319 capped_num_points_per_voxel ,
321320 cxpr_block_size = block_size ,
322- num_warps = block_size // num_threads_per_warp , # pyright: ignore[reportCallIssue]
321+ num_warps = block_size // num_threads_per_warp ,
323322 )
324323 else :
325324 collect_point_features_torch (
0 commit comments