From ee87de1d07b788290fe970e3c22f09d1b4001951 Mon Sep 17 00:00:00 2001 From: Mario Sieg Date: Fri, 22 May 2026 15:48:44 +0200 Subject: [PATCH] Integration finally done --- src/prime_rl/trainer/models/layers/moe.py | 129 ++++++++++++++++------ 1 file changed, 93 insertions(+), 36 deletions(-) diff --git a/src/prime_rl/trainer/models/layers/moe.py b/src/prime_rl/trainer/models/layers/moe.py index 69998c922d..fe942922dd 100644 --- a/src/prime_rl/trainer/models/layers/moe.py +++ b/src/prime_rl/trainer/models/layers/moe.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Literal - +import prime_moe import torch import torch.nn.functional as F + +from dataclasses import dataclass +from typing import Literal from torch import nn from torchtitan.distributed.expert_parallel import expert_parallel - from prime_rl.configs.trainer import EPCommBackend @@ -84,6 +84,9 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) +def _pack_gate_up_interleaved(w1: torch.Tensor, w3: torch.Tensor) -> torch.Tensor: + E, H, D = w1.shape + return torch.stack((w1, w3), dim=2).reshape(E, H<<1, D).contiguous() # TODO: keeping this for-loop implementation for comparison # and readability, may remove later @@ -121,6 +124,40 @@ def _run_experts_for_loop_impl( return out +def moe_align_block_size_simple(topk_ids: torch.Tensor, block_size_m: int, num_experts: int): + assert topk_ids.dim() == 2 + device = topk_ids.device + T, top_k = topk_ids.shape + pair_ids = torch.arange(T * top_k, device=device, dtype=torch.int32) + expert_flat = topk_ids.reshape(-1).to(torch.int32) + sort_idx = torch.argsort(expert_flat) + expert_sorted = expert_flat[sort_idx] + pair_sorted = pair_ids[sort_idx] + counts = torch.bincount(expert_sorted, minlength=num_experts) + padded_counts = ((counts + block_size_m - 1) // block_size_m) * block_size_m + total_padded = int(padded_counts.sum().item()) + sorted_token_ids = torch.empty((total_padded,), device=device, dtype=torch.int32) + num_blocks = int((padded_counts // block_size_m).sum().item()) + expert_ids = torch.empty((num_blocks,), device=device, dtype=torch.int32) + write_ptr = 0 + block_ptr = 0 + read_ptr = 0 + for e in range(num_experts): + c = int(counts[e].item()) + pc = int(padded_counts[e].item()) + if c > 0: + sorted_token_ids[write_ptr : write_ptr + c] = pair_sorted[read_ptr : read_ptr + c] + if pc > c: + sorted_token_ids[write_ptr + c : write_ptr + pc] = -1 + nb = pc // block_size_m + if nb > 0: + expert_ids[block_ptr : block_ptr + nb] = e + block_ptr += nb + write_ptr += pc + read_ptr += c + + num_tokens_post_padded = torch.tensor([total_padded], device=device, dtype=torch.int32) + return sorted_token_ids, expert_ids, num_tokens_post_padded @expert_parallel def _run_experts_for_loop( @@ -665,6 +702,48 @@ def run_pending_chunk(pending_state): routed_output = routed_outputs[0] if len(routed_outputs) == 1 else torch.cat(routed_outputs, dim=0) return routed_output if shared_output is None else shared_output + routed_output + def _run_fused_routed_experts( + self, + x: torch.Tensor, + selected_experts_indices: torch.Tensor, + top_scores: torch.Tensor, + ) -> torch.Tensor: + assert x.dim() == 2 + assert selected_experts_indices.dim() == 2 + assert top_scores.dim() == 2 + top_k = selected_experts_indices.shape[1] + block_m = 128 + bn = 32 + wn = 8 + stages = 1 + w1_packed = _pack_gate_up_interleaved( + self.experts.w1.bfloat16(), + self.experts.w3.bfloat16(), + ) + w2 = self.experts.w2.bfloat16() + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size_simple( + selected_experts_indices.to(torch.int32), + block_m, + self.experts.num_experts, + ) + out = torch.zeros_like(x, dtype=torch.bfloat16) + torch.ops.prime_moe.fused_moe_bf16( + x.bfloat16(), + w1_packed, + w2, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_scores.float(), + out, + top_k, + block_m, + bn, + wn, + stages, + ) + return out.type_as(x) + def forward( self, x: torch.Tensor, @@ -695,47 +774,25 @@ def forward( num_tokens_per_expert, ) = self.router(x, self.expert_bias, routed_experts=routed_experts) - # tokens_per_expert will be used to update the expert bias for load balancing. - # and also to count the expert usage - # Full block checkpointing can double count tokens_per_expert because it reruns the router - # in backward. The selective MoE path avoids that by checkpointing only the - # routed expert compute below. with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) - if self.ep_comm_backend == "deepep": routed_output = self._run_deepep_routed_experts(x, selected_experts_indices, top_scores) return routed_output.reshape(bs, slen, dim) - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - # NOTE: the reason we need to compute num_tokens_per_expert again is: - # 1st computation in router is to update self.tokens_per_expert - # which would be the same across all TP ranks. - # 2nd computation in reorderer is for the actual routing and experts computation - # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. - # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. - ( - top_scores_experts_sorted, - token_indices_experts_sorted, - num_tokens_per_expert, - ) = self.reorderer(top_scores, selected_experts_indices) - - routed_output = self._run_routed_experts( + if self.score_before_experts: + raise RuntimeError( + "Fused kernel expects output weighting - Set score_before_experts=False or implement pre-score weighting inside the kernel" + ) + routed_output = self._run_fused_routed_experts( x, - token_indices_experts_sorted, - num_tokens_per_expert, - top_scores_experts_sorted, + selected_experts_indices, + top_scores, + ) if self.shared_expert is not None: - out = self.shared_expert(x) - else: - out = torch.zeros_like(x) - - routed_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, dim) - out = out.scatter_add(dim=0, index=routed_indices, src=routed_output) - out = out.reshape(bs, slen, dim) - return out + routed_output = routed_output + self.shared_expert(x) + return routed_output.reshape(bs, slen, dim) def init_weights( self,