From 7e7f36fadeee80e51021c2dabca5ac606ea9988e Mon Sep 17 00:00:00 2001 From: faresobeid Date: Thu, 4 Jun 2026 05:50:45 +0530 Subject: [PATCH] extend router replay --- docs/inference.md | 2 + .../src/prime_rl/configs/trainer.py | 3 + src/prime_rl/trainer/model.py | 24 ++++++++ src/prime_rl/trainer/models/layers/moe.py | 60 ++++++++++++++++--- src/prime_rl/trainer/rl/train.py | 3 + 5 files changed, 84 insertions(+), 8 deletions(-) diff --git a/docs/inference.md b/docs/inference.md index 0552c8f245..b7d4439d8a 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -264,6 +264,8 @@ enable_router_replay = true # this will also auto-set the inference.enable_retur enable_return_routed_experts = true ``` +During training, the router still computes gate scores so gradients flow through it. By default, router replay uses the inference-selected experts exactly. Set `trainer.router_replay_score_threshold_ratio` to a value greater than `0` to filter replayed experts whose gate score falls below that fraction of the weakest expert in the trainer router's own top-k for that token, then fill those slots with router-selected candidates. + This however is not free, it adds a significant overhead to the HTTP requests as this payload can grow quite large. We reccomend increasing `orchestrator.*.env.num_workers` to allow for more parallelization on the verifiers side. Currently this feature is also not supported with CPU KV cache offload, which can have negative impact on the inference throughput. diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index 00f4e07deb..ffb66f3ad2 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -538,6 +538,9 @@ class TrainerConfig(BaseConfig): enable_router_replay: bool = False """Return routed experts in the batch so the trainer can replay routing. Requires ``enable_return_routed_experts=true`` on the vLLM server (or ``--enable-return-routed-experts``) and is only supported for custom models.""" + router_replay_score_threshold_ratio: float = Field(0.0, ge=0.0, le=1.0) + """When router replay is enabled, optionally accept a replayed expert only if its gate score is at least this fraction of the trainer router's weakest top-k gate score for that token. The default 0 disables plausibility filtering.""" + memory_profiler_path: Path | None = None """Path to write the memory profile to.""" diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index d169aeedbc..b08ce1fca7 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -371,6 +371,30 @@ def apply_force_balanced_routing(model: nn.Module) -> None: ) +def configure_router_replay_filter(model: nn.Module, score_threshold_ratio: float) -> None: + """Configure trainer-side plausibility filtering for replayed MoE routes.""" + logger = get_logger() + language_model = get_language_model(model) + routers = [] + + for layer in language_model.layers: + mlp = getattr(layer, "mlp", None) or getattr(layer, "feed_forward", None) + if isinstance(mlp, MoE): + routers.append(mlp.router) + + if not routers: + logger.warning("Router replay is enabled, but no token-choice MoE routers were found to configure.") + return + + for router in routers: + router.router_replay_score_threshold_ratio = score_threshold_ratio + + logger.info( + f"Configured router replay plausibility filtering on {len(routers)} MoE layers " + f"(score_threshold_ratio={score_threshold_ratio})." + ) + + def is_tt_moe_model(model: nn.Module) -> bool: return hasattr(model.config, "num_experts") or hasattr(model.config, "n_routed_experts") diff --git a/src/prime_rl/trainer/models/layers/moe.py b/src/prime_rl/trainer/models/layers/moe.py index 14d46b2f89..461275866e 100644 --- a/src/prime_rl/trainer/models/layers/moe.py +++ b/src/prime_rl/trainer/models/layers/moe.py @@ -427,6 +427,43 @@ def __init__( self.route_norm = route_norm self.route_scale = route_scale self.force_balanced = False + self.router_replay_score_threshold_ratio = 0.0 + + def _topk_from_scores( + self, scores: torch.Tensor, expert_bias: torch.Tensor | None + ) -> tuple[torch.Tensor, torch.Tensor]: + if expert_bias is not None: + _, selected_experts_indices = torch.topk(scores + expert_bias, k=self.top_k, dim=1) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) + return top_scores, selected_experts_indices + + def _filter_replayed_experts( + self, + scores: torch.Tensor, + routed_experts: torch.Tensor, + router_top_scores: torch.Tensor, + router_selected_experts: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + replayed_scores = scores.gather(dim=1, index=routed_experts) + threshold = router_top_scores.amin(dim=1, keepdim=True) * self.router_replay_score_threshold_ratio + keep_replayed = replayed_scores >= threshold + + selected_experts_indices = torch.where(keep_replayed, routed_experts, torch.full_like(routed_experts, -1)) + for candidate_slot in range(self.top_k): + candidate = router_selected_experts[:, candidate_slot] + candidate_unused = (selected_experts_indices != candidate.unsqueeze(1)).all(dim=1) + fillable_slots = (selected_experts_indices < 0) & candidate_unused.unsqueeze(1) + first_fillable_slot = fillable_slots & (fillable_slots.cumsum(dim=1) == 1) + selected_experts_indices = torch.where( + first_fillable_slot, + candidate.unsqueeze(1), + selected_experts_indices, + ) + + top_scores = scores.gather(dim=1, index=selected_experts_indices) + return top_scores, selected_experts_indices def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None, routed_experts: torch.Tensor | None = None @@ -468,18 +505,25 @@ def forward( # top_scores is still derived from the original scores. if routed_experts is not None: - top_scores = scores.gather(dim=1, index=routed_experts) - selected_experts_indices = routed_experts + routed_experts = routed_experts.to(torch.long) + if self.router_replay_score_threshold_ratio > 0: + router_top_scores, router_selected_experts = self._topk_from_scores(scores, expert_bias) + top_scores, selected_experts_indices = self._filter_replayed_experts( + scores, + routed_experts, + router_top_scores, + router_selected_experts, + ) + else: + top_scores = scores.gather(dim=1, index=routed_experts) + selected_experts_indices = routed_experts elif self.force_balanced: num_tokens = scores.shape[0] arange = torch.arange(num_tokens * self.top_k, device=scores.device) selected_experts_indices = (arange % self.num_experts).view(num_tokens, self.top_k) top_scores = scores.gather(dim=1, index=selected_experts_indices) - elif expert_bias is not None: - _, selected_experts_indices = torch.topk(scores + expert_bias, k=self.top_k, dim=1) - top_scores = scores.gather(dim=1, index=selected_experts_indices) else: - top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) + top_scores, selected_experts_indices = self._topk_from_scores(scores, expert_bias) routing_confidence_sum = _selected_probability_mass_sum(scores, top_scores, self.score_func) @@ -490,7 +534,7 @@ def forward( # group tokens together by expert indices from 0 to num_experts and pass that to experts forward num_tokens_per_expert = torch.histc( - selected_experts_indices.reshape(-1), + selected_experts_indices.reshape(-1).float(), bins=self.num_experts, min=0, max=self.num_experts, @@ -542,7 +586,7 @@ def forward( # group tokens together by expert indices from 0 to num_experts and pass that to experts forward selected_experts_indices = selected_experts_indices.reshape(-1) num_tokens_per_expert = torch.histc( - selected_experts_indices, + selected_experts_indices.float(), bins=self.num_experts, min=0, max=self.num_experts, diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 83afa666dc..434df54df6 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -42,6 +42,7 @@ setup_model, is_tt_moe_model, get_load_balance_stats, + configure_router_replay_filter, ) from prime_rl.trainer.parallel_dims import get_parallel_dims from prime_rl.trainer.perf import get_perf_counter @@ -145,6 +146,8 @@ def train(config: TrainerConfig): logger.info(f"Initializing model ({config.model})") loading_from_ckpt_later = config.ckpt and checkpoint_step is not None model = setup_model(config.model, parallel_dims, loading_from_ckpt_later) + if config.enable_router_replay: + configure_router_replay_filter(model, config.router_replay_score_threshold_ratio) logger.info(f"Initializing tokenizer ({config.tokenizer})") tokenizer = setup_tokenizer(config.tokenizer)