Skip to content

Commit 68954b0

Browse files
committed
feat: add MoE expert routing capture for R3 rollout replay
1 parent 02078ad commit 68954b0

36 files changed

Lines changed: 781 additions & 81 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ dist
77
.vscode
88
tmp/
99
requirements-musa.txt
10+
CLAUDE.md

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
num_fused_shared_experts: int = 0,
3434
layer_num: int = 0,
3535
network_config: Dict[str, Any] = None,
36+
moe_layer_index: int = 0,
3637
) -> None:
3738
super().__init__(data_type=data_type)
3839
self.w1_weight_name = gate_proj_name
@@ -50,6 +51,7 @@ def __init__(
5051
self.enable_ep_moe = get_env_start_args().enable_ep_moe
5152
self.n_routed_experts = n_routed_experts
5253
self.num_fused_shared_experts = num_fused_shared_experts
54+
self.moe_layer_index = moe_layer_index
5355
self._init_config(network_config)
5456
self._init_redundancy_expert_params()
5557
self._init_parallel_params()
@@ -130,6 +132,7 @@ def experts(
130132
topk_group: int,
131133
num_expert_group: int,
132134
is_prefill: Optional[bool] = None,
135+
microbatch_index: int = 0,
133136
) -> torch.Tensor:
134137
"""Backward compatible method that routes to platform-specific implementation."""
135138
return self.fuse_moe_impl(
@@ -145,6 +148,8 @@ def experts(
145148
topk_group=topk_group,
146149
num_expert_group=num_expert_group,
147150
is_prefill=is_prefill,
151+
moe_layer_index=self.moe_layer_index,
152+
microbatch_index=microbatch_index,
148153
)
149154

150155
def low_latency_dispatch(

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.common.quantization import Quantcfg
99
from lightllm.common.quantization.quantize_method import QuantizationMethod
1010
from lightllm.utils.log_utils import init_logger
11+
from lightllm.common.basemodel import routing_manager as _routing_mgr
1112

1213
logger = init_logger(__name__)
1314

@@ -46,6 +47,7 @@ def __init__(
4647
num_fused_shared_experts: int = 0,
4748
layer_num: int = 0,
4849
network_config: Dict[str, Any] = None,
50+
moe_layer_index: int = 0,
4951
) -> None:
5052
network_config["norm_topk_prob"] = None
5153
super().__init__(
@@ -62,6 +64,7 @@ def __init__(
6264
num_fused_shared_experts=num_fused_shared_experts,
6365
layer_num=layer_num,
6466
network_config=network_config,
67+
moe_layer_index=moe_layer_index,
6568
)
6669

6770
self.hidden_size = network_config["hidden_size"]
@@ -144,10 +147,15 @@ def experts(
144147
topk_group: int,
145148
num_expert_group: int,
146149
is_prefill: Optional[bool] = None,
150+
microbatch_index: int = 0,
147151
):
148152

149153
topk_weights, topk_ids = self._router(router_logits, top_k)
150154

155+
# Rollout router replay
156+
if _routing_mgr.g_routing_capture_manager is not None:
157+
_routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)
158+
151159
w1, w1_scale = self.w1
152160
w2, w2_scale = self.w2
153161
use_fp8_w8a8 = self.quant_method is not None

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,7 @@ def __call__(
6262
topk_group: int,
6363
num_expert_group: int,
6464
is_prefill: Optional[bool] = None,
65+
moe_layer_index: Optional[int] = None,
66+
microbatch_index: int = 0,
6567
) -> torch.Tensor:
6668
pass

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.common.quantization.no_quant import WeightPack
44
from lightllm.common.quantization.quantize_method import QuantizationMethod
55
from .base_impl import FuseMoeBaseImpl
6+
from lightllm.common.basemodel import routing_manager as _routing_mgr
67

78

89
class FuseMoeTriton(FuseMoeBaseImpl):
@@ -124,6 +125,8 @@ def __call__(
124125
topk_group: int,
125126
num_expert_group: int,
126127
is_prefill: Optional[bool] = None,
128+
moe_layer_index: Optional[int] = None,
129+
microbatch_index: int = 0,
127130
):
128131
topk_weights, topk_ids = self._select_experts(
129132
input_tensor=input_tensor,
@@ -136,6 +139,10 @@ def __call__(
136139
num_expert_group=num_expert_group,
137140
scoring_func=scoring_func,
138141
)
142+
143+
if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None:
144+
_routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)
145+
139146
output = self._fused_experts(
140147
input_tensor=input_tensor,
141148
w13=w13,
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import atexit
2+
import torch
3+
import numpy as np
4+
from multiprocessing import shared_memory
5+
from typing import Optional
6+
from lightllm.utils.log_utils import init_logger
7+
from lightllm.utils.dist_utils import get_current_rank_in_dp
8+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
9+
from lightllm.utils.envs_utils import get_unique_server_name
10+
from lightllm.utils.shm_utils import create_or_link_shm
11+
12+
logger = init_logger(__name__)
13+
14+
15+
def routing_dtype_id_to_np(dtype_id: int):
16+
if dtype_id == 1:
17+
return np.int8
18+
elif dtype_id == 2:
19+
return np.int16
20+
return np.int32
21+
22+
23+
def get_routing_config_shm() -> SharedArray:
24+
service_name = get_unique_server_name()
25+
return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32)
26+
27+
28+
class RoutingCaptureManager:
29+
def __init__(
30+
self,
31+
num_moe_layers: int,
32+
topk: int,
33+
num_experts: int,
34+
kv_cache_size: int,
35+
max_capture_tokens: int,
36+
):
37+
self.num_moe_layers = num_moe_layers
38+
self.topk = topk
39+
self.num_experts = num_experts
40+
self.kv_cache_size = kv_cache_size
41+
42+
self.dtype = torch.int8 if num_experts <= 127 else torch.int16
43+
dtype_bytes = 1 if self.dtype == torch.int8 else 2
44+
45+
# Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory.
46+
# Written after forward() via flush_to_routing_buffer(), read on request finish.
47+
routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
48+
self.routing_buffer = torch.zeros(
49+
(num_moe_layers, kv_cache_size, topk),
50+
dtype=self.dtype,
51+
device="cpu",
52+
)
53+
54+
# Capture buffers: simple contiguous tensors written to during forward().
55+
capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes
56+
self._capture_buffer = [
57+
torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2)
58+
]
59+
60+
dtype_name = "int8" if self.dtype == torch.int8 else "int16"
61+
logger.info(
62+
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
63+
f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, "
64+
f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}"
65+
)
66+
67+
@property
68+
def np_dtype(self):
69+
return np.int8 if self.dtype == torch.int8 else np.int16
70+
71+
@property
72+
def dtype_id(self) -> int:
73+
return 1 if self.dtype == torch.int8 else 2
74+
75+
def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
76+
num_tokens = topk_ids.shape[0]
77+
self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype)
78+
79+
def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None:
80+
buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk)
81+
buf_t = buf.permute(1, 0, 2).cpu()
82+
self.routing_buffer[:, mem_indexes[:num_tokens].cpu(), :] = buf_t
83+
84+
def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray:
85+
cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes
86+
return self.routing_buffer[:, cpu_indexes, :].numpy()
87+
88+
89+
g_routing_capture_manager: Optional[RoutingCaptureManager] = None
90+
91+
92+
def create_routing_capture_manager(
93+
num_moe_layers: int,
94+
topk: int,
95+
num_experts: int,
96+
kv_cache_size: int,
97+
max_capture_tokens: int,
98+
) -> None:
99+
global g_routing_capture_manager
100+
assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
101+
g_routing_capture_manager = RoutingCaptureManager(
102+
num_moe_layers=num_moe_layers,
103+
topk=topk,
104+
num_experts=num_experts,
105+
kv_cache_size=kv_cache_size,
106+
max_capture_tokens=max_capture_tokens,
107+
)
108+
109+
110+
def preallocate_routing_shm_pool(max_req_num: int, num_moe_layers: int, max_tokens: int, topk: int, np_dtype) -> None:
111+
"""Pre-allocate POSIX SHM segments for all request slots.
112+
113+
Each segment is sized for the maximum possible routing data so it can be
114+
reused across requests without create/destroy overhead.
115+
"""
116+
dtype_bytes = np.dtype(np_dtype).itemsize
117+
segment_size = num_moe_layers * max_tokens * topk * dtype_bytes
118+
service_name = get_unique_server_name()
119+
120+
for i in range(max_req_num):
121+
name = f"{service_name}_shm_routing_{i}"
122+
shm = create_or_link_shm(name, segment_size, auto_cleanup=True)
123+
shm.close() # close handle; SHM persists in /dev/shm
124+
125+
logger.info(
126+
f"Pre-allocated {max_req_num} routing SHM segments, "
127+
f"each {segment_size / 1024:.1f} KB (total {max_req_num * segment_size / 1024 / 1024:.1f} MB)"
128+
)
129+
130+
131+
def cleanup_routing_shm_pool() -> None:
132+
"""Unlink all pre-allocated routing SHM segments. Called at server shutdown."""
133+
try:
134+
from lightllm.utils.envs_utils import get_env_start_args
135+
136+
args = get_env_start_args()
137+
except Exception:
138+
return
139+
140+
service_name = get_unique_server_name()
141+
142+
for i in range(args.running_max_req_size):
143+
name = f"{service_name}_shm_routing_{i}"
144+
try:
145+
shm = shared_memory.SharedMemory(name=name)
146+
shm.close()
147+
shm.unlink()
148+
except Exception:
149+
pass
150+
151+
config_name = f"{service_name}_routing_config"
152+
try:
153+
shm = shared_memory.SharedMemory(name=config_name)
154+
shm.close()
155+
shm.unlink()
156+
except Exception:
157+
pass
158+
159+
160+
def init_routing_capture(model, num_moe_layers: int) -> None:
161+
dp_rank = get_current_rank_in_dp()
162+
logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}")
163+
if dp_rank != 0:
164+
logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}")
165+
return
166+
167+
if num_moe_layers == 0:
168+
logger.warning(
169+
"enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled."
170+
)
171+
return
172+
173+
num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
174+
topk = model.config.get("num_experts_per_tok", 0)
175+
assert num_experts > 0 and topk > 0
176+
177+
from lightllm.utils.envs_utils import get_env_start_args
178+
179+
args = get_env_start_args()
180+
181+
# Capture buffer must fit the max tokens in any single forward call.
182+
# For prefill that's batch_max_tokens; for decode it's graph_max_batch_size.
183+
batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192
184+
max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size)
185+
186+
logger.info(
187+
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
188+
f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}"
189+
)
190+
191+
create_routing_capture_manager(
192+
num_moe_layers=num_moe_layers,
193+
topk=topk,
194+
num_experts=num_experts,
195+
kv_cache_size=model.mem_manager.size + 1,
196+
max_capture_tokens=max_capture_tokens,
197+
)
198+
199+
mgr = g_routing_capture_manager
200+
np_dtype = mgr.np_dtype
201+
dtype_id = mgr.dtype_id
202+
203+
max_req_total_len = args.max_req_total_len
204+
205+
# Write config to cross-process SHM
206+
shm = get_routing_config_shm()
207+
shm.arr[0] = num_moe_layers
208+
shm.arr[1] = topk
209+
shm.arr[2] = dtype_id
210+
shm.arr[3] = max_req_total_len
211+
logger.info(
212+
f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, "
213+
f"dtype_id={dtype_id}, max_tokens={max_req_total_len}"
214+
)
215+
216+
preallocate_routing_shm_pool(
217+
max_req_num=args.running_max_req_size,
218+
num_moe_layers=num_moe_layers,
219+
max_tokens=max_req_total_len,
220+
topk=topk,
221+
np_dtype=np_dtype,
222+
)
223+
224+
atexit.register(cleanup_routing_shm_pool)

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def _moe_ffn(
312312
use_grouped_topk=self.n_group,
313313
topk_group=self.topk_group,
314314
num_expert_group=self.n_group,
315+
microbatch_index=infer_state.microbatch_index,
315316
)
316317

