Skip to content

Commit d6e3b3e

Browse files
committed
feat: auto-calculate max_new_tokens to align with vLLM behavior
When max_new_tokens is not specified (None or -1), automatically calculate it as max_req_total_len - prompt_tokens. This aligns with vLLM's behavior where max_tokens defaults to the remaining context length. Changes: - sampling_params.py: default max_new_tokens changed from 16384 to -1 - py_sampling_params.py: default max_new_tokens changed from 16384 to None - manager.py: add auto-calculation logic in _check_and_repair_length
1 parent 391d2ea commit d6e3b3e

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
top_k: int = None, # -1 is for all
3939
ignore_eos: bool = False,
4040
image_max_patch_num: int = -1,
41-
max_new_tokens: int = 16384,
41+
max_new_tokens: int = None,
4242
min_new_tokens: int = 1,
4343
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
4444
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
@@ -141,11 +141,11 @@ def verify(self):
141141
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
142142
if self.top_k < -1 or self.top_k == 0:
143143
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
144-
if self.max_new_tokens < 1:
144+
if self.max_new_tokens is not None and self.max_new_tokens < 1:
145145
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
146146
if self.min_new_tokens < 1:
147147
raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.")
148-
if self.min_new_tokens > self.max_new_tokens:
148+
if self.max_new_tokens is not None and self.min_new_tokens > self.max_new_tokens:
149149
raise ValueError(
150150
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
151151
)

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def init(self, tokenizer, **kwargs):
345345
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
346346
self.ignore_eos = kwargs.get("ignore_eos", False)
347347
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
348-
self.max_new_tokens = kwargs.get("max_new_tokens", 16384)
348+
self.max_new_tokens = kwargs.get("max_new_tokens", -1)
349349
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
350350
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
351351
self.group_request_id = kwargs.get("group_request_id", -1)
@@ -439,11 +439,11 @@ def verify(self):
439439
raise ValueError(f"top_p must be in (0.0, 1.0], got {self.top_p}")
440440
if self.top_k < -1 or self.top_k == 0:
441441
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
442-
if self.max_new_tokens < 1:
442+
if self.max_new_tokens != -1 and self.max_new_tokens < 1:
443443
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
444444
if self.min_new_tokens < 1:
445445
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
446-
if self.min_new_tokens > self.max_new_tokens:
446+
if self.max_new_tokens != -1 and self.min_new_tokens > self.max_new_tokens:
447447
raise ValueError(
448448
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
449449
)

lightllm/server/httpserver/manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,24 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
477477
if not prompt_ids:
478478
raise ValueError("prompt_ids is empty")
479479
prompt_tokens = len(prompt_ids)
480+
481+
if sampling_params.max_new_tokens is None or sampling_params.max_new_tokens == -1:
482+
sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens
483+
if sampling_params.max_new_tokens < 1:
484+
raise ValueError(
485+
f"the input prompt token len {prompt_tokens} >= max_req_total_len {self.max_req_total_len}, "
486+
f"no space for output tokens"
487+
)
488+
if sampling_params.min_new_tokens > sampling_params.max_new_tokens:
489+
raise ValueError(
490+
f"min_new_tokens ({sampling_params.min_new_tokens}) > auto-calculated max_new_tokens "
491+
f"({sampling_params.max_new_tokens}), consider reducing input length or min_new_tokens"
492+
)
493+
logger.debug(
494+
f"max_new_tokens is unset, auto-calculate to {sampling_params.max_new_tokens} "
495+
f"(max_req_total_len {self.max_req_total_len} - prompt_tokens {prompt_tokens})"
496+
)
497+
480498
if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len:
481499
# use long_truncation_mode to truncate long input len req.
482500
if self.args.long_truncation_mode is None:

0 commit comments

Comments
 (0)