From 67b7621312e1c22181b12f008136d0d9079c225f Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Thu, 25 Jun 2026 08:47:38 +0000 Subject: [PATCH 1/4] feat(serve): add --generation-config CLI for server sampling defaults Align api_server with vLLM by loading HuggingFace generation_config.json as default sampling params, with optional override and lmdeploy fallback. Co-authored-by: Cursor --- lmdeploy/cli/serve.py | 6 + lmdeploy/cli/utils.py | 26 +++ lmdeploy/serve/anthropic/adapter.py | 20 +- .../serve/anthropic/endpoints/messages.py | 6 +- lmdeploy/serve/anthropic/protocol.py | 2 +- lmdeploy/serve/core/generation_config.py | 173 ++++++++++++++++++ lmdeploy/serve/openai/api_server.py | 66 ++++--- lmdeploy/serve/openai/protocol.py | 26 +-- lmdeploy/serve/openai/responses/protocol.py | 4 +- lmdeploy/serve/openai/responses/request.py | 21 ++- lmdeploy/serve/openai/responses/serving.py | 6 +- .../serve/openai/serving_chat_completion.py | 28 ++- lmdeploy/serve/openai/serving_completion.py | 28 ++- .../serve/test_generation_config.py | 103 +++++++++++ 14 files changed, 443 insertions(+), 72 deletions(-) create mode 100644 lmdeploy/serve/core/generation_config.py create mode 100644 tests/test_lmdeploy/serve/test_generation_config.py diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 01ac1d44f1..5d8e0174b6 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -92,6 +92,8 @@ def add_parser_api_server(): # model args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) + ArgumentHelper.generation_config(parser) + ArgumentHelper.override_generation_config(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') @@ -318,6 +320,8 @@ def api_server(args): reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, + generation_config=args.generation_config, + override_generation_config=args.override_generation_config, ) else: from lmdeploy.serve.openai.launch_server import launch_server @@ -350,6 +354,8 @@ def api_server(args): reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, + generation_config=args.generation_config, + override_generation_config=args.override_generation_config, ) @staticmethod diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 1dda62a7e8..810f54524e 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -301,6 +301,32 @@ def hf_overrides(parser): default=None, help='Extra arguments to be forwarded to the HuggingFace config.') + @staticmethod + def generation_config(parser): + """Add argument generation_config to parser.""" + return parser.add_argument( + '--generation-config', + type=str, + default='auto', + help='The folder path to the generation config. Defaults to "auto", the ' + 'generation config will be loaded from model path. If set to "lmdeploy", no ' + 'generation config is loaded, lmdeploy defaults will be used. If set to a folder ' + 'path, the generation config will be loaded from the specified folder path. ' + 'If max_new_tokens is specified in generation config, then it sets a ' + 'server-wide limit on the number of output tokens for all requests.') + + @staticmethod + def override_generation_config(parser): + """Add argument override_generation_config to parser.""" + return parser.add_argument( + '--override-generation-config', + type=json.loads, + default=None, + help='Overrides or sets generation config. e.g. \'{"temperature": 0.5}\'. If ' + 'used with --generation-config auto, the override parameters will be merged ' + 'with the default config from the model. If used with --generation-config ' + 'lmdeploy, only the override parameters are used.') + @staticmethod def use_logn_attn(parser): """Add argument use_logn_attn to parser.""" diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py index 6975c7c8ad..0474ca6580 100644 --- a/lmdeploy/serve/anthropic/adapter.py +++ b/lmdeploy/serve/anthropic/adapter.py @@ -9,6 +9,7 @@ import shortuuid from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from .protocol import ( @@ -341,15 +342,18 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[ return lm_messages -def to_generation_config(request: MessagesRequest) -> GenerationConfig: +def to_generation_config( + request: MessagesRequest, + server_defaults: dict | None = None, + override_max_new_tokens: int | None = None, +) -> GenerationConfig: """Map Anthropic messages request to LMDeploy generation config.""" - - return GenerationConfig( - max_new_tokens=request.max_tokens, - do_sample=True, - top_k=40 if request.top_k is None else request.top_k, - top_p=1.0 if request.top_p is None else request.top_p, - temperature=1.0 if request.temperature is None else request.temperature, + request_values = extract_request_sampling_values(request) + return build_generation_config( + request_values, + server_defaults or {}, + max_tokens=request.max_tokens, + override_max_new_tokens=override_max_new_tokens, stop_words=request.stop_sequences, include_stop_str_in_output=request.include_stop_str_in_output or False, skip_special_tokens=True, diff --git a/lmdeploy/serve/anthropic/endpoints/messages.py b/lmdeploy/serve/anthropic/endpoints/messages.py index 1f5b17a54b..d390b6a868 100644 --- a/lmdeploy/serve/anthropic/endpoints/messages.py +++ b/lmdeploy/serve/anthropic/endpoints/messages.py @@ -174,7 +174,11 @@ async def create_message(request: MessagesRequest, raw_request: Request): result_generator = server_context.async_engine.generate( engine_messages, session, - gen_config=to_generation_config(request), + gen_config=to_generation_config( + request, + server_defaults=server_context.server_sampling_defaults, + override_max_new_tokens=server_context.override_max_new_tokens, + ), tools=parsed_request.tools, stream_response=True, sequence_start=True, diff --git a/lmdeploy/serve/anthropic/protocol.py b/lmdeploy/serve/anthropic/protocol.py index 03f0b1c37f..37fa0f7bd7 100644 --- a/lmdeploy/serve/anthropic/protocol.py +++ b/lmdeploy/serve/anthropic/protocol.py @@ -104,7 +104,7 @@ class MessagesRequest(BaseModel): system: str | list[ContentBlockParam] | None = None stop_sequences: list[str] | None = None stream: bool = False - temperature: float | None = 1.0 + temperature: float | None = None top_p: float | None = None top_k: int | None = None metadata: dict[str, Any] | None = None diff --git a/lmdeploy/serve/core/generation_config.py b/lmdeploy/serve/core/generation_config.py new file mode 100644 index 0000000000..b11cf26dbc --- /dev/null +++ b/lmdeploy/serve/core/generation_config.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Server-side generation config resolution and sampling parameter merge +helpers.""" + +from __future__ import annotations + +from typing import Any + +from lmdeploy.messages import GenerationConfig +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +PROTOCOL_FALLBACKS: dict[str, Any] = { + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 40, + 'repetition_penalty': 1.0, + 'min_p': 0.0, + 'do_sample': True, +} + +SAMPLING_PARAM_KEYS = ( + 'temperature', + 'top_p', + 'top_k', + 'min_p', + 'repetition_penalty', + 'max_new_tokens', + 'do_sample', +) + +REQUEST_SAMPLING_FIELDS = ( + 'temperature', + 'top_p', + 'top_k', + 'min_p', + 'repetition_penalty', +) + + +def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]: + from transformers import GenerationConfig + + try: + cfg = GenerationConfig.from_pretrained(path, trust_remote_code=trust_remote_code) + return cfg.to_diff_dict() + except OSError: + return {} + + +def extract_sampling_params(config: dict[str, Any]) -> dict[str, Any]: + """Extract supported sampling parameters from a generation config dict.""" + return {key: config[key] for key in SAMPLING_PARAM_KEYS if key in config and config[key] is not None} + + +def resolve_server_sampling_defaults( + generation_config: str, + override: dict[str, Any] | None, + model_path: str, + trust_remote_code: bool, +) -> tuple[dict[str, Any], int | None]: + """Resolve server-side default sampling params from CLI flags. + + Returns: + A tuple of (sampling_defaults, override_max_new_tokens). + ``override_max_new_tokens`` is a server-wide cap/default when set. + """ + override = override or {} + src = generation_config + + if src == 'lmdeploy': + config: dict[str, Any] = {} + elif src == 'auto': + config = _load_hf_generation_config(model_path, trust_remote_code) + else: + config = _load_hf_generation_config(src, trust_remote_code) + + config.update(override) + sampling = extract_sampling_params(config) + + override_max_new_tokens = sampling.pop('max_new_tokens', None) + if override_max_new_tokens is not None: + override_max_new_tokens = int(override_max_new_tokens) + + if sampling and src != 'lmdeploy': + source = "the model's `generation_config.json`" if src == 'auto' else src + logger.info( + 'Using default sampling params from %s: %s. ' + 'Use `--generation-config lmdeploy` to disable.', + source, + sampling, + ) + elif sampling and override: + logger.info('Using override generation config sampling params: %s.', sampling) + + return sampling, override_max_new_tokens + + +def merge_sampling_params( + request_values: dict[str, Any], + server_defaults: dict[str, Any], + fallbacks: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Merge sampling params with request > server > protocol fallback + priority.""" + fallbacks = fallbacks or PROTOCOL_FALLBACKS + merged: dict[str, Any] = {} + all_keys = set(fallbacks) | set(server_defaults) | set(request_values) + for key in all_keys: + if key in request_values: + merged[key] = request_values[key] + elif key in server_defaults: + merged[key] = server_defaults[key] + elif key in fallbacks: + merged[key] = fallbacks[key] + return merged + + +def extract_request_sampling_values(request: Any) -> dict[str, Any]: + """Extract explicitly provided sampling fields from a request object.""" + values: dict[str, Any] = {} + for field in REQUEST_SAMPLING_FIELDS: + if not hasattr(request, field): + continue + value = getattr(request, field) + if value is not None: + values[field] = value + return values + + +def resolve_max_new_tokens( + max_completion_tokens: int | None, + max_tokens: int | None, + server_cap: int | None, +) -> int | None: + """Resolve output token limit with optional server-wide cap/default.""" + request_value = max_completion_tokens if max_completion_tokens is not None else max_tokens + if request_value is None: + return server_cap + if server_cap is not None: + return min(request_value, server_cap) + return request_value + + +def build_generation_config( + request_values: dict[str, Any], + server_defaults: dict[str, Any], + *, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + override_max_new_tokens: int | None = None, + fallbacks: dict[str, Any] | None = None, + **extra_kwargs: Any, +) -> GenerationConfig: + """Build ``GenerationConfig`` from merged sampling defaults and request + values.""" + merged = merge_sampling_params(request_values, server_defaults, fallbacks) + max_new_tokens = resolve_max_new_tokens( + max_completion_tokens, + max_tokens, + override_max_new_tokens, + ) + return GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=merged.get('do_sample', PROTOCOL_FALLBACKS['do_sample']), + top_k=merged['top_k'], + top_p=merged['top_p'], + temperature=merged['temperature'], + repetition_penalty=merged['repetition_penalty'], + min_p=merged['min_p'], + **extra_kwargs, + ) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e8a3ca6ecc..d39150487c 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -47,6 +47,11 @@ ) from lmdeploy.serve.anthropic import create_anthropic_router from lmdeploy.serve.core import AsyncEngine, EngineHealthMonitor +from lmdeploy.serve.core.generation_config import ( + build_generation_config, + extract_request_sampling_values, + resolve_server_sampling_defaults, +) from lmdeploy.serve.openai.protocol import ( AbortRequest, ChatCompletionRequest, @@ -110,6 +115,8 @@ class VariableInterface: allow_terminate_by_client: bool = False enable_abort_handling: bool = False response_parser_cls: type[ResponseParser] | None = None + server_sampling_defaults: dict = {} + override_max_new_tokens: int | None = None @classmethod def create_session(cls, user_session_id: int | None = None) -> Session: @@ -147,6 +154,23 @@ def get_engine_config(cls): return cls.async_engine.backend_config +def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfig: + """Build ``GenerationConfig`` with server and request sampling merge.""" + request_values = extract_request_sampling_values(request) + max_completion_tokens = getattr(request, 'max_completion_tokens', None) + max_tokens = getattr(request, 'max_tokens', None) + if max_completion_tokens is None and hasattr(request, 'max_output_tokens'): + max_completion_tokens = getattr(request, 'max_output_tokens', None) + return build_generation_config( + request_values, + VariableInterface.server_sampling_defaults, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + override_max_new_tokens=VariableInterface.override_max_new_tokens, + **extra_kwargs, + ) + + async def _with_request_cleanup(generator, result_generators, sessions): """Yield from an API generator and cleanup when the HTTP task exits.""" session_mgr = VariableInterface.get_session_manager() @@ -486,14 +510,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque # (e.g. GPT-OSS clears response_format and injects the schema into messages) request = response_parser.request - gen_config = GenerationConfig( - max_new_tokens=request.max_completion_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=gen_logprobs, - top_k=request.top_k, - top_p=request.top_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, include_stop_str_in_output=request.include_stop_str_in_output, @@ -501,7 +520,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque response_format=request.response_format, logits_processors=logits_processors, min_new_tokens=request.min_new_tokens, - min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, @@ -828,20 +846,13 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed is not None else None - max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens) - gen_config = GenerationConfig( - max_new_tokens=max_new_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=request.logprobs, - top_k=request.top_k, - top_p=request.top_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, skip_special_tokens=request.skip_special_tokens, - min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, @@ -1012,15 +1023,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): prompt = [dict(role='user', content=[text_input] + image_input)] input_ids = None - gen_config = GenerationConfig( - max_new_tokens=request.max_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=1 if request.return_logprob else None, - top_k=request.top_k, - top_p=request.top_p, - min_p=request.min_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, stop_token_ids=request.stop_token_ids, @@ -1546,6 +1551,8 @@ def serve(model_path: str, allow_terminate_by_client: bool = False, enable_abort_handling: bool = False, speculative_config: SpeculativeConfig | None = None, + generation_config: str = 'auto', + override_generation_config: dict | None = None, **kwargs): """An example to perform model inference through the command line interface. @@ -1606,6 +1613,15 @@ def serve(model_path: str, VariableInterface.allow_terminate_by_client = allow_terminate_by_client VariableInterface.enable_abort_handling = enable_abort_handling + server_defaults, override_max_new_tokens = resolve_server_sampling_defaults( + generation_config, + override_generation_config, + model_path, + trust_remote_code, + ) + VariableInterface.server_sampling_defaults = server_defaults + VariableInterface.override_max_new_tokens = override_max_new_tokens + ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http' if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 675cbf5103..40162faf15 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -144,8 +144,8 @@ class ChatCompletionRequest(BaseModel): model: str messages: str | list[dict[str, Any]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) - temperature: float | None = 0.7 - top_p: float | None = 1.0 + temperature: float | None = None + top_p: float | None = None tools: list[Tool] | None = Field(default=None, examples=[None]) tool_choice: ToolChoice | AllowedToolChoice | Literal[ 'auto', 'required', 'none'] = Field(default='auto', examples=['none']) @@ -176,17 +176,17 @@ class ChatCompletionRequest(BaseModel): response_format: ResponseFormat | None = Field(default=None, examples=[None]) # additional argument of lmdeploy do_preprocess: bool | None = True - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None repetition_ngram_size: int = Field(default=0, ge=0) repetition_ngram_threshold: int = Field(default=0, ge=0) session_id: int | None = -1 ignore_eos: bool | None = False skip_special_tokens: bool | None = True spaces_between_special_tokens: bool | None = True - top_k: int | None = 40 + top_k: int | None = None seed: int | None = None min_new_tokens: int | None = Field(default=None, examples=[None]) - min_p: float = 0.0 + min_p: float | None = None enable_thinking: bool | None = None # will be deprecated in the future return_token_ids: bool | None = False return_logprob: bool | None = False @@ -352,7 +352,7 @@ class CompletionRequest(BaseModel): model: str prompt: str | list[Any] suffix: str | None = None - temperature: float | None = 0.7 + temperature: float | None = None n: int | None = 1 logprobs: int | None = None max_completion_tokens: int | None = Field( @@ -362,29 +362,29 @@ class CompletionRequest(BaseModel): 'including visible output tokens and reasoning tokens'), ) max_tokens: int | None = Field( - default=16, - examples=[16], + default=None, + examples=[None], deprecated='max_tokens is deprecated in favor of the max_completion_tokens field', ) stop: str | list[str] | None = Field(default=None, examples=[None]) stream: bool | None = False stream_options: StreamOptions | None = Field(default=None, examples=[None]) - top_p: float | None = 1.0 + top_p: float | None = None echo: bool | None = False presence_penalty: float | None = 0.0 frequency_penalty: float | None = 0.0 user: str | None = None # additional argument of lmdeploy - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None repetition_ngram_size: int = Field(default=0, ge=0) repetition_ngram_threshold: int = Field(default=0, ge=0) session_id: int | None = -1 ignore_eos: bool | None = False skip_special_tokens: bool | None = True spaces_between_special_tokens: bool | None = True - top_k: int | None = 40 # for opencompass + top_k: int | None = None # for opencompass seed: int | None = None - min_p: float = 0.0 + min_p: float | None = None class CompletionResponseChoice(BaseModel): @@ -550,7 +550,7 @@ class GenerateReqInput(BaseModel): stop_token_ids: list[int] | None = None stream: bool | None = False temperature: float = 1.0 - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None ignore_eos: bool | None = False top_p: float = 1.0 top_k: int = 0 diff --git a/lmdeploy/serve/openai/responses/protocol.py b/lmdeploy/serve/openai/responses/protocol.py index ed08f16202..d02d220df7 100644 --- a/lmdeploy/serve/openai/responses/protocol.py +++ b/lmdeploy/serve/openai/responses/protocol.py @@ -82,10 +82,10 @@ class ResponsesRequest(BaseModel): presence_penalty: float | None = None frequency_penalty: float | None = None repetition_penalty: float | None = None - top_k: int | None = 40 + top_k: int | None = None stop: str | list[str] | None = None seed: int | None = None - min_p: float = 0.0 + min_p: float | None = None ignore_eos: bool | None = False skip_special_tokens: bool | None = True include_stop_str_in_output: bool | None = False diff --git a/lmdeploy/serve/openai/responses/request.py b/lmdeploy/serve/openai/responses/request.py index 2db1000532..6a7dc9f8df 100644 --- a/lmdeploy/serve/openai/responses/request.py +++ b/lmdeploy/serve/openai/responses/request.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from lmdeploy.serve.openai.responses.protocol import ResponsesRequest from lmdeploy.utils import get_logger @@ -264,20 +265,22 @@ def _response_format_from_text(text: Any) -> dict[str, Any] | None: raise ValueError(f'Unsupported text.format type: {format_type!r}.') -def to_generation_config(request: ResponsesRequest) -> GenerationConfig: +def to_generation_config( + request: ResponsesRequest, + server_defaults: dict | None = None, + override_max_new_tokens: int | None = None, +) -> GenerationConfig: stop_words = [request.stop] if isinstance(request.stop, str) else request.stop - return GenerationConfig( - max_new_tokens=request.max_output_tokens, - do_sample=True, - top_k=40 if request.top_k is None else request.top_k, - top_p=1.0 if request.top_p is None else request.top_p, - temperature=1.0 if request.temperature is None else request.temperature, + request_values = extract_request_sampling_values(request) + return build_generation_config( + request_values, + server_defaults or {}, + max_completion_tokens=request.max_output_tokens, + override_max_new_tokens=override_max_new_tokens, stop_words=stop_words, ignore_eos=request.ignore_eos, skip_special_tokens=request.skip_special_tokens, include_stop_str_in_output=request.include_stop_str_in_output, response_format=_response_format_from_text(request.text), - min_p=request.min_p, random_seed=request.seed, - repetition_penalty=1.0 if request.repetition_penalty is None else request.repetition_penalty, ) diff --git a/lmdeploy/serve/openai/responses/serving.py b/lmdeploy/serve/openai/responses/serving.py index 2837e483c5..64864ece80 100644 --- a/lmdeploy/serve/openai/responses/serving.py +++ b/lmdeploy/serve/openai/responses/serving.py @@ -97,7 +97,11 @@ async def create_response(self, request: ResponsesRequest, raw_request: Request) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='input') try: - gen_config = to_generation_config(request) + gen_config = to_generation_config( + request, + server_defaults=self.server_context.server_sampling_defaults, + override_max_new_tokens=self.server_context.override_max_new_tokens, + ) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='text') try: diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 362f0bf9e5..7ac750a617 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -1,12 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING +from lmdeploy.serve.core.generation_config import ( + PROTOCOL_FALLBACKS, + extract_request_sampling_values, + merge_sampling_params, +) + from .protocol import ChatCompletionRequest if TYPE_CHECKING: from .api_server import VariableInterface +def _effective_sampling(request: ChatCompletionRequest, server_context: 'VariableInterface') -> dict: + return merge_sampling_params( + extract_request_sampling_values(request), + server_context.server_sampling_defaults, + PROTOCOL_FALLBACKS, + ) + + def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() @@ -32,15 +46,17 @@ def check_request(request: ChatCompletionRequest, server_context: 'VariableInter if session_manager.has(request.session_id): return f'The session_id {request.session_id!r} is occupied.' + sampling = _effective_sampling(request, server_context) + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' - if not (0 < request.top_p <= 1): - return f'The top_p {request.top_p!r} must be in (0, 1].' - if request.top_k < 0: - return f'The top_k {request.top_k!r} cannot be a negative integer.' - if not (0 <= request.temperature <= 2): - return f'The temperature {request.temperature!r} must be in [0, 2]' + if not (0 < sampling['top_p'] <= 1): + return f'The top_p {sampling["top_p"]!r} must be in (0, 1].' + if sampling['top_k'] < 0: + return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.' + if not (0 <= sampling['temperature'] <= 2): + return f'The temperature {sampling["temperature"]!r} must be in [0, 2]' # Validate input_ids and image_data constraints. # messages has higher priority. input_ids and image_data are only used when diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index 759972db36..3047717c1f 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -1,12 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING +from lmdeploy.serve.core.generation_config import ( + PROTOCOL_FALLBACKS, + extract_request_sampling_values, + merge_sampling_params, +) + from .protocol import CompletionRequest if TYPE_CHECKING: from .api_server import VariableInterface +def _effective_sampling(request: CompletionRequest, server_context: 'VariableInterface') -> dict: + return merge_sampling_params( + extract_request_sampling_values(request), + server_context.server_sampling_defaults, + PROTOCOL_FALLBACKS, + ) + + def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() @@ -24,14 +38,16 @@ def check_request(request: CompletionRequest, server_context: 'VariableInterface if session_manager.has(request.session_id): return f'The session_id {request.session_id!r} is occupied.' + sampling = _effective_sampling(request, server_context) + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' - if not (0 < request.top_p <= 1): - return f'The top_p {request.top_p!r} must be in (0, 1].' - if request.top_k < 0: - return f'The top_k {request.top_k!r} cannot be a negative integer.' - if not (0 <= request.temperature <= 2): - return f'The temperature {request.temperature!r} must be in [0, 2]' + if not (0 < sampling['top_p'] <= 1): + return f'The top_p {sampling["top_p"]!r} must be in (0, 1].' + if sampling['top_k'] < 0: + return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.' + if not (0 <= sampling['temperature'] <= 2): + return f'The temperature {sampling["temperature"]!r} must be in [0, 2]' return '' diff --git a/tests/test_lmdeploy/serve/test_generation_config.py b/tests/test_lmdeploy/serve/test_generation_config.py new file mode 100644 index 0000000000..a5a9199a2f --- /dev/null +++ b/tests/test_lmdeploy/serve/test_generation_config.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +from lmdeploy.serve.core.generation_config import ( + PROTOCOL_FALLBACKS, + build_generation_config, + extract_request_sampling_values, + merge_sampling_params, + resolve_max_new_tokens, + resolve_server_sampling_defaults, +) +from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest + + +def test_merge_sampling_params_priority(): + merged = merge_sampling_params( + {'temperature': 0.2}, + {'temperature': 0.5, 'top_k': 10}, + PROTOCOL_FALLBACKS, + ) + assert merged['temperature'] == 0.2 + assert merged['top_k'] == 10 + assert merged['top_p'] == PROTOCOL_FALLBACKS['top_p'] + + +def test_merge_sampling_params_uses_server_then_fallback(): + merged = merge_sampling_params({}, {'temperature': 0.5}, PROTOCOL_FALLBACKS) + assert merged['temperature'] == 0.5 + assert merged['top_k'] == PROTOCOL_FALLBACKS['top_k'] + + +def test_extract_request_sampling_values_only_non_null(): + request = ChatCompletionRequest(model='test', messages='hi', temperature=0.3) + values = extract_request_sampling_values(request) + assert values == {'temperature': 0.3} + + +def test_resolve_max_new_tokens_uses_server_default(): + assert resolve_max_new_tokens(None, None, 128) == 128 + + +def test_resolve_max_new_tokens_caps_request_value(): + assert resolve_max_new_tokens(256, None, 128) == 128 + assert resolve_max_new_tokens(None, 256, 128) == 128 + + +def test_resolve_max_new_tokens_prefers_max_completion_tokens(): + assert resolve_max_new_tokens(64, 256, None) == 64 + + +def test_build_generation_config_from_merged_values(): + gen_config = build_generation_config( + {'temperature': 0.2}, + {'top_k': 5}, + max_completion_tokens=32, + override_max_new_tokens=64, + ) + assert gen_config.temperature == 0.2 + assert gen_config.top_k == 5 + assert gen_config.max_new_tokens == 32 + assert gen_config.do_sample is True + + +@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') +def test_resolve_server_sampling_defaults_auto(mock_load): + mock_load.return_value = { + 'temperature': 0.6, + 'top_p': 0.8, + 'max_new_tokens': 2048, + } + defaults, cap = resolve_server_sampling_defaults('auto', None, '/fake/model', False) + assert defaults == {'temperature': 0.6, 'top_p': 0.8} + assert cap == 2048 + mock_load.assert_called_once_with('/fake/model', False) + + +def test_resolve_server_sampling_defaults_lmdeploy(): + defaults, cap = resolve_server_sampling_defaults('lmdeploy', None, '/fake/model', False) + assert defaults == {} + assert cap is None + + +@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') +def test_resolve_server_sampling_defaults_with_override(mock_load): + mock_load.return_value = {'temperature': 0.6, 'top_k': 20} + defaults, cap = resolve_server_sampling_defaults( + 'auto', + {'temperature': 0.5, 'max_new_tokens': 100}, + '/fake/model', + False, + ) + assert defaults == {'temperature': 0.5, 'top_k': 20} + assert cap == 100 + + +def test_completion_request_sampling_merge(): + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config( + extract_request_sampling_values(request), + {'temperature': 0.9}, + ) + assert gen_config.temperature == 0.9 + assert gen_config.top_k == PROTOCOL_FALLBACKS['top_k'] From e1eac99a579d9709461a17b105190580f3b3abd9 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 1 Jul 2026 12:46:44 +0000 Subject: [PATCH 2/4] refactor(serve): simplify generation-config CLI handling Drop --override-generation-config and server max_new_tokens caps so defaults only come from --generation-config and per-request fields. Co-authored-by: Cursor --- lmdeploy/cli/serve.py | 3 -- lmdeploy/cli/utils.py | 16 +----- lmdeploy/serve/anthropic/adapter.py | 2 - .../serve/anthropic/endpoints/messages.py | 1 - lmdeploy/serve/core/generation_config.py | 52 +++---------------- lmdeploy/serve/openai/api_server.py | 8 +-- lmdeploy/serve/openai/responses/request.py | 2 - lmdeploy/serve/openai/responses/serving.py | 1 - .../serve/test_generation_config.py | 39 +++----------- 9 files changed, 15 insertions(+), 109 deletions(-) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 5d8e0174b6..e11bb359ba 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -93,7 +93,6 @@ def add_parser_api_server(): ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) ArgumentHelper.generation_config(parser) - ArgumentHelper.override_generation_config(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') @@ -321,7 +320,6 @@ def api_server(args): tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, generation_config=args.generation_config, - override_generation_config=args.override_generation_config, ) else: from lmdeploy.serve.openai.launch_server import launch_server @@ -355,7 +353,6 @@ def api_server(args): tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, generation_config=args.generation_config, - override_generation_config=args.override_generation_config, ) @staticmethod diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 810f54524e..12f726700a 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -311,21 +311,7 @@ def generation_config(parser): help='The folder path to the generation config. Defaults to "auto", the ' 'generation config will be loaded from model path. If set to "lmdeploy", no ' 'generation config is loaded, lmdeploy defaults will be used. If set to a folder ' - 'path, the generation config will be loaded from the specified folder path. ' - 'If max_new_tokens is specified in generation config, then it sets a ' - 'server-wide limit on the number of output tokens for all requests.') - - @staticmethod - def override_generation_config(parser): - """Add argument override_generation_config to parser.""" - return parser.add_argument( - '--override-generation-config', - type=json.loads, - default=None, - help='Overrides or sets generation config. e.g. \'{"temperature": 0.5}\'. If ' - 'used with --generation-config auto, the override parameters will be merged ' - 'with the default config from the model. If used with --generation-config ' - 'lmdeploy, only the override parameters are used.') + 'path, the generation config will be loaded from the specified folder path.') @staticmethod def use_logn_attn(parser): diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py index 0474ca6580..781bacfd3f 100644 --- a/lmdeploy/serve/anthropic/adapter.py +++ b/lmdeploy/serve/anthropic/adapter.py @@ -345,7 +345,6 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[ def to_generation_config( request: MessagesRequest, server_defaults: dict | None = None, - override_max_new_tokens: int | None = None, ) -> GenerationConfig: """Map Anthropic messages request to LMDeploy generation config.""" request_values = extract_request_sampling_values(request) @@ -353,7 +352,6 @@ def to_generation_config( request_values, server_defaults or {}, max_tokens=request.max_tokens, - override_max_new_tokens=override_max_new_tokens, stop_words=request.stop_sequences, include_stop_str_in_output=request.include_stop_str_in_output or False, skip_special_tokens=True, diff --git a/lmdeploy/serve/anthropic/endpoints/messages.py b/lmdeploy/serve/anthropic/endpoints/messages.py index d390b6a868..4dbffae5ac 100644 --- a/lmdeploy/serve/anthropic/endpoints/messages.py +++ b/lmdeploy/serve/anthropic/endpoints/messages.py @@ -177,7 +177,6 @@ async def create_message(request: MessagesRequest, raw_request: Request): gen_config=to_generation_config( request, server_defaults=server_context.server_sampling_defaults, - override_max_new_tokens=server_context.override_max_new_tokens, ), tools=parsed_request.tools, stream_response=True, diff --git a/lmdeploy/serve/core/generation_config.py b/lmdeploy/serve/core/generation_config.py index b11cf26dbc..0bd2c483a7 100644 --- a/lmdeploy/serve/core/generation_config.py +++ b/lmdeploy/serve/core/generation_config.py @@ -24,10 +24,6 @@ 'temperature', 'top_p', 'top_k', - 'min_p', - 'repetition_penalty', - 'max_new_tokens', - 'do_sample', ) REQUEST_SAMPLING_FIELDS = ( @@ -56,17 +52,10 @@ def extract_sampling_params(config: dict[str, Any]) -> dict[str, Any]: def resolve_server_sampling_defaults( generation_config: str, - override: dict[str, Any] | None, model_path: str, trust_remote_code: bool, -) -> tuple[dict[str, Any], int | None]: - """Resolve server-side default sampling params from CLI flags. - - Returns: - A tuple of (sampling_defaults, override_max_new_tokens). - ``override_max_new_tokens`` is a server-wide cap/default when set. - """ - override = override or {} +) -> dict[str, Any]: + """Resolve server-side default sampling params from CLI flags.""" src = generation_config if src == 'lmdeploy': @@ -76,25 +65,15 @@ def resolve_server_sampling_defaults( else: config = _load_hf_generation_config(src, trust_remote_code) - config.update(override) sampling = extract_sampling_params(config) - override_max_new_tokens = sampling.pop('max_new_tokens', None) - if override_max_new_tokens is not None: - override_max_new_tokens = int(override_max_new_tokens) - if sampling and src != 'lmdeploy': source = "the model's `generation_config.json`" if src == 'auto' else src logger.info( - 'Using default sampling params from %s: %s. ' - 'Use `--generation-config lmdeploy` to disable.', - source, - sampling, - ) - elif sampling and override: - logger.info('Using override generation config sampling params: %s.', sampling) + f'Using default sampling params from {source}: {sampling}. ' + 'Use `--generation-config lmdeploy` to disable.') - return sampling, override_max_new_tokens + return sampling def merge_sampling_params( @@ -129,38 +108,19 @@ def extract_request_sampling_values(request: Any) -> dict[str, Any]: return values -def resolve_max_new_tokens( - max_completion_tokens: int | None, - max_tokens: int | None, - server_cap: int | None, -) -> int | None: - """Resolve output token limit with optional server-wide cap/default.""" - request_value = max_completion_tokens if max_completion_tokens is not None else max_tokens - if request_value is None: - return server_cap - if server_cap is not None: - return min(request_value, server_cap) - return request_value - - def build_generation_config( request_values: dict[str, Any], server_defaults: dict[str, Any], *, max_completion_tokens: int | None = None, max_tokens: int | None = None, - override_max_new_tokens: int | None = None, fallbacks: dict[str, Any] | None = None, **extra_kwargs: Any, ) -> GenerationConfig: """Build ``GenerationConfig`` from merged sampling defaults and request values.""" merged = merge_sampling_params(request_values, server_defaults, fallbacks) - max_new_tokens = resolve_max_new_tokens( - max_completion_tokens, - max_tokens, - override_max_new_tokens, - ) + max_new_tokens = max_completion_tokens if max_completion_tokens is not None else max_tokens return GenerationConfig( max_new_tokens=max_new_tokens, do_sample=merged.get('do_sample', PROTOCOL_FALLBACKS['do_sample']), diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index d39150487c..f8a8cbe94f 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -116,7 +116,6 @@ class VariableInterface: enable_abort_handling: bool = False response_parser_cls: type[ResponseParser] | None = None server_sampling_defaults: dict = {} - override_max_new_tokens: int | None = None @classmethod def create_session(cls, user_session_id: int | None = None) -> Session: @@ -166,7 +165,6 @@ def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfi VariableInterface.server_sampling_defaults, max_completion_tokens=max_completion_tokens, max_tokens=max_tokens, - override_max_new_tokens=VariableInterface.override_max_new_tokens, **extra_kwargs, ) @@ -1552,7 +1550,6 @@ def serve(model_path: str, enable_abort_handling: bool = False, speculative_config: SpeculativeConfig | None = None, generation_config: str = 'auto', - override_generation_config: dict | None = None, **kwargs): """An example to perform model inference through the command line interface. @@ -1613,14 +1610,11 @@ def serve(model_path: str, VariableInterface.allow_terminate_by_client = allow_terminate_by_client VariableInterface.enable_abort_handling = enable_abort_handling - server_defaults, override_max_new_tokens = resolve_server_sampling_defaults( + VariableInterface.server_sampling_defaults = resolve_server_sampling_defaults( generation_config, - override_generation_config, model_path, trust_remote_code, ) - VariableInterface.server_sampling_defaults = server_defaults - VariableInterface.override_max_new_tokens = override_max_new_tokens ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http' if ssl: diff --git a/lmdeploy/serve/openai/responses/request.py b/lmdeploy/serve/openai/responses/request.py index 6a7dc9f8df..8cfc799b38 100644 --- a/lmdeploy/serve/openai/responses/request.py +++ b/lmdeploy/serve/openai/responses/request.py @@ -268,7 +268,6 @@ def _response_format_from_text(text: Any) -> dict[str, Any] | None: def to_generation_config( request: ResponsesRequest, server_defaults: dict | None = None, - override_max_new_tokens: int | None = None, ) -> GenerationConfig: stop_words = [request.stop] if isinstance(request.stop, str) else request.stop request_values = extract_request_sampling_values(request) @@ -276,7 +275,6 @@ def to_generation_config( request_values, server_defaults or {}, max_completion_tokens=request.max_output_tokens, - override_max_new_tokens=override_max_new_tokens, stop_words=stop_words, ignore_eos=request.ignore_eos, skip_special_tokens=request.skip_special_tokens, diff --git a/lmdeploy/serve/openai/responses/serving.py b/lmdeploy/serve/openai/responses/serving.py index 64864ece80..f42ec6961c 100644 --- a/lmdeploy/serve/openai/responses/serving.py +++ b/lmdeploy/serve/openai/responses/serving.py @@ -100,7 +100,6 @@ async def create_response(self, request: ResponsesRequest, raw_request: Request) gen_config = to_generation_config( request, server_defaults=self.server_context.server_sampling_defaults, - override_max_new_tokens=self.server_context.override_max_new_tokens, ) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='text') diff --git a/tests/test_lmdeploy/serve/test_generation_config.py b/tests/test_lmdeploy/serve/test_generation_config.py index a5a9199a2f..e1b85a37f3 100644 --- a/tests/test_lmdeploy/serve/test_generation_config.py +++ b/tests/test_lmdeploy/serve/test_generation_config.py @@ -6,7 +6,6 @@ build_generation_config, extract_request_sampling_values, merge_sampling_params, - resolve_max_new_tokens, resolve_server_sampling_defaults, ) from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest @@ -35,25 +34,11 @@ def test_extract_request_sampling_values_only_non_null(): assert values == {'temperature': 0.3} -def test_resolve_max_new_tokens_uses_server_default(): - assert resolve_max_new_tokens(None, None, 128) == 128 - - -def test_resolve_max_new_tokens_caps_request_value(): - assert resolve_max_new_tokens(256, None, 128) == 128 - assert resolve_max_new_tokens(None, 256, 128) == 128 - - -def test_resolve_max_new_tokens_prefers_max_completion_tokens(): - assert resolve_max_new_tokens(64, 256, None) == 64 - - def test_build_generation_config_from_merged_values(): gen_config = build_generation_config( {'temperature': 0.2}, {'top_k': 5}, max_completion_tokens=32, - override_max_new_tokens=64, ) assert gen_config.temperature == 0.2 assert gen_config.top_k == 5 @@ -61,6 +46,11 @@ def test_build_generation_config_from_merged_values(): assert gen_config.do_sample is True +def test_build_generation_config_max_new_tokens_defaults_to_none(): + gen_config = build_generation_config({}, {}) + assert gen_config.max_new_tokens is None + + @patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') def test_resolve_server_sampling_defaults_auto(mock_load): mock_load.return_value = { @@ -68,29 +58,14 @@ def test_resolve_server_sampling_defaults_auto(mock_load): 'top_p': 0.8, 'max_new_tokens': 2048, } - defaults, cap = resolve_server_sampling_defaults('auto', None, '/fake/model', False) + defaults = resolve_server_sampling_defaults('auto', '/fake/model', False) assert defaults == {'temperature': 0.6, 'top_p': 0.8} - assert cap == 2048 mock_load.assert_called_once_with('/fake/model', False) def test_resolve_server_sampling_defaults_lmdeploy(): - defaults, cap = resolve_server_sampling_defaults('lmdeploy', None, '/fake/model', False) + defaults = resolve_server_sampling_defaults('lmdeploy', '/fake/model', False) assert defaults == {} - assert cap is None - - -@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') -def test_resolve_server_sampling_defaults_with_override(mock_load): - mock_load.return_value = {'temperature': 0.6, 'top_k': 20} - defaults, cap = resolve_server_sampling_defaults( - 'auto', - {'temperature': 0.5, 'max_new_tokens': 100}, - '/fake/model', - False, - ) - assert defaults == {'temperature': 0.5, 'top_k': 20} - assert cap == 100 def test_completion_request_sampling_merge(): From 46e0352f5080ace0594be96a92c34588a3597c35 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 1 Jul 2026 12:58:24 +0000 Subject: [PATCH 3/4] refactor(serve): store full HF config as default_gen_config Replace filtered server_sampling_defaults with the raw generation config dict loaded by --generation-config. Co-authored-by: Cursor --- lmdeploy/serve/anthropic/adapter.py | 4 +- .../serve/anthropic/endpoints/messages.py | 2 +- lmdeploy/serve/core/generation_config.py | 39 +++++++------------ lmdeploy/serve/openai/api_server.py | 8 ++-- lmdeploy/serve/openai/responses/request.py | 4 +- lmdeploy/serve/openai/responses/serving.py | 2 +- .../serve/openai/serving_chat_completion.py | 2 +- lmdeploy/serve/openai/serving_completion.py | 2 +- .../serve/test_generation_config.py | 18 +++++---- 9 files changed, 36 insertions(+), 45 deletions(-) diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py index 781bacfd3f..14ee52b9a2 100644 --- a/lmdeploy/serve/anthropic/adapter.py +++ b/lmdeploy/serve/anthropic/adapter.py @@ -344,13 +344,13 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[ def to_generation_config( request: MessagesRequest, - server_defaults: dict | None = None, + default_gen_config: dict | None = None, ) -> GenerationConfig: """Map Anthropic messages request to LMDeploy generation config.""" request_values = extract_request_sampling_values(request) return build_generation_config( request_values, - server_defaults or {}, + default_gen_config or {}, max_tokens=request.max_tokens, stop_words=request.stop_sequences, include_stop_str_in_output=request.include_stop_str_in_output or False, diff --git a/lmdeploy/serve/anthropic/endpoints/messages.py b/lmdeploy/serve/anthropic/endpoints/messages.py index 4dbffae5ac..6027fb3030 100644 --- a/lmdeploy/serve/anthropic/endpoints/messages.py +++ b/lmdeploy/serve/anthropic/endpoints/messages.py @@ -176,7 +176,7 @@ async def create_message(request: MessagesRequest, raw_request: Request): session, gen_config=to_generation_config( request, - server_defaults=server_context.server_sampling_defaults, + default_gen_config=server_context.default_gen_config, ), tools=parsed_request.tools, stream_response=True, diff --git a/lmdeploy/serve/core/generation_config.py b/lmdeploy/serve/core/generation_config.py index 0bd2c483a7..e6af816c0f 100644 --- a/lmdeploy/serve/core/generation_config.py +++ b/lmdeploy/serve/core/generation_config.py @@ -20,12 +20,6 @@ 'do_sample': True, } -SAMPLING_PARAM_KEYS = ( - 'temperature', - 'top_p', - 'top_k', -) - REQUEST_SAMPLING_FIELDS = ( 'temperature', 'top_p', @@ -45,17 +39,12 @@ def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, return {} -def extract_sampling_params(config: dict[str, Any]) -> dict[str, Any]: - """Extract supported sampling parameters from a generation config dict.""" - return {key: config[key] for key in SAMPLING_PARAM_KEYS if key in config and config[key] is not None} - - -def resolve_server_sampling_defaults( +def resolve_default_gen_config( generation_config: str, model_path: str, trust_remote_code: bool, ) -> dict[str, Any]: - """Resolve server-side default sampling params from CLI flags.""" + """Resolve server-side default generation config from CLI flags.""" src = generation_config if src == 'lmdeploy': @@ -65,32 +54,30 @@ def resolve_server_sampling_defaults( else: config = _load_hf_generation_config(src, trust_remote_code) - sampling = extract_sampling_params(config) - - if sampling and src != 'lmdeploy': + if config and src != 'lmdeploy': source = "the model's `generation_config.json`" if src == 'auto' else src logger.info( - f'Using default sampling params from {source}: {sampling}. ' + f'Using default generation config from {source}: {config}. ' 'Use `--generation-config lmdeploy` to disable.') - return sampling + return config def merge_sampling_params( request_values: dict[str, Any], - server_defaults: dict[str, Any], + default_gen_config: dict[str, Any], fallbacks: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Merge sampling params with request > server > protocol fallback - priority.""" + """Merge sampling params with request > default_gen_config > protocol + fallback priority.""" fallbacks = fallbacks or PROTOCOL_FALLBACKS merged: dict[str, Any] = {} - all_keys = set(fallbacks) | set(server_defaults) | set(request_values) + all_keys = set(fallbacks) | set(default_gen_config) | set(request_values) for key in all_keys: if key in request_values: merged[key] = request_values[key] - elif key in server_defaults: - merged[key] = server_defaults[key] + elif key in default_gen_config: + merged[key] = default_gen_config[key] elif key in fallbacks: merged[key] = fallbacks[key] return merged @@ -110,7 +97,7 @@ def extract_request_sampling_values(request: Any) -> dict[str, Any]: def build_generation_config( request_values: dict[str, Any], - server_defaults: dict[str, Any], + default_gen_config: dict[str, Any], *, max_completion_tokens: int | None = None, max_tokens: int | None = None, @@ -119,7 +106,7 @@ def build_generation_config( ) -> GenerationConfig: """Build ``GenerationConfig`` from merged sampling defaults and request values.""" - merged = merge_sampling_params(request_values, server_defaults, fallbacks) + merged = merge_sampling_params(request_values, default_gen_config, fallbacks) max_new_tokens = max_completion_tokens if max_completion_tokens is not None else max_tokens return GenerationConfig( max_new_tokens=max_new_tokens, diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index f8a8cbe94f..d185257c1a 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -50,7 +50,7 @@ from lmdeploy.serve.core.generation_config import ( build_generation_config, extract_request_sampling_values, - resolve_server_sampling_defaults, + resolve_default_gen_config, ) from lmdeploy.serve.openai.protocol import ( AbortRequest, @@ -115,7 +115,7 @@ class VariableInterface: allow_terminate_by_client: bool = False enable_abort_handling: bool = False response_parser_cls: type[ResponseParser] | None = None - server_sampling_defaults: dict = {} + default_gen_config: dict = {} @classmethod def create_session(cls, user_session_id: int | None = None) -> Session: @@ -162,7 +162,7 @@ def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfi max_completion_tokens = getattr(request, 'max_output_tokens', None) return build_generation_config( request_values, - VariableInterface.server_sampling_defaults, + VariableInterface.default_gen_config, max_completion_tokens=max_completion_tokens, max_tokens=max_tokens, **extra_kwargs, @@ -1610,7 +1610,7 @@ def serve(model_path: str, VariableInterface.allow_terminate_by_client = allow_terminate_by_client VariableInterface.enable_abort_handling = enable_abort_handling - VariableInterface.server_sampling_defaults = resolve_server_sampling_defaults( + VariableInterface.default_gen_config = resolve_default_gen_config( generation_config, model_path, trust_remote_code, diff --git a/lmdeploy/serve/openai/responses/request.py b/lmdeploy/serve/openai/responses/request.py index 8cfc799b38..56aed0d5f3 100644 --- a/lmdeploy/serve/openai/responses/request.py +++ b/lmdeploy/serve/openai/responses/request.py @@ -267,13 +267,13 @@ def _response_format_from_text(text: Any) -> dict[str, Any] | None: def to_generation_config( request: ResponsesRequest, - server_defaults: dict | None = None, + default_gen_config: dict | None = None, ) -> GenerationConfig: stop_words = [request.stop] if isinstance(request.stop, str) else request.stop request_values = extract_request_sampling_values(request) return build_generation_config( request_values, - server_defaults or {}, + default_gen_config or {}, max_completion_tokens=request.max_output_tokens, stop_words=stop_words, ignore_eos=request.ignore_eos, diff --git a/lmdeploy/serve/openai/responses/serving.py b/lmdeploy/serve/openai/responses/serving.py index f42ec6961c..abcac88908 100644 --- a/lmdeploy/serve/openai/responses/serving.py +++ b/lmdeploy/serve/openai/responses/serving.py @@ -99,7 +99,7 @@ async def create_response(self, request: ResponsesRequest, raw_request: Request) try: gen_config = to_generation_config( request, - server_defaults=self.server_context.server_sampling_defaults, + default_gen_config=self.server_context.default_gen_config, ) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='text') diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 7ac750a617..4aadf90353 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -16,7 +16,7 @@ def _effective_sampling(request: ChatCompletionRequest, server_context: 'VariableInterface') -> dict: return merge_sampling_params( extract_request_sampling_values(request), - server_context.server_sampling_defaults, + server_context.default_gen_config, PROTOCOL_FALLBACKS, ) diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index 3047717c1f..a8564eb56f 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -16,7 +16,7 @@ def _effective_sampling(request: CompletionRequest, server_context: 'VariableInterface') -> dict: return merge_sampling_params( extract_request_sampling_values(request), - server_context.server_sampling_defaults, + server_context.default_gen_config, PROTOCOL_FALLBACKS, ) diff --git a/tests/test_lmdeploy/serve/test_generation_config.py b/tests/test_lmdeploy/serve/test_generation_config.py index e1b85a37f3..8b198bd043 100644 --- a/tests/test_lmdeploy/serve/test_generation_config.py +++ b/tests/test_lmdeploy/serve/test_generation_config.py @@ -6,7 +6,7 @@ build_generation_config, extract_request_sampling_values, merge_sampling_params, - resolve_server_sampling_defaults, + resolve_default_gen_config, ) from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest @@ -52,20 +52,24 @@ def test_build_generation_config_max_new_tokens_defaults_to_none(): @patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') -def test_resolve_server_sampling_defaults_auto(mock_load): +def test_resolve_default_gen_config_auto(mock_load): mock_load.return_value = { 'temperature': 0.6, 'top_p': 0.8, 'max_new_tokens': 2048, } - defaults = resolve_server_sampling_defaults('auto', '/fake/model', False) - assert defaults == {'temperature': 0.6, 'top_p': 0.8} + config = resolve_default_gen_config('auto', '/fake/model', False) + assert config == { + 'temperature': 0.6, + 'top_p': 0.8, + 'max_new_tokens': 2048, + } mock_load.assert_called_once_with('/fake/model', False) -def test_resolve_server_sampling_defaults_lmdeploy(): - defaults = resolve_server_sampling_defaults('lmdeploy', '/fake/model', False) - assert defaults == {} +def test_resolve_default_gen_config_lmdeploy(): + config = resolve_default_gen_config('lmdeploy', '/fake/model', False) + assert config == {} def test_completion_request_sampling_merge(): From 2c4a2847141975ff5687b4844d10a036e49bdae8 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 1 Jul 2026 13:31:43 +0000 Subject: [PATCH 4/4] refactor(serve): simplify generation config merge and defaults Drop PROTOCOL_FALLBACKS in favor of GenerationConfig dataclass defaults, and streamline build_generation_config to merge request and HF config in one place. Co-authored-by: Cursor --- lmdeploy/serve/anthropic/adapter.py | 7 +-- lmdeploy/serve/core/generation_config.py | 60 ++++++------------- lmdeploy/serve/openai/api_server.py | 14 ++--- lmdeploy/serve/openai/responses/request.py | 7 +-- .../serve/openai/serving_chat_completion.py | 17 +++--- lmdeploy/serve/openai/serving_completion.py | 17 +++--- .../serve/test_generation_config.py | 38 ++++++------ 7 files changed, 63 insertions(+), 97 deletions(-) diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py index 14ee52b9a2..c1b2553897 100644 --- a/lmdeploy/serve/anthropic/adapter.py +++ b/lmdeploy/serve/anthropic/adapter.py @@ -9,7 +9,7 @@ import shortuuid from lmdeploy.messages import GenerationConfig -from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values +from lmdeploy.serve.core.generation_config import build_generation_config from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from .protocol import ( @@ -347,11 +347,10 @@ def to_generation_config( default_gen_config: dict | None = None, ) -> GenerationConfig: """Map Anthropic messages request to LMDeploy generation config.""" - request_values = extract_request_sampling_values(request) return build_generation_config( - request_values, + request, default_gen_config or {}, - max_tokens=request.max_tokens, + max_new_tokens=request.max_tokens, stop_words=request.stop_sequences, include_stop_str_in_output=request.include_stop_str_in_output or False, skip_special_tokens=True, diff --git a/lmdeploy/serve/core/generation_config.py b/lmdeploy/serve/core/generation_config.py index e6af816c0f..82126af92c 100644 --- a/lmdeploy/serve/core/generation_config.py +++ b/lmdeploy/serve/core/generation_config.py @@ -4,6 +4,7 @@ from __future__ import annotations +import dataclasses from typing import Any from lmdeploy.messages import GenerationConfig @@ -11,23 +12,6 @@ logger = get_logger('lmdeploy') -PROTOCOL_FALLBACKS: dict[str, Any] = { - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 40, - 'repetition_penalty': 1.0, - 'min_p': 0.0, - 'do_sample': True, -} - -REQUEST_SAMPLING_FIELDS = ( - 'temperature', - 'top_p', - 'top_k', - 'min_p', - 'repetition_penalty', -) - def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]: from transformers import GenerationConfig @@ -66,55 +50,45 @@ def resolve_default_gen_config( def merge_sampling_params( request_values: dict[str, Any], default_gen_config: dict[str, Any], - fallbacks: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Merge sampling params with request > default_gen_config > protocol - fallback priority.""" - fallbacks = fallbacks or PROTOCOL_FALLBACKS + """Merge sampling params with request > default_gen_config priority.""" merged: dict[str, Any] = {} - all_keys = set(fallbacks) | set(default_gen_config) | set(request_values) - for key in all_keys: + for key in set(default_gen_config) | set(request_values): if key in request_values: merged[key] = request_values[key] - elif key in default_gen_config: + else: merged[key] = default_gen_config[key] - elif key in fallbacks: - merged[key] = fallbacks[key] return merged def extract_request_sampling_values(request: Any) -> dict[str, Any]: - """Extract explicitly provided sampling fields from a request object.""" + """Extract non-None GenerationConfig fields present on the request.""" values: dict[str, Any] = {} - for field in REQUEST_SAMPLING_FIELDS: - if not hasattr(request, field): + for field in dataclasses.fields(GenerationConfig): + if not hasattr(request, field.name): continue - value = getattr(request, field) + value = getattr(request, field.name) if value is not None: - values[field] = value + values[field.name] = value return values def build_generation_config( - request_values: dict[str, Any], + request: Any, default_gen_config: dict[str, Any], *, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - fallbacks: dict[str, Any] | None = None, + max_new_tokens: int | None = None, **extra_kwargs: Any, ) -> GenerationConfig: """Build ``GenerationConfig`` from merged sampling defaults and request values.""" - merged = merge_sampling_params(request_values, default_gen_config, fallbacks) - max_new_tokens = max_completion_tokens if max_completion_tokens is not None else max_tokens + request_values = extract_request_sampling_values(request) + merged = merge_sampling_params(request_values, default_gen_config) + merged.pop('max_new_tokens', None) + merged.pop('do_sample', None) return GenerationConfig( max_new_tokens=max_new_tokens, - do_sample=merged.get('do_sample', PROTOCOL_FALLBACKS['do_sample']), - top_k=merged['top_k'], - top_p=merged['top_p'], - temperature=merged['temperature'], - repetition_penalty=merged['repetition_penalty'], - min_p=merged['min_p'], + do_sample=True, + **merged, **extra_kwargs, ) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index d185257c1a..abdc7a320f 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -49,7 +49,6 @@ from lmdeploy.serve.core import AsyncEngine, EngineHealthMonitor from lmdeploy.serve.core.generation_config import ( build_generation_config, - extract_request_sampling_values, resolve_default_gen_config, ) from lmdeploy.serve.openai.protocol import ( @@ -155,16 +154,13 @@ def get_engine_config(cls): def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfig: """Build ``GenerationConfig`` with server and request sampling merge.""" - request_values = extract_request_sampling_values(request) - max_completion_tokens = getattr(request, 'max_completion_tokens', None) - max_tokens = getattr(request, 'max_tokens', None) - if max_completion_tokens is None and hasattr(request, 'max_output_tokens'): - max_completion_tokens = getattr(request, 'max_output_tokens', None) + max_new_tokens = getattr(request, 'max_completion_tokens', None) + if max_new_tokens is None: + max_new_tokens = getattr(request, 'max_tokens', None) return build_generation_config( - request_values, + request, VariableInterface.default_gen_config, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, + max_new_tokens=max_new_tokens, **extra_kwargs, ) diff --git a/lmdeploy/serve/openai/responses/request.py b/lmdeploy/serve/openai/responses/request.py index 56aed0d5f3..50d6927139 100644 --- a/lmdeploy/serve/openai/responses/request.py +++ b/lmdeploy/serve/openai/responses/request.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from lmdeploy.messages import GenerationConfig -from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values +from lmdeploy.serve.core.generation_config import build_generation_config from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from lmdeploy.serve.openai.responses.protocol import ResponsesRequest from lmdeploy.utils import get_logger @@ -270,11 +270,10 @@ def to_generation_config( default_gen_config: dict | None = None, ) -> GenerationConfig: stop_words = [request.stop] if isinstance(request.stop, str) else request.stop - request_values = extract_request_sampling_values(request) return build_generation_config( - request_values, + request, default_gen_config or {}, - max_completion_tokens=request.max_output_tokens, + max_new_tokens=request.max_output_tokens, stop_words=stop_words, ignore_eos=request.ignore_eos, skip_special_tokens=request.skip_special_tokens, diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 4aadf90353..8c44e5badd 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -1,11 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING -from lmdeploy.serve.core.generation_config import ( - PROTOCOL_FALLBACKS, - extract_request_sampling_values, - merge_sampling_params, -) +from lmdeploy.serve.core.generation_config import build_generation_config from .protocol import ChatCompletionRequest @@ -14,11 +10,12 @@ def _effective_sampling(request: ChatCompletionRequest, server_context: 'VariableInterface') -> dict: - return merge_sampling_params( - extract_request_sampling_values(request), - server_context.default_gen_config, - PROTOCOL_FALLBACKS, - ) + gen = build_generation_config(request, server_context.default_gen_config) + return { + 'temperature': gen.temperature, + 'top_p': gen.top_p, + 'top_k': gen.top_k, + } def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str: diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index a8564eb56f..ca96f8789d 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -1,11 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING -from lmdeploy.serve.core.generation_config import ( - PROTOCOL_FALLBACKS, - extract_request_sampling_values, - merge_sampling_params, -) +from lmdeploy.serve.core.generation_config import build_generation_config from .protocol import CompletionRequest @@ -14,11 +10,12 @@ def _effective_sampling(request: CompletionRequest, server_context: 'VariableInterface') -> dict: - return merge_sampling_params( - extract_request_sampling_values(request), - server_context.default_gen_config, - PROTOCOL_FALLBACKS, - ) + gen = build_generation_config(request, server_context.default_gen_config) + return { + 'temperature': gen.temperature, + 'top_p': gen.top_p, + 'top_k': gen.top_k, + } def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str: diff --git a/tests/test_lmdeploy/serve/test_generation_config.py b/tests/test_lmdeploy/serve/test_generation_config.py index 8b198bd043..8946ef8a61 100644 --- a/tests/test_lmdeploy/serve/test_generation_config.py +++ b/tests/test_lmdeploy/serve/test_generation_config.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import patch +from lmdeploy.messages import GenerationConfig from lmdeploy.serve.core.generation_config import ( - PROTOCOL_FALLBACKS, build_generation_config, extract_request_sampling_values, merge_sampling_params, @@ -10,22 +10,20 @@ ) from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest +_DEFAULTS = GenerationConfig() + def test_merge_sampling_params_priority(): merged = merge_sampling_params( {'temperature': 0.2}, {'temperature': 0.5, 'top_k': 10}, - PROTOCOL_FALLBACKS, ) - assert merged['temperature'] == 0.2 - assert merged['top_k'] == 10 - assert merged['top_p'] == PROTOCOL_FALLBACKS['top_p'] + assert merged == {'temperature': 0.2, 'top_k': 10} -def test_merge_sampling_params_uses_server_then_fallback(): - merged = merge_sampling_params({}, {'temperature': 0.5}, PROTOCOL_FALLBACKS) - assert merged['temperature'] == 0.5 - assert merged['top_k'] == PROTOCOL_FALLBACKS['top_k'] +def test_merge_sampling_params_uses_server_defaults(): + merged = merge_sampling_params({}, {'temperature': 0.5}) + assert merged == {'temperature': 0.5} def test_extract_request_sampling_values_only_non_null(): @@ -35,10 +33,11 @@ def test_extract_request_sampling_values_only_non_null(): def test_build_generation_config_from_merged_values(): + request = ChatCompletionRequest(model='test', messages='hi', temperature=0.2) gen_config = build_generation_config( - {'temperature': 0.2}, + request, {'top_k': 5}, - max_completion_tokens=32, + max_new_tokens=32, ) assert gen_config.temperature == 0.2 assert gen_config.top_k == 5 @@ -47,10 +46,18 @@ def test_build_generation_config_from_merged_values(): def test_build_generation_config_max_new_tokens_defaults_to_none(): - gen_config = build_generation_config({}, {}) + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config(request, {}) assert gen_config.max_new_tokens is None +def test_build_generation_config_uses_generation_config_defaults(): + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config(request, {}) + assert gen_config.temperature == _DEFAULTS.temperature + assert gen_config.top_k == _DEFAULTS.top_k + + @patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') def test_resolve_default_gen_config_auto(mock_load): mock_load.return_value = { @@ -74,9 +81,6 @@ def test_resolve_default_gen_config_lmdeploy(): def test_completion_request_sampling_merge(): request = CompletionRequest(model='test', prompt='hello') - gen_config = build_generation_config( - extract_request_sampling_values(request), - {'temperature': 0.9}, - ) + gen_config = build_generation_config(request, {'temperature': 0.9}) assert gen_config.temperature == 0.9 - assert gen_config.top_k == PROTOCOL_FALLBACKS['top_k'] + assert gen_config.top_k == _DEFAULTS.top_k