@@ -26,6 +26,7 @@ def _eplb_map_and_record_i32_kernel(
2626 map_slots ,
2727 out_size ,
2828 numel ,
29+ num_active_experts ,
2930 BLOCK_SIZE : tl .constexpr ,
3031 ):
3132 pid = tl .program_id (0 )
@@ -37,16 +38,18 @@ def _eplb_map_and_record_i32_kernel(
3738 safe_expert_id = tl .where (valid_expert , expert_id , 0 )
3839
3940 # 1. Convert the logical expert ids to physical expert ids
40- # Directly select a random replica for each logical expert
4141 replica_count = tl .load (
4242 logical_replica_count_ptr + safe_expert_id ,
4343 mask = mask & valid_expert ,
4444 other = 1 ,
4545 )
4646 # Avoid invalid modulo/div by forcing at least 1.
4747 replica_count = tl .maximum (replica_count , 1 )
48- # Match torch.compile path: use flattened token position.
49- replica_idx = offs % replica_count
48+ # floor(2^32 / phi), classic Knuth multiplicative hash multiplier.
49+ KNUTH_MULTIPLIER = 2654435769
50+ token_idx = (offs // num_active_experts ).to (tl .int64 )
51+ hashed = (token_idx * KNUTH_MULTIPLIER ) & 0xFFFFFFFF
52+ replica_idx = hashed % replica_count
5053
5154 # 2. Record expert load metrics.
5255
@@ -85,6 +88,7 @@ def _eplb_map_and_record_triton(
8588 numel = topk_ids_in .numel ()
8689 if numel == 0 :
8790 return topk_ids
91+ num_active_experts = topk_ids_in .shape [- 1 ]
8892 out_flat = torch .empty ((numel ,), device = topk_ids .device , dtype = topk_ids .dtype )
8993 grid = lambda meta : (triton .cdiv (numel , meta ["BLOCK_SIZE" ]),)
9094 assert expert_load_view .is_contiguous ()
@@ -99,6 +103,7 @@ def _eplb_map_and_record_triton(
99103 logical_to_physical_map .shape [1 ],
100104 expert_load_view .shape [0 ],
101105 numel ,
106+ num_active_experts ,
102107 BLOCK_SIZE = 256 ,
103108 )
104109 return out_flat .reshape (topk_ids .shape )
0 commit comments