Skip to content

Commit a00abc2

Browse files
author
Developer
committed
feat: add R3 routing support for LLM inference
Add routing capture and management infrastructure to support R3-style request routing across model inference backends. Includes routing manager, request/batch extensions, API endpoint additions, and backend integration for deepseek2, mixtral, qwen3_moe, llama, and gpt_oss models.
1 parent bbdc7ba commit a00abc2

39 files changed

+590
-395
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ dist
77
.vscode
88
tmp/
99
requirements-musa.txt
10+
CLAUDE.md

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1313
from lightllm.common.basemodel.infer_struct import InferStateInfo
14+
from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter
1415
from lightllm.common.kv_cache_mem_manager import MemoryManager
1516
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
1617
from lightllm.common.req_manager import ReqManager
@@ -164,6 +165,7 @@ def _init_quant(self):
164165
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")
165166

166167
def _init_weights(self, start_layer_index=0):
168+
reset_moe_layer_counter()
167169
self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config)
168170
self.trans_layers_weight = [
169171
self.transformer_weight_class(

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args
1414
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank
1515
from lightllm.utils.log_utils import init_logger
16+
from lightllm.common.basemodel.routing_manager import get_next_moe_layer_index
1617

1718
logger = init_logger(__name__)
1819

@@ -35,6 +36,7 @@ def __init__(
3536
network_config: Dict[str, Any] = None,
3637
) -> None:
3738
super().__init__(data_type=data_type)
39+
self.moe_layer_index = get_next_moe_layer_index()
3840
self.w1_weight_name = gate_proj_name
3941
self.w2_weight_name = down_proj_name
4042
self.w3_weight_name = up_proj_name
@@ -130,6 +132,7 @@ def experts(
130132
topk_group: int,
131133
num_expert_group: int,
132134
is_prefill: Optional[bool] = None,
135+
microbatch_index: int = 0,
133136
) -> torch.Tensor:
134137
"""Backward compatible method that routes to platform-specific implementation."""
135138
return self.fuse_moe_impl(
@@ -145,6 +148,8 @@ def experts(
145148
topk_group=topk_group,
146149
num_expert_group=num_expert_group,
147150
is_prefill=is_prefill,
151+
moe_layer_index=self.moe_layer_index,
152+
microbatch_index=microbatch_index,
148153
)
149154

150155
def low_latency_dispatch(

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.common.quantization import Quantcfg
99
from lightllm.common.quantization.quantize_method import QuantizationMethod
1010
from lightllm.utils.log_utils import init_logger
11+
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager
1112

1213
logger = init_logger(__name__)
1314

@@ -144,10 +145,14 @@ def experts(
144145
topk_group: int,
145146
num_expert_group: int,
146147
is_prefill: Optional[bool] = None,
148+
microbatch_index: int = 0,
147149
):
148150

149151
topk_weights, topk_ids = self._router(router_logits, top_k)
150152

153+
if g_routing_capture_manager is not None:
154+
g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)
155+
151156
w1, w1_scale = self.w1
152157
w2, w2_scale = self.w2
153158
use_fp8_w8a8 = self.quant_method is not None

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,7 @@ def __call__(
6262
topk_group: int,
6363
num_expert_group: int,
6464
is_prefill: Optional[bool] = None,
65+
moe_layer_index: Optional[int] = None,
66+
microbatch_index: int = 0,
6567
) -> torch.Tensor:
6668
pass

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.common.quantization.no_quant import WeightPack
44
from lightllm.common.quantization.quantize_method import QuantizationMethod
55
from .base_impl import FuseMoeBaseImpl
6+
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager
67

78

89
class FuseMoeTriton(FuseMoeBaseImpl):
@@ -124,6 +125,8 @@ def __call__(
124125
topk_group: int,
125126
num_expert_group: int,
126127
is_prefill: Optional[bool] = None,
128+
moe_layer_index: Optional[int] = None,
129+
microbatch_index: int = 0,
127130
):
128131
topk_weights, topk_ids = self._select_experts(
129132
input_tensor=input_tensor,
@@ -136,6 +139,10 @@ def __call__(
136139
num_expert_group=num_expert_group,
137140
scoring_func=scoring_func,
138141
)
142+
143+
if g_routing_capture_manager is not None and moe_layer_index is not None:
144+
g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)
145+
139146
output = self._fused_experts(
140147
input_tensor=input_tensor,
141148
w13=w13,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .mm_weight import (
22
MMWeightTpl,
33
)
4-
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight, QKVROWNMMWeight
4+
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight
55
from .colmm_weight import COLMMWeight
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import torch
2+
import numpy as np
3+
from typing import Optional
4+
from lightllm.utils.log_utils import init_logger
5+
from lightllm.utils.dist_utils import get_current_rank_in_dp
6+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
7+
from lightllm.utils.envs_utils import get_unique_server_name
8+
9+
logger = init_logger(__name__)
10+
11+
12+
class SharedRoutingConfig:
13+
"""Shared MoE routing configuration across processes."""
14+
15+
def __init__(self):
16+
service_name = get_unique_server_name()
17+
self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)
18+
19+
@property
20+
def num_moe_layers(self) -> int:
21+
return int(self._shm.arr[0])
22+
23+
@num_moe_layers.setter
24+
def num_moe_layers(self, value: int):
25+
self._shm.arr[0] = value
26+
27+
@property
28+
def topk(self) -> int:
29+
return int(self._shm.arr[1])
30+
31+
@topk.setter
32+
def topk(self, value: int):
33+
self._shm.arr[1] = value
34+
35+
def is_initialized(self) -> bool:
36+
return self.num_moe_layers > 0 and self.topk > 0
37+
38+
39+
_shared_routing_config: Optional[SharedRoutingConfig] = None
40+
41+
42+
def get_shared_routing_config() -> SharedRoutingConfig:
43+
"""Get or create the shared routing config."""
44+
global _shared_routing_config
45+
if _shared_routing_config is None:
46+
_shared_routing_config = SharedRoutingConfig()
47+
return _shared_routing_config
48+
49+
50+
_moe_layer_counter: int = 0
51+
52+
53+
def reset_moe_layer_counter() -> None:
54+
global _moe_layer_counter
55+
_moe_layer_counter = 0
56+
57+
58+
def get_next_moe_layer_index() -> int:
59+
global _moe_layer_counter
60+
idx = _moe_layer_counter
61+
_moe_layer_counter += 1
62+
return idx
63+
64+
65+
def get_moe_layer_count() -> int:
66+
return _moe_layer_counter
67+
68+
69+
class RoutingCaptureManager:
70+
"""Captures MoE routing decisions"""
71+
72+
def __init__(
73+
self,
74+
num_moe_layers: int,
75+
topk: int,
76+
num_experts: int,
77+
batch_max_tokens: int,
78+
kv_cache_size: int,
79+
enable_overlap: bool = False,
80+
):
81+
self.num_moe_layers = num_moe_layers
82+
self.topk = topk
83+
self.num_experts = num_experts
84+
self.batch_max_tokens = batch_max_tokens
85+
self.kv_cache_size = kv_cache_size
86+
87+
self.dtype = torch.int8 if num_experts <= 127 else torch.int16
88+
dtype_bytes = 1 if self.dtype == torch.int8 else 2
89+
90+
self.num_slots = 2 if enable_overlap else 1
91+
92+
gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes
93+
self.gpu_buffer = torch.zeros(
94+
(self.num_slots, num_moe_layers, batch_max_tokens, topk),
95+
dtype=self.dtype,
96+
device="cuda",
97+
)
98+
99+
cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
100+
self.cpu_buffer = torch.zeros(
101+
(num_moe_layers, kv_cache_size, topk),
102+
dtype=self.dtype,
103+
device="cpu",
104+
pin_memory=True,
105+
)
106+
107+
self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)]
108+
self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)]
109+
110+
dtype_name = "int8" if self.dtype == torch.int8 else "int16"
111+
logger.info(
112+
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
113+
f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, "
114+
f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}"
115+
)
116+
117+
def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
118+
num_tokens = topk_ids.shape[0]
119+
self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype)
120+
121+
def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None:
122+
num_tokens = mem_indexes.shape[0]
123+
if num_tokens == 0:
124+
return
125+
126+
slot = microbatch_index % self.num_slots
127+
stream = self.flush_streams[slot]
128+
event = self.flush_events[slot]
129+
130+
stream.wait_stream(torch.cuda.current_stream())
131+
132+
with torch.cuda.stream(stream):
133+
cpu_indexes = mem_indexes.cpu()
134+
self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu()
135+
event.record()
136+
137+
def sync_events(self) -> None:
138+
"""Synchronize all flush events. Call once before batch extraction."""
139+
for event in self.flush_events:
140+
event.synchronize()
141+
142+
def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray:
143+
self.sync_events()
144+
return self.cpu_buffer[:, mem_indexes, :].numpy()
145+
146+
def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray:
147+
return self.cpu_buffer[:, mem_indexes, :].numpy()
148+
149+
150+
g_routing_capture_manager: Optional[RoutingCaptureManager] = None
151+
152+
153+
def create_routing_capture_manager(
154+
num_moe_layers: int,
155+
topk: int,
156+
num_experts: int,
157+
batch_max_tokens: int,
158+
kv_cache_size: int,
159+
enable_overlap: bool = False,
160+
) -> None:
161+
global g_routing_capture_manager
162+
assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
163+
g_routing_capture_manager = RoutingCaptureManager(
164+
num_moe_layers=num_moe_layers,
165+
topk=topk,
166+
num_experts=num_experts,
167+
batch_max_tokens=batch_max_tokens,
168+
kv_cache_size=kv_cache_size,
169+
enable_overlap=enable_overlap,
170+
)
171+
172+
173+
def init_routing_capture(model) -> None:
174+
if not getattr(model.args, "enable_return_routed_experts", False):
175+
return
176+
177+
if get_current_rank_in_dp() != 0:
178+
logger.info("Skipping routing capture initialization on non-zero rank")
179+
return
180+
181+
num_moe_layers = get_moe_layer_count()
182+
if num_moe_layers == 0:
183+
logger.warning(
184+
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
185+
)
186+
return
187+
188+
num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
189+
topk = model.config.get("num_experts_per_tok", 0)
190+
assert num_experts > 0 and topk > 0
191+
enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False)
192+
193+
logger.info(
194+
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
195+
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
196+
)
197+
198+
create_routing_capture_manager(
199+
num_moe_layers=num_moe_layers,
200+
topk=topk,
201+
num_experts=num_experts,
202+
batch_max_tokens=model.max_total_token_num,
203+
kv_cache_size=model.mem_manager.size + 1,
204+
enable_overlap=enable_overlap,
205+
)
206+
207+
shared_config = get_shared_routing_config()
208+
shared_config.num_moe_layers = num_moe_layers
209+
shared_config.topk = topk
210+
logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}")
211+
212+
213+
def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
214+
if g_routing_capture_manager is not None:
215+
g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index)

0 commit comments

Comments
 (0)