Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 93 additions & 35 deletions src/prime_rl/trainer/models/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Literal

import prime_moe

Check failure on line 7 in src/prime_rl/trainer/models/layers/moe.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (F401)

src/prime_rl/trainer/models/layers/moe.py:7:8: F401 `prime_moe` imported but unused
import torch
import torch.nn.functional as F

from dataclasses import dataclass
from typing import Literal
from torch import nn
from torchtitan.distributed.expert_parallel import expert_parallel

from prime_rl.configs.trainer import EPCommBackend

Check failure on line 15 in src/prime_rl/trainer/models/layers/moe.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (I001)

src/prime_rl/trainer/models/layers/moe.py:7:1: I001 Import block is un-sorted or un-formatted


@dataclass
Expand Down Expand Up @@ -85,6 +85,9 @@
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)

def _pack_gate_up_interleaved(w1: torch.Tensor, w3: torch.Tensor) -> torch.Tensor:
E, H, D = w1.shape
return torch.stack((w1, w3), dim=2).reshape(E, H<<1, D).contiguous()

# TODO: keeping this for-loop implementation for comparison
# and readability, may remove later
Expand Down Expand Up @@ -122,6 +125,40 @@

return out

def moe_align_block_size_simple(topk_ids: torch.Tensor, block_size_m: int, num_experts: int):
assert topk_ids.dim() == 2
device = topk_ids.device
T, top_k = topk_ids.shape
pair_ids = torch.arange(T * top_k, device=device, dtype=torch.int32)
expert_flat = topk_ids.reshape(-1).to(torch.int32)
sort_idx = torch.argsort(expert_flat)
expert_sorted = expert_flat[sort_idx]
pair_sorted = pair_ids[sort_idx]
counts = torch.bincount(expert_sorted, minlength=num_experts)
padded_counts = ((counts + block_size_m - 1) // block_size_m) * block_size_m
total_padded = int(padded_counts.sum().item())
sorted_token_ids = torch.empty((total_padded,), device=device, dtype=torch.int32)
num_blocks = int((padded_counts // block_size_m).sum().item())
expert_ids = torch.empty((num_blocks,), device=device, dtype=torch.int32)
write_ptr = 0
block_ptr = 0
read_ptr = 0
for e in range(num_experts):
c = int(counts[e].item())
pc = int(padded_counts[e].item())
if c > 0:
sorted_token_ids[write_ptr : write_ptr + c] = pair_sorted[read_ptr : read_ptr + c]
if pc > c:
sorted_token_ids[write_ptr + c : write_ptr + pc] = -1
nb = pc // block_size_m
if nb > 0:
expert_ids[block_ptr : block_ptr + nb] = e
block_ptr += nb
write_ptr += pc
read_ptr += c

num_tokens_post_padded = torch.tensor([total_padded], device=device, dtype=torch.int32)
return sorted_token_ids, expert_ids, num_tokens_post_padded

@expert_parallel
def _run_experts_for_loop(
Expand Down Expand Up @@ -705,6 +742,48 @@
routed_output = routed_outputs[0] if len(routed_outputs) == 1 else torch.cat(routed_outputs, dim=0)
return routed_output if shared_output is None else shared_output + routed_output

def _run_fused_routed_experts(
self,
x: torch.Tensor,
selected_experts_indices: torch.Tensor,
top_scores: torch.Tensor,
) -> torch.Tensor:
assert x.dim() == 2
assert selected_experts_indices.dim() == 2
assert top_scores.dim() == 2
top_k = selected_experts_indices.shape[1]
block_m = 128
bn = 32
wn = 8
stages = 1
w1_packed = _pack_gate_up_interleaved(
self.experts.w1.bfloat16(),
self.experts.w3.bfloat16(),
)
w2 = self.experts.w2.bfloat16()
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size_simple(
selected_experts_indices.to(torch.int32),
block_m,
self.experts.num_experts,
)
out = torch.zeros_like(x, dtype=torch.bfloat16)
torch.ops.prime_moe.fused_moe_bf16(
x.bfloat16(),
w1_packed,
w2,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_scores.float(),
out,
top_k,
block_m,
bn,
wn,
stages,
)
return out.type_as(x)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -736,11 +815,6 @@
routing_confidence_sum,
) = self.router(x, self.expert_bias, routed_experts=routed_experts)

# tokens_per_expert will be used to update the expert bias for load balancing.
# and also to count the expert usage
# Full block checkpointing can double count tokens_per_expert because it reruns the router
# in backward. The selective MoE path avoids that by checkpointing only the
# routed expert compute below.
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)
self.routing_confidence_sum.add_(routing_confidence_sum)
Expand All @@ -749,35 +823,19 @@
routed_output = self._run_deepep_routed_experts(x, selected_experts_indices, top_scores)
return routed_output.reshape(bs, slen, dim)

# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
# NOTE: the reason we need to compute num_tokens_per_expert again is:
# 1st computation in router is to update self.tokens_per_expert
# which would be the same across all TP ranks.
# 2nd computation in reorderer is for the actual routing and experts computation
# which would be sharded over TP ranks if expert_tensor_parallel_degree==1.
# If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree.
(
top_scores_experts_sorted,
token_indices_experts_sorted,
num_tokens_per_expert,
) = self.reorderer(top_scores, selected_experts_indices)

routed_output = self._run_routed_experts(
if self.score_before_experts:
raise RuntimeError(
"Fused kernel expects output weighting - Set score_before_experts=False or implement pre-score weighting inside the kernel"
)
routed_output = self._run_fused_routed_experts(
x,
token_indices_experts_sorted,
num_tokens_per_expert,
top_scores_experts_sorted,
selected_experts_indices,
top_scores,

)
if self.shared_expert is not None:
out = self.shared_expert(x)
else:
out = torch.zeros_like(x)

routed_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, dim)
out = out.scatter_add(dim=0, index=routed_indices, src=routed_output)
out = out.reshape(bs, slen, dim)
return out
routed_output = routed_output + self.shared_expert(x)
return routed_output.reshape(bs, slen, dim)

def init_weights(
self,
Expand Down
Loading