Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ enable_router_replay = true # this will also auto-set the inference.enable_retur
enable_return_routed_experts = true
```

During training, the router still computes gate scores so gradients flow through it. By default, router replay uses the inference-selected experts exactly. Set `trainer.router_replay_score_threshold_ratio` to a value greater than `0` to filter replayed experts whose gate score falls below that fraction of the weakest expert in the trainer router's own top-k for that token, then fill those slots with router-selected candidates.

This however is not free, it adds a significant overhead to the HTTP requests as this payload can grow quite large. We reccomend increasing `orchestrator.*.env.num_workers` to allow for more parallelization on the verifiers side.

Currently this feature is also not supported with CPU KV cache offload, which can have negative impact on the inference throughput.
3 changes: 3 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,9 @@ class TrainerConfig(BaseConfig):
enable_router_replay: bool = False
"""Return routed experts in the batch so the trainer can replay routing. Requires ``enable_return_routed_experts=true`` on the vLLM server (or ``--enable-return-routed-experts``) and is only supported for custom models."""

router_replay_score_threshold_ratio: float = Field(0.0, ge=0.0, le=1.0)
"""When router replay is enabled, optionally accept a replayed expert only if its gate score is at least this fraction of the trainer router's weakest top-k gate score for that token. The default 0 disables plausibility filtering."""

memory_profiler_path: Path | None = None
"""Path to write the memory profile to."""

Expand Down
24 changes: 24 additions & 0 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,30 @@ def apply_force_balanced_routing(model: nn.Module) -> None:
)


def configure_router_replay_filter(model: nn.Module, score_threshold_ratio: float) -> None:
"""Configure trainer-side plausibility filtering for replayed MoE routes."""
logger = get_logger()
language_model = get_language_model(model)
routers = []

for layer in language_model.layers:
mlp = getattr(layer, "mlp", None) or getattr(layer, "feed_forward", None)
if isinstance(mlp, MoE):
routers.append(mlp.router)

if not routers:
logger.warning("Router replay is enabled, but no token-choice MoE routers were found to configure.")
return

for router in routers:
router.router_replay_score_threshold_ratio = score_threshold_ratio

logger.info(
f"Configured router replay plausibility filtering on {len(routers)} MoE layers "
f"(score_threshold_ratio={score_threshold_ratio})."
)


def is_tt_moe_model(model: nn.Module) -> bool:
return hasattr(model.config, "num_experts") or hasattr(model.config, "n_routed_experts")

Expand Down
60 changes: 52 additions & 8 deletions src/prime_rl/trainer/models/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,43 @@ def __init__(
self.route_norm = route_norm
self.route_scale = route_scale
self.force_balanced = False
self.router_replay_score_threshold_ratio = 0.0

def _topk_from_scores(
self, scores: torch.Tensor, expert_bias: torch.Tensor | None
) -> tuple[torch.Tensor, torch.Tensor]:
if expert_bias is not None:
_, selected_experts_indices = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
top_scores = scores.gather(dim=1, index=selected_experts_indices)
else:
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
return top_scores, selected_experts_indices

def _filter_replayed_experts(
self,
scores: torch.Tensor,
routed_experts: torch.Tensor,
router_top_scores: torch.Tensor,
router_selected_experts: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
replayed_scores = scores.gather(dim=1, index=routed_experts)
threshold = router_top_scores.amin(dim=1, keepdim=True) * self.router_replay_score_threshold_ratio
keep_replayed = replayed_scores >= threshold

selected_experts_indices = torch.where(keep_replayed, routed_experts, torch.full_like(routed_experts, -1))
for candidate_slot in range(self.top_k):
candidate = router_selected_experts[:, candidate_slot]
candidate_unused = (selected_experts_indices != candidate.unsqueeze(1)).all(dim=1)
fillable_slots = (selected_experts_indices < 0) & candidate_unused.unsqueeze(1)
first_fillable_slot = fillable_slots & (fillable_slots.cumsum(dim=1) == 1)
selected_experts_indices = torch.where(
first_fillable_slot,
candidate.unsqueeze(1),
selected_experts_indices,
)

top_scores = scores.gather(dim=1, index=selected_experts_indices)
return top_scores, selected_experts_indices

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None, routed_experts: torch.Tensor | None = None
Expand Down Expand Up @@ -468,18 +505,25 @@ def forward(
# top_scores is still derived from the original scores.

if routed_experts is not None:
top_scores = scores.gather(dim=1, index=routed_experts)
selected_experts_indices = routed_experts
routed_experts = routed_experts.to(torch.long)
if self.router_replay_score_threshold_ratio > 0:
router_top_scores, router_selected_experts = self._topk_from_scores(scores, expert_bias)
top_scores, selected_experts_indices = self._filter_replayed_experts(
scores,
routed_experts,
router_top_scores,
router_selected_experts,
)
else:
top_scores = scores.gather(dim=1, index=routed_experts)
selected_experts_indices = routed_experts
elif self.force_balanced:
num_tokens = scores.shape[0]
arange = torch.arange(num_tokens * self.top_k, device=scores.device)
selected_experts_indices = (arange % self.num_experts).view(num_tokens, self.top_k)
top_scores = scores.gather(dim=1, index=selected_experts_indices)
elif expert_bias is not None:
_, selected_experts_indices = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
top_scores = scores.gather(dim=1, index=selected_experts_indices)
else:
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
top_scores, selected_experts_indices = self._topk_from_scores(scores, expert_bias)

routing_confidence_sum = _selected_probability_mass_sum(scores, top_scores, self.score_func)

Expand All @@ -490,7 +534,7 @@ def forward(

# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
num_tokens_per_expert = torch.histc(
selected_experts_indices.reshape(-1),
selected_experts_indices.reshape(-1).float(),
bins=self.num_experts,
min=0,
max=self.num_experts,
Expand Down Expand Up @@ -542,7 +586,7 @@ def forward(
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
selected_experts_indices = selected_experts_indices.reshape(-1)
num_tokens_per_expert = torch.histc(
selected_experts_indices,
selected_experts_indices.float(),
bins=self.num_experts,
min=0,
max=self.num_experts,
Expand Down
3 changes: 3 additions & 0 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
setup_model,
is_tt_moe_model,
get_load_balance_stats,
configure_router_replay_filter,
)
from prime_rl.trainer.parallel_dims import get_parallel_dims
from prime_rl.trainer.perf import get_perf_counter
Expand Down Expand Up @@ -145,6 +146,8 @@ def train(config: TrainerConfig):
logger.info(f"Initializing model ({config.model})")
loading_from_ckpt_later = config.ckpt and checkpoint_step is not None
model = setup_model(config.model, parallel_dims, loading_from_ckpt_later)
if config.enable_router_replay:
configure_router_replay_filter(model, config.router_replay_score_threshold_ratio)

logger.info(f"Initializing tokenizer ({config.tokenizer})")
tokenizer = setup_tokenizer(config.tokenizer)
Expand Down
Loading