Skip to content

Commit ca0c282

Browse files
committed
draft
1 parent b1094e7 commit ca0c282

File tree

20 files changed

+152
-218
lines changed

20 files changed

+152
-218
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111

1212
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1313
from lightllm.common.basemodel.infer_struct import InferStateInfo
14-
from lightllm.common.basemodel.routing_manager import (
15-
create_routing_capture_manager,
16-
reset_moe_layer_counter,
17-
get_moe_layer_count,
18-
)
14+
from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter
1915
from lightllm.common.kv_cache_mem_manager import MemoryManager
2016
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
2117
from lightllm.common.req_manager import ReqManager
@@ -282,45 +278,16 @@ def _init_prefill_cuda_graph(self):
282278
self.prefill_graph.warmup(self)
283279

284280
def _init_custom(self):
285-
if self.args.enable_return_routed_experts:
286-
# Get MoE layer count from counter (set during _init_weights)
287-
num_moe_layers = get_moe_layer_count()
288-
if num_moe_layers == 0:
289-
logger.warning(
290-
"enable_return_routed_experts is set but no MoE layers found. "
291-
"Routing capture will not be enabled."
292-
)
293-
return
294-
295-
# Get MoE parameters from model config
296-
n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0))
297-
if n_routed_experts == 0:
298-
logger.warning(
299-
"enable_return_routed_experts is set but n_routed_experts=0. "
300-
"Routing capture will not be enabled."
301-
)
302-
return
281+
"""Hook for model-specific initialization. Override in subclasses."""
282+
pass
303283

304-
topk = self.config.get("num_experts_per_tok", 1)
305-
num_experts = n_routed_experts
284+
def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None:
285+
"""Hook called after forward pass completes. Override in subclasses for post-processing."""
286+
pass
306287

307-
# Check if overlap mode is enabled
308-
enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False)
309-
310-
logger.info(
311-
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
312-
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
313-
)
314-
315-
create_routing_capture_manager(
316-
num_moe_layers=num_moe_layers,
317-
topk=topk,
318-
num_experts=num_experts,
319-
batch_max_tokens=self.max_total_token_num,
320-
kv_cache_size=self.mem_manager.size,
321-
enable_overlap=enable_overlap,
322-
)
323-
return
288+
def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None:
289+
"""Hook called after dual microbatch forward pass completes. Override in subclasses."""
290+
pass
324291

325292
@torch.no_grad()
326293
def forward(self, model_input: ModelInput):
@@ -332,7 +299,7 @@ def forward(self, model_input: ModelInput):
332299
else:
333300
result = self._decode(model_input)
334301

335-
# Note: flush is now handled by backend layer (ChunkedPrefill, DP, etc.)
302+
self._post_forward(model_input)
336303
return result
337304

338305
def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
@@ -726,6 +693,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
726693
dist_group_manager.clear_deepep_buffer()
727694
model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
728695
model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
696+
self._post_forward_dual(model_input0, model_input1)
729697
return model_output0, model_output1
730698

731699
@torch.no_grad()
@@ -819,6 +787,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
819787
infer_state1.init_att_state()
820788

821789
model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1)
790+
self._post_forward_dual(model_input0, model_input1)
822791
return model_output0, model_output1
823792

