88import triton .language as tl
99
1010
11- @triton .jit
12- def filter_and_label_points_triton_kernel ( # noqa: PLR0913, D417
13- # input
14- points_ptr : torch .Tensor ,
15- num_points : int ,
16- num_features_per_point : int ,
17- # parameters
18- min_x : float ,
19- min_y : float ,
20- min_z : float ,
21- max_x : float ,
22- max_y : float ,
23- max_z : float ,
24- voxel_dim_x : float ,
25- voxel_dim_y : float ,
26- voxel_dim_z : float ,
27- grid_dim_x : int ,
28- grid_dim_y : int ,
29- grid_dim_z : int ,
30- max_num_voxels : int ,
31- # output
32- point_voxel_indices_ptr : torch .Tensor ,
33- # Constants
34- cxpr_block_size : tl .constexpr ,
35- ) -> None :
36- """Filter valid points and label each with a voxel index.
37-
38- Args:
39- points_ptr: input points, shape (num_points, num_features_per_point).
40- voxelization parameters
41- point_voxel_indices_ptr: output per point flattened voxel indices, shape (num_points).
42- """
43- block_idx = tl .program_id (axis = 0 )
44- point_idx = block_idx * cxpr_block_size + tl .arange (0 , cxpr_block_size )
45- point_mask = point_idx < num_points
46-
47- point_x = tl .load (points_ptr + point_idx * num_features_per_point + 0 , mask = point_mask , other = max_x + voxel_dim_x )
48- point_y = tl .load (points_ptr + point_idx * num_features_per_point + 1 , mask = point_mask , other = max_y + voxel_dim_y )
49- point_z = tl .load (points_ptr + point_idx * num_features_per_point + 2 , mask = point_mask , other = max_z + voxel_dim_z )
50-
51- voxel_x = tl .floor ((point_x - min_x ) / voxel_dim_x ).to (tl .int32 )
52- voxel_y = tl .floor ((point_y - min_y ) / voxel_dim_y ).to (tl .int32 )
53- voxel_z = tl .floor ((point_z - min_z ) / voxel_dim_z ).to (tl .int32 )
54-
55- valid_x = (voxel_x >= 0 ) & (voxel_x < grid_dim_x )
56- valid_y = (voxel_y >= 0 ) & (voxel_y < grid_dim_y )
57- valid_z = (voxel_z >= 0 ) & (voxel_z < grid_dim_z )
58- valid_point = point_mask & valid_x & valid_y & valid_z
59-
60- flat_voxel_idx = tl .where (valid_point , ((voxel_z * grid_dim_y + voxel_y ) * grid_dim_x + voxel_x ), max_num_voxels )
61- tl .store (point_voxel_indices_ptr + point_idx , flat_voxel_idx , mask = point_mask )
62-
63-
6411@triton .jit
6512def generate_dense_voxels_triton_kernel ( # noqa: PLR0913, D417
6613 # input
@@ -187,7 +134,7 @@ def generate_voxels_triton_kernel( # noqa: PLR0913, D417
187134 tl .store (voxel_indices_ptr + voxel_idx * 4 + 2 , voxel_z , mask = valid_voxel )
188135
189136 # store all feature points, including padded 0s
190- for point_idx in range (0 , max_num_points_per_voxel , 1 ):
137+ for point_idx in range (0 , max_num_points_per_voxel ):
191138 input_idx = flat_voxel_idx * max_num_points_per_voxel + point_idx
192139 point_x = tl .load (dense_point_features_ptr + input_idx * 4 + 0 , mask = valid_voxel )
193140 point_y = tl .load (dense_point_features_ptr + input_idx * 4 + 1 , mask = valid_voxel )
@@ -199,53 +146,3 @@ def generate_voxels_triton_kernel( # noqa: PLR0913, D417
199146 tl .store (point_features_ptr + output_idx * 4 + 1 , point_y , mask = valid_voxel )
200147 tl .store (point_features_ptr + output_idx * 4 + 2 , point_z , mask = valid_voxel )
201148 tl .store (point_features_ptr + output_idx * 4 + 3 , point_w , mask = valid_voxel )
202-
203-
204- @triton .jit
205- def collect_point_features_triton_kernel ( # noqa: PLR0913, D417
206- # input
207- points_ptr : torch .Tensor ,
208- num_features_per_point : int ,
209- segment_offsets_ptr : torch .Tensor ,
210- num_filled_voxels : int ,
211- point_indices_ptr : torch .Tensor ,
212- # parameters
213- max_num_points_per_voxel : int ,
214- # output
215- point_features_ptr : torch .Tensor ,
216- capped_num_points_per_voxel_ptr : torch .Tensor ,
217- # Constants
218- cxpr_block_size : tl .constexpr ,
219- ) -> None :
220- """Group valid points into dense voxels.
221-
222- Args:
223- points_ptr: input points tensor, shape (num_points, num_features_per_point)
224- segment_offsets_ptr: input segment end offsets, shape (num_filled_voxels)
225- point_indices_ptr: input raw point indices, shape (num_valid_points)
226- voxelization parameters
227- point_features_ptr: output voxel point features, shape (num_filled_voxels, max_num_points_per_voxel, num_features_per_point)
228- capped_num_points_per_voxel_ptr: output number of points per voxel tensor after capping, shape (num_filled_voxels)
229- """
230- block_idx = tl .program_id (axis = 0 )
231- voxel_idx = block_idx * cxpr_block_size + tl .arange (0 , cxpr_block_size )
232- voxel_mask = voxel_idx < num_filled_voxels
233-
234- # top n filtering
235- segment_start = tl .load (segment_offsets_ptr + voxel_idx - 1 , mask = (voxel_mask & (voxel_idx > 0 )), other = 0 )
236- segment_end = tl .load (segment_offsets_ptr + voxel_idx , mask = voxel_mask , other = 0 )
237- num_points_in_voxel = segment_end - segment_start
238- num_points_in_voxel = tl .minimum (num_points_in_voxel , max_num_points_per_voxel )
239- tl .store (capped_num_points_per_voxel_ptr + voxel_idx , num_points_in_voxel , mask = voxel_mask )
240-
241- for voxel_point_idx in range (0 , max_num_points_per_voxel , 1 ):
242- # this mask is sufficient since other num_points_in_voxel == 0
243- per_voxel_mask = voxel_point_idx < num_points_in_voxel
244-
245- raw_point_idx = tl .load (point_indices_ptr + segment_start + voxel_point_idx , mask = per_voxel_mask )
246- output_idx = voxel_idx * max_num_points_per_voxel + voxel_point_idx
247- for feature_idx in range (0 , num_features_per_point , 1 ):
248- value = tl .load (
249- points_ptr + raw_point_idx * num_features_per_point + feature_idx , mask = per_voxel_mask , other = 0
250- )
251- tl .store (point_features_ptr + output_idx * num_features_per_point + feature_idx , value , mask = voxel_mask )
0 commit comments