Skip to content

Commit 89a553e

Browse files
authored
fix: replica selection bias in fusedmoe router (#1638)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent a7bc451 commit 89a553e

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

aphrodite/model_executor/layers/fused_moe/router/base_router.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def _eplb_map_and_record_i32_kernel(
2626
map_slots,
2727
out_size,
2828
numel,
29+
num_active_experts,
2930
BLOCK_SIZE: tl.constexpr,
3031
):
3132
pid = tl.program_id(0)
@@ -37,16 +38,18 @@ def _eplb_map_and_record_i32_kernel(
3738
safe_expert_id = tl.where(valid_expert, expert_id, 0)
3839

3940
# 1. Convert the logical expert ids to physical expert ids
40-
# Directly select a random replica for each logical expert
4141
replica_count = tl.load(
4242
logical_replica_count_ptr + safe_expert_id,
4343
mask=mask & valid_expert,
4444
other=1,
4545
)
4646
# Avoid invalid modulo/div by forcing at least 1.
4747
replica_count = tl.maximum(replica_count, 1)
48-
# Match torch.compile path: use flattened token position.
49-
replica_idx = offs % replica_count
48+
# floor(2^32 / phi), classic Knuth multiplicative hash multiplier.
49+
KNUTH_MULTIPLIER = 2654435769
50+
token_idx = (offs // num_active_experts).to(tl.int64)
51+
hashed = (token_idx * KNUTH_MULTIPLIER) & 0xFFFFFFFF
52+
replica_idx = hashed % replica_count
5053

5154
# 2. Record expert load metrics.
5255

@@ -85,6 +88,7 @@ def _eplb_map_and_record_triton(
8588
numel = topk_ids_in.numel()
8689
if numel == 0:
8790
return topk_ids
91+
num_active_experts = topk_ids_in.shape[-1]
8892
out_flat = torch.empty((numel,), device=topk_ids.device, dtype=topk_ids.dtype)
8993
grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
9094
assert expert_load_view.is_contiguous()
@@ -99,6 +103,7 @@ def _eplb_map_and_record_triton(
99103
logical_to_physical_map.shape[1],
100104
expert_load_view.shape[0],
101105
numel,
106+
num_active_experts,
102107
BLOCK_SIZE=256,
103108
)
104109
return out_flat.reshape(topk_ids.shape)

0 commit comments

Comments
 (0)