824793
@final
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Mixin for MoE (Mixture of Experts) models.
2+
3+
Provides R3 (Rollout Router Replay) routing capture functionality for MoE models.
4+
MoE models that want R3 support should inherit from this mixin and call
5+
`_init_routing_capture()` in their `_init_custom()` method.
6+
"""
7+
8+
from lightllm.common.basemodel.batch_objs import ModelInput
9+
from lightllm.common.basemodel.routing_manager import (
10+
create_routing_capture_manager,
11+
get_moe_layer_count,
12+
flush_routing_capture,
13+
flush_routing_capture_dual,
14+
)
15+
from lightllm.utils.log_utils import init_logger
16+
17+
logger = init_logger(__name__)
18+
19+
20+
class MoeModelMixin:
21+
"""Mixin class providing R3 routing capture support for MoE models.
22+
23+
Usage:
24+
class MyMoeModel(MoeModelMixin, LlamaTpPartModel):
25+
def _init_custom(self):
26+
super()._init_custom()
27+
self._init_routing_capture() # Enable R3 if flag is set
28+
"""
29+
30+
def _init_routing_capture(self) -> None:
31+
"""Initialize R3 routing capture if enabled via --enable_return_routed_experts.
32+
33+
Should be called in the model's _init_custom() method after weights are loaded.
34+
This method is idempotent - safe to call multiple times.
35+
"""
36+
if not getattr(self.args, "enable_return_routed_experts", False):
37+
return
38+
39+
# Get MoE layer count from counter (set during _init_weights)
40+
num_moe_layers = get_moe_layer_count()
41+
if num_moe_layers == 0:
42+
logger.warning(
43+
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
44+
)
45+
return
46+
47+
# Get MoE parameters from model config
48+
n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0))
49+
if n_routed_experts == 0:
50+
logger.warning(
51+
"enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled."
52+
)
53+
return
54+
55+
topk = self.config.get("num_experts_per_tok", 1)
56+
num_experts = n_routed_experts
57+
58+
# Check if overlap mode is enabled
59+
enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False)
60+
61+
logger.info(
62+
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
63+
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
64+
)
65+
66+
create_routing_capture_manager(
67+
num_moe_layers=num_moe_layers,
68+
topk=topk,
69+
num_experts=num_experts,
70+
batch_max_tokens=self.max_total_token_num,
71+
kv_cache_size=self.mem_manager.size,
72+
enable_overlap=enable_overlap,
73+
)
74+
75+
def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None:
76+
"""Hook called after forward pass completes.
77+
78+
Flushes R3 routing capture data from GPU to CPU buffer.
79+
No-op if R3 is not enabled.
80+
"""
81+
flush_routing_capture(model_input.mem_indexes, microbatch_index)
82+
83+
def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None:
84+
"""Hook called after dual microbatch forward pass completes.
85+
86+
Flushes R3 routing capture data for both microbatches.
87+
No-op if R3 is not enabled.
88+
"""
89+
flush_routing_capture_dual(model_input0.mem_indexes, model_input1.mem_indexes)

lightllm/common/basemodel/routing_manager.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,31 @@ def create_routing_capture_manager(
182182
def get_routing_capture_manager() -> Optional[RoutingCaptureManager]:
183183
"""Get the global routing capture manager."""
184184
return g_routing_capture_manager
185+
186+
187+
def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
188+
"""Flush routing capture to CPU if manager is active.
189+
190+
Call after forward pass completes. No-op if R3 capture is not enabled.
191+
192+
Args:
193+
mem_indexes: KV cache slot indices for the batch
194+
microbatch_index: Microbatch index (0 for single batch, 0/1 for overlap)
195+
"""
196+
if g_routing_capture_manager is not None:
197+
g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index)
198+
199+
200+
def flush_routing_capture_dual(mem_indexes0: torch.Tensor, mem_indexes1: torch.Tensor) -> None:
201+
"""Flush routing capture for dual microbatch overlap mode.
202+
203+
Call after forward pass completes for both microbatches.
204+
No-op if R3 capture is not enabled.
205+
206+
Args:
207+
mem_indexes0: KV cache slot indices for microbatch 0
208+
mem_indexes1: KV cache slot indices for microbatch 1
209+
"""
210+
if g_routing_capture_manager is not None:
211+
g_routing_capture_manager.flush_to_cpu_async(mem_indexes0, microbatch_index=0)
212+
g_routing_capture_manager.flush_to_cpu_async(mem_indexes1, microbatch_index=1)

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ def _load_mlp(self, mlp_prefix):
246246

247247
def _init_moe(self):
248248
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
249-
250249
self.moe_gate = ROWMMWeight(
251250
weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight",
252251
data_type=self.data_type_,

lightllm/models/deepseek2/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
77
from lightllm.models.llama.model import LlamaTpPartModel
88
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
9+
from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin
910
from lightllm.utils.log_utils import init_logger
1011
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
1112
from lightllm.distributed.communication_op import dist_group_manager
@@ -15,7 +16,7 @@
1516

1617

1718
@ModelRegistry(["deepseek_v2", "deepseek_v3"])
18-
class Deepseek2TpPartModel(LlamaTpPartModel):
19+
class Deepseek2TpPartModel(MoeModelMixin, LlamaTpPartModel):
1920
# weight class
2021
transformer_weight_class = Deepseek2TransformerLayerWeight
2122

@@ -48,6 +49,7 @@ def _init_some_value(self):
4849
def _init_custom(self):
4950
self._init_to_get_yarn_rotary()
5051
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
52+
self._init_routing_capture() # R3 routing capture for MoE
5153

5254
def _verify_params(self):
5355
return super()._verify_params()

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6):
4242
def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor:
4343
hidden_states = input.view(-1, self.embed_dim_)
4444
num_tokens, hidden_dim = hidden_states.shape
45-
4645
router_logits = layer_weight.moe_gate.mm(hidden_states)
4746
hidden_states = layer_weight.experts.experts(
4847
hidden_states,

lightllm/models/gpt_oss/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
33
from lightllm.models.llama.model import LlamaTpPartModel
44
from lightllm.models.registry import ModelRegistry
5+
from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin
56

67
from lightllm.utils.envs_utils import get_env_start_args
78
from lightllm.utils.log_utils import init_logger
@@ -10,7 +11,7 @@
1011

1112

1213
@ModelRegistry("gpt_oss")
13-
class GptOssTpPartModel(LlamaTpPartModel):
14+
class GptOssTpPartModel(MoeModelMixin, LlamaTpPartModel):
1415
# weight class
1516
transformer_weight_class = GptOssTransformerLayerWeight
1617

@@ -25,3 +26,7 @@ def __init__(self, kvargs):
2526
assert (
2627
get_env_start_args().llm_decode_att_backend[0] == "fa3"
2728
), "For now GPT-OSS type model only support flashattention-3"
29+
30+
def _init_custom(self):
31+
super()._init_custom()
32+
self._init_routing_capture() # R3 routing capture for MoE

lightllm/models/mixtral/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from lightllm.models.registry import ModelRegistry
44
from lightllm.common.basemodel.basemodel import TpPartBaseModel
5+
from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin
56
from lightllm.common.kv_cache_mem_manager import MemoryManager
67
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
78
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
@@ -16,7 +17,7 @@
1617

1718

1819
@ModelRegistry("mixtral")
19-
class MixtralTpPartModel(TpPartBaseModel):
20+
class MixtralTpPartModel(MoeModelMixin, TpPartBaseModel):
2021
# weight class
2122
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
2223
transformer_weight_class = MixtralTransformerLayerWeight
@@ -45,6 +46,7 @@ def _verify_params(self):
4546

4647
def _init_custom(self):
4748
self._init_to_get_rotary()
49+
self._init_routing_capture() # R3 routing capture for MoE
4850
return
4951

5052
def _init_mem_manager(self):

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def _moe_ffn(
131131

132132
hidden_states = input.view(-1, self.embed_dim_)
133133
num_tokens, hidden_dim = hidden_states.shape
134-
135134
router_logits = layer_weight.moe_gate.mm(hidden_states)
136135
layer_weight.experts.experts(
137136
hidden_states,

lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def _init_moe(self):
6060
tp_rank=0,
6161
tp_world_size=1,
6262
)
63-
6463
moe_mode = os.getenv("MOE_MODE", "TP")
6564
assert moe_mode in ["EP", "TP"]
6665
if moe_mode == "TP":

0 commit comments

Comments
 (0)