317318
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
@@ -339,6 +340,7 @@ def _moe_ffn_edp(
339340
topk_group=self.topk_group,
340341
num_expert_group=self.n_group,
341342
is_prefill=infer_state.is_prefill,
343+
microbatch_index=infer_state.microbatch_index,
342344
)
343345

344346
if self.n_shared_experts is not None:

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def _init_moe(self):
242242
# == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。
243243
if self.num_fused_shared_experts == 0:
244244
self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True)
245+
first_moe = self.network_config_["first_k_dense_replace"]
246+
freq = self.network_config_.get("moe_layer_freq", 1)
247+
moe_layer_index = (self.layer_num_ - first_moe) // freq
245248
self.experts = FusedMoeWeight(
246249
gate_proj_name="gate_proj",
247250
down_proj_name="down_proj",
@@ -256,6 +259,7 @@ def _init_moe(self):
256259
num_fused_shared_experts=self.num_fused_shared_experts,
257260
layer_num=self.layer_num_,
258261
network_config=self.network_config_,
262+
moe_layer_index=moe_layer_index,
259263
)
260264

261265
def _init_ffn(self):

lightllm/models/deepseek2/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
77
from lightllm.models.llama.model import LlamaTpPartModel
88
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
9+
from lightllm.common.basemodel.routing_manager import init_routing_capture
910
from lightllm.utils.log_utils import init_logger
1011
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
1112
from lightllm.distributed.communication_op import dist_group_manager
@@ -49,6 +50,9 @@ def _init_some_value(self):
4950
def _init_custom(self):
5051
self._init_to_get_yarn_rotary()
5152
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
53+
if self.args.enable_return_routed_experts:
54+
num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
55+
init_routing_capture(self, num_moe_layers)
5256

5357
def _verify_params(self):
5458
return super()._verify_params()

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
5151
use_grouped_topk=False,
5252
topk_group=None,
5353
num_expert_group=None,
54+
microbatch_index=infer_state.microbatch_index,
5455
)
5556
return hidden_states.view(num_tokens, hidden_dim)
5657

0 commit comments

Comments
 (0)