@@ -67,13 +67,15 @@ def _fused_routing_from_topk_hist_kernel(
6767 # Clamp the offset for masked-out lanes to 0 so the pointer arithmetic
6868 # below stays within the allocated buffers.
6969 safe_item = tl .where (item_mask , item_offs , 0 )
70- global_expt = tl .load (topk_ids_ptr + safe_item , mask = item_mask , other = 0 ).to (tl .int32 )
70+ global_expt = tl .load (topk_ids_ptr + safe_item , mask = item_mask , other = 0 ).to (
71+ tl .int32
72+ )
7173 if HAS_EXPERT_MAP :
7274 map_mask = item_mask & (global_expt >= 0 ) & (global_expt < expert_map_numel )
7375 safe_global_expt = tl .where (map_mask , global_expt , 0 )
74- local_expt = tl .load (expert_map_ptr + safe_global_expt , mask = map_mask , other = - 1 ). to (
75- tl . int32
76- )
76+ local_expt = tl .load (
77+ expert_map_ptr + safe_global_expt , mask = map_mask , other = - 1
78+ ). to ( tl . int32 )
7779 # Match reference semantics: invalid experts are redirected to bucket 0
7880 # and later zeroed in gate_scal.
7981 expt = tl .where (local_expt >= 0 , local_expt , 0 )
@@ -141,14 +143,16 @@ def _fused_routing_from_topk_place_kernel(
141143 item_offs = tl .arange (0 , BLOCK_NK )
142144 item_mask = item_offs < NK
143145 safe_item = tl .where (item_mask , item_offs , 0 )
144- global_expt = tl .load (topk_ids_ptr + safe_item , mask = item_mask , other = 0 ).to (tl .int32 )
146+ global_expt = tl .load (topk_ids_ptr + safe_item , mask = item_mask , other = 0 ).to (
147+ tl .int32
148+ )
145149 weights = tl .load (topk_weights_ptr + safe_item , mask = item_mask , other = 0.0 )
146150 if HAS_EXPERT_MAP :
147151 map_mask = item_mask & (global_expt >= 0 ) & (global_expt < expert_map_numel )
148152 safe_global_expt = tl .where (map_mask , global_expt , 0 )
149- local_expt = tl .load (expert_map_ptr + safe_global_expt , mask = map_mask , other = - 1 ). to (
150- tl . int32
151- )
153+ local_expt = tl .load (
154+ expert_map_ptr + safe_global_expt , mask = map_mask , other = - 1
155+ ). to ( tl . int32 )
152156 invalid = local_expt < 0
153157 expt = tl .where (invalid , 0 , local_expt )
154158 weights = tl .where (invalid , 0.0 , weights )
0 commit comments