Skip to content

Commit 3eab5a7

Browse files
committed
fix return_routed_experts
1 parent bf83078 commit 3eab5a7

3 files changed

Lines changed: 33 additions & 24 deletions

File tree

lightllm/common/basemodel/routing_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def routing_dtype_id_to_np(dtype_id: int):
1616
if dtype_id == 1:
17-
return np.int8
17+
return np.uint8
1818
elif dtype_id == 2:
1919
return np.int16
2020
return np.int32
@@ -39,8 +39,8 @@ def __init__(
3939
self.num_experts = num_experts
4040
self.kv_cache_size = kv_cache_size
4141

42-
self.dtype = torch.int8 if num_experts <= 127 else torch.int16
43-
dtype_bytes = 1 if self.dtype == torch.int8 else 2
42+
self.dtype = torch.uint8 if num_experts <= 255 else torch.int16
43+
dtype_bytes = 1 if self.dtype == torch.uint8 else 2
4444

4545
# Shape: (num_moe_layers, kv_cache_size, topk) — on CPU to save GPU memory.
4646
# Written after forward() via flush_to_routing_buffer(), read on request finish.
@@ -57,7 +57,7 @@ def __init__(
5757
torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2)
5858
]
5959

60-
dtype_name = "int8" if self.dtype == torch.int8 else "int16"
60+
dtype_name = "uint8" if self.dtype == torch.uint8 else "int16"
6161
logger.info(
6262
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
6363
f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, "
@@ -66,11 +66,11 @@ def __init__(
6666

6767
@property
6868
def np_dtype(self):
69-
return np.int8 if self.dtype == torch.int8 else np.int16
69+
return np.uint8 if self.dtype == torch.uint8 else np.int16
7070

7171
@property
7272
def dtype_id(self) -> int:
73-
return 1 if self.dtype == torch.int8 else 2
73+
return 1 if self.dtype == torch.uint8 else 2
7474

7575
def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
7676
num_tokens = topk_ids.shape[0]

lightllm/server/api_lightllm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
3535
prompt = request_dict.pop("inputs")
3636
sample_params_dict = request_dict["parameters"]
3737
return_details = sample_params_dict.pop("return_details", False)
38+
return_routed_experts = sample_params_dict.pop(
39+
"return_routed_experts", httpserver_manager.args.enable_return_routed_experts
40+
)
3841
sampling_params = SamplingParams()
3942
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
4043
sampling_params.verify()
@@ -105,7 +108,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
105108
ret["prompt_logprobs"] = prompt_logprobs
106109
if input_usage is not None:
107110
ret["input_usage"] = input_usage
108-
if routed_experts_data is not None:
111+
if return_routed_experts and routed_experts_data is not None:
109112
ret["routed_experts"] = routed_experts_data
110113

111114
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))
@@ -117,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer
117120
prompt = request_dict.pop("inputs")
118121
sample_params_dict = request_dict["parameters"]
119122
_ = sample_params_dict.pop("return_details", False)
123+
_ = sample_params_dict.pop("return_routed_experts", None)
120124
sampling_params = SamplingParams()
121125
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
122126
sampling_params.verify()

lightllm/server/core/objs/sampling_params.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -357,23 +357,28 @@ class SamplingParams(ctypes.Structure):
357357

358358
def init(self, tokenizer, **kwargs):
359359
super().__init__()
360-
self.best_of = kwargs.get("best_of", 1)
361-
self.n = kwargs.get("n", self.best_of)
362-
self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample)
363-
self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty)
364-
self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty)
365-
self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty)
366-
self.temperature = kwargs.get("temperature", SamplingParams._temperature)
367-
self.top_p = kwargs.get("top_p", SamplingParams._top_p)
368-
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
369-
self.ignore_eos = kwargs.get("ignore_eos", False)
370-
self.min_pixels = kwargs.get("min_pixels", -1)
371-
self.max_pixels = kwargs.get("max_pixels", -1)
372-
self.max_new_tokens = kwargs.get("max_new_tokens", 16)
373-
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
374-
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
375-
self.group_request_id = kwargs.get("group_request_id", -1)
376-
self.suggested_dp_index = kwargs.get("suggested_dp_index", -1)
360+
361+
def _get(key, default):
362+
v = kwargs.get(key)
363+
return v if v is not None else default
364+
365+
self.best_of = _get("best_of", 1)
366+
self.n = _get("n", self.best_of)
367+
self.do_sample = _get("do_sample", SamplingParams._do_sample)
368+
self.presence_penalty = _get("presence_penalty", SamplingParams._presence_penalty)
369+
self.frequency_penalty = _get("frequency_penalty", SamplingParams._frequency_penalty)
370+
self.repetition_penalty = _get("repetition_penalty", SamplingParams._repetition_penalty)
371+
self.temperature = _get("temperature", SamplingParams._temperature)
372+
self.top_p = _get("top_p", SamplingParams._top_p)
373+
self.top_k = _get("top_k", SamplingParams._top_k)
374+
self.ignore_eos = _get("ignore_eos", False)
375+
self.min_pixels = _get("min_pixels", -1)
376+
self.max_pixels = _get("max_pixels", -1)
377+
self.max_new_tokens = _get("max_new_tokens", 16)
378+
self.min_new_tokens = _get("min_new_tokens", 1)
379+
self.input_penalty = _get("input_penalty", DEFAULT_INPUT_PENALTY)
380+
self.group_request_id = _get("group_request_id", -1)
381+
self.suggested_dp_index = _get("suggested_dp_index", -1)
377382

378383
self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS)
379384
self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False)

0 commit comments

Comments
 (0)