diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 01ac1d44f1..e11bb359ba 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -92,6 +92,7 @@ def add_parser_api_server(): # model args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) + ArgumentHelper.generation_config(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') @@ -318,6 +319,7 @@ def api_server(args): reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, + generation_config=args.generation_config, ) else: from lmdeploy.serve.openai.launch_server import launch_server @@ -350,6 +352,7 @@ def api_server(args): reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, + generation_config=args.generation_config, ) @staticmethod diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 1dda62a7e8..12f726700a 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -301,6 +301,18 @@ def hf_overrides(parser): default=None, help='Extra arguments to be forwarded to the HuggingFace config.') + @staticmethod + def generation_config(parser): + """Add argument generation_config to parser.""" + return parser.add_argument( + '--generation-config', + type=str, + default='auto', + help='The folder path to the generation config. Defaults to "auto", the ' + 'generation config will be loaded from model path. If set to "lmdeploy", no ' + 'generation config is loaded, lmdeploy defaults will be used. If set to a folder ' + 'path, the generation config will be loaded from the specified folder path.') + @staticmethod def use_logn_attn(parser): """Add argument use_logn_attn to parser.""" diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py index 6975c7c8ad..c1b2553897 100644 --- a/lmdeploy/serve/anthropic/adapter.py +++ b/lmdeploy/serve/anthropic/adapter.py @@ -9,6 +9,7 @@ import shortuuid from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import build_generation_config from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from .protocol import ( @@ -341,15 +342,15 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[ return lm_messages -def to_generation_config(request: MessagesRequest) -> GenerationConfig: +def to_generation_config( + request: MessagesRequest, + default_gen_config: dict | None = None, +) -> GenerationConfig: """Map Anthropic messages request to LMDeploy generation config.""" - - return GenerationConfig( + return build_generation_config( + request, + default_gen_config or {}, max_new_tokens=request.max_tokens, - do_sample=True, - top_k=40 if request.top_k is None else request.top_k, - top_p=1.0 if request.top_p is None else request.top_p, - temperature=1.0 if request.temperature is None else request.temperature, stop_words=request.stop_sequences, include_stop_str_in_output=request.include_stop_str_in_output or False, skip_special_tokens=True, diff --git a/lmdeploy/serve/anthropic/endpoints/messages.py b/lmdeploy/serve/anthropic/endpoints/messages.py index 1f5b17a54b..6027fb3030 100644 --- a/lmdeploy/serve/anthropic/endpoints/messages.py +++ b/lmdeploy/serve/anthropic/endpoints/messages.py @@ -174,7 +174,10 @@ async def create_message(request: MessagesRequest, raw_request: Request): result_generator = server_context.async_engine.generate( engine_messages, session, - gen_config=to_generation_config(request), + gen_config=to_generation_config( + request, + default_gen_config=server_context.default_gen_config, + ), tools=parsed_request.tools, stream_response=True, sequence_start=True, diff --git a/lmdeploy/serve/anthropic/protocol.py b/lmdeploy/serve/anthropic/protocol.py index 03f0b1c37f..37fa0f7bd7 100644 --- a/lmdeploy/serve/anthropic/protocol.py +++ b/lmdeploy/serve/anthropic/protocol.py @@ -104,7 +104,7 @@ class MessagesRequest(BaseModel): system: str | list[ContentBlockParam] | None = None stop_sequences: list[str] | None = None stream: bool = False - temperature: float | None = 1.0 + temperature: float | None = None top_p: float | None = None top_k: int | None = None metadata: dict[str, Any] | None = None diff --git a/lmdeploy/serve/core/generation_config.py b/lmdeploy/serve/core/generation_config.py new file mode 100644 index 0000000000..82126af92c --- /dev/null +++ b/lmdeploy/serve/core/generation_config.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Server-side generation config resolution and sampling parameter merge +helpers.""" + +from __future__ import annotations + +import dataclasses +from typing import Any + +from lmdeploy.messages import GenerationConfig +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + + +def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]: + from transformers import GenerationConfig + + try: + cfg = GenerationConfig.from_pretrained(path, trust_remote_code=trust_remote_code) + return cfg.to_diff_dict() + except OSError: + return {} + + +def resolve_default_gen_config( + generation_config: str, + model_path: str, + trust_remote_code: bool, +) -> dict[str, Any]: + """Resolve server-side default generation config from CLI flags.""" + src = generation_config + + if src == 'lmdeploy': + config: dict[str, Any] = {} + elif src == 'auto': + config = _load_hf_generation_config(model_path, trust_remote_code) + else: + config = _load_hf_generation_config(src, trust_remote_code) + + if config and src != 'lmdeploy': + source = "the model's `generation_config.json`" if src == 'auto' else src + logger.info( + f'Using default generation config from {source}: {config}. ' + 'Use `--generation-config lmdeploy` to disable.') + + return config + + +def merge_sampling_params( + request_values: dict[str, Any], + default_gen_config: dict[str, Any], +) -> dict[str, Any]: + """Merge sampling params with request > default_gen_config priority.""" + merged: dict[str, Any] = {} + for key in set(default_gen_config) | set(request_values): + if key in request_values: + merged[key] = request_values[key] + else: + merged[key] = default_gen_config[key] + return merged + + +def extract_request_sampling_values(request: Any) -> dict[str, Any]: + """Extract non-None GenerationConfig fields present on the request.""" + values: dict[str, Any] = {} + for field in dataclasses.fields(GenerationConfig): + if not hasattr(request, field.name): + continue + value = getattr(request, field.name) + if value is not None: + values[field.name] = value + return values + + +def build_generation_config( + request: Any, + default_gen_config: dict[str, Any], + *, + max_new_tokens: int | None = None, + **extra_kwargs: Any, +) -> GenerationConfig: + """Build ``GenerationConfig`` from merged sampling defaults and request + values.""" + request_values = extract_request_sampling_values(request) + merged = merge_sampling_params(request_values, default_gen_config) + merged.pop('max_new_tokens', None) + merged.pop('do_sample', None) + return GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=True, + **merged, + **extra_kwargs, + ) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e8a3ca6ecc..abdc7a320f 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -47,6 +47,10 @@ ) from lmdeploy.serve.anthropic import create_anthropic_router from lmdeploy.serve.core import AsyncEngine, EngineHealthMonitor +from lmdeploy.serve.core.generation_config import ( + build_generation_config, + resolve_default_gen_config, +) from lmdeploy.serve.openai.protocol import ( AbortRequest, ChatCompletionRequest, @@ -110,6 +114,7 @@ class VariableInterface: allow_terminate_by_client: bool = False enable_abort_handling: bool = False response_parser_cls: type[ResponseParser] | None = None + default_gen_config: dict = {} @classmethod def create_session(cls, user_session_id: int | None = None) -> Session: @@ -147,6 +152,19 @@ def get_engine_config(cls): return cls.async_engine.backend_config +def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfig: + """Build ``GenerationConfig`` with server and request sampling merge.""" + max_new_tokens = getattr(request, 'max_completion_tokens', None) + if max_new_tokens is None: + max_new_tokens = getattr(request, 'max_tokens', None) + return build_generation_config( + request, + VariableInterface.default_gen_config, + max_new_tokens=max_new_tokens, + **extra_kwargs, + ) + + async def _with_request_cleanup(generator, result_generators, sessions): """Yield from an API generator and cleanup when the HTTP task exits.""" session_mgr = VariableInterface.get_session_manager() @@ -486,14 +504,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque # (e.g. GPT-OSS clears response_format and injects the schema into messages) request = response_parser.request - gen_config = GenerationConfig( - max_new_tokens=request.max_completion_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=gen_logprobs, - top_k=request.top_k, - top_p=request.top_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, include_stop_str_in_output=request.include_stop_str_in_output, @@ -501,7 +514,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque response_format=request.response_format, logits_processors=logits_processors, min_new_tokens=request.min_new_tokens, - min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, @@ -828,20 +840,13 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed is not None else None - max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens) - gen_config = GenerationConfig( - max_new_tokens=max_new_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=request.logprobs, - top_k=request.top_k, - top_p=request.top_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, skip_special_tokens=request.skip_special_tokens, - min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, @@ -1012,15 +1017,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): prompt = [dict(role='user', content=[text_input] + image_input)] input_ids = None - gen_config = GenerationConfig( - max_new_tokens=request.max_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=1 if request.return_logprob else None, - top_k=request.top_k, - top_p=request.top_p, - min_p=request.min_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, stop_token_ids=request.stop_token_ids, @@ -1546,6 +1545,7 @@ def serve(model_path: str, allow_terminate_by_client: bool = False, enable_abort_handling: bool = False, speculative_config: SpeculativeConfig | None = None, + generation_config: str = 'auto', **kwargs): """An example to perform model inference through the command line interface. @@ -1606,6 +1606,12 @@ def serve(model_path: str, VariableInterface.allow_terminate_by_client = allow_terminate_by_client VariableInterface.enable_abort_handling = enable_abort_handling + VariableInterface.default_gen_config = resolve_default_gen_config( + generation_config, + model_path, + trust_remote_code, + ) + ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http' if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 675cbf5103..40162faf15 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -144,8 +144,8 @@ class ChatCompletionRequest(BaseModel): model: str messages: str | list[dict[str, Any]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) - temperature: float | None = 0.7 - top_p: float | None = 1.0 + temperature: float | None = None + top_p: float | None = None tools: list[Tool] | None = Field(default=None, examples=[None]) tool_choice: ToolChoice | AllowedToolChoice | Literal[ 'auto', 'required', 'none'] = Field(default='auto', examples=['none']) @@ -176,17 +176,17 @@ class ChatCompletionRequest(BaseModel): response_format: ResponseFormat | None = Field(default=None, examples=[None]) # additional argument of lmdeploy do_preprocess: bool | None = True - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None repetition_ngram_size: int = Field(default=0, ge=0) repetition_ngram_threshold: int = Field(default=0, ge=0) session_id: int | None = -1 ignore_eos: bool | None = False skip_special_tokens: bool | None = True spaces_between_special_tokens: bool | None = True - top_k: int | None = 40 + top_k: int | None = None seed: int | None = None min_new_tokens: int | None = Field(default=None, examples=[None]) - min_p: float = 0.0 + min_p: float | None = None enable_thinking: bool | None = None # will be deprecated in the future return_token_ids: bool | None = False return_logprob: bool | None = False @@ -352,7 +352,7 @@ class CompletionRequest(BaseModel): model: str prompt: str | list[Any] suffix: str | None = None - temperature: float | None = 0.7 + temperature: float | None = None n: int | None = 1 logprobs: int | None = None max_completion_tokens: int | None = Field( @@ -362,29 +362,29 @@ class CompletionRequest(BaseModel): 'including visible output tokens and reasoning tokens'), ) max_tokens: int | None = Field( - default=16, - examples=[16], + default=None, + examples=[None], deprecated='max_tokens is deprecated in favor of the max_completion_tokens field', ) stop: str | list[str] | None = Field(default=None, examples=[None]) stream: bool | None = False stream_options: StreamOptions | None = Field(default=None, examples=[None]) - top_p: float | None = 1.0 + top_p: float | None = None echo: bool | None = False presence_penalty: float | None = 0.0 frequency_penalty: float | None = 0.0 user: str | None = None # additional argument of lmdeploy - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None repetition_ngram_size: int = Field(default=0, ge=0) repetition_ngram_threshold: int = Field(default=0, ge=0) session_id: int | None = -1 ignore_eos: bool | None = False skip_special_tokens: bool | None = True spaces_between_special_tokens: bool | None = True - top_k: int | None = 40 # for opencompass + top_k: int | None = None # for opencompass seed: int | None = None - min_p: float = 0.0 + min_p: float | None = None class CompletionResponseChoice(BaseModel): @@ -550,7 +550,7 @@ class GenerateReqInput(BaseModel): stop_token_ids: list[int] | None = None stream: bool | None = False temperature: float = 1.0 - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None ignore_eos: bool | None = False top_p: float = 1.0 top_k: int = 0 diff --git a/lmdeploy/serve/openai/responses/protocol.py b/lmdeploy/serve/openai/responses/protocol.py index ed08f16202..d02d220df7 100644 --- a/lmdeploy/serve/openai/responses/protocol.py +++ b/lmdeploy/serve/openai/responses/protocol.py @@ -82,10 +82,10 @@ class ResponsesRequest(BaseModel): presence_penalty: float | None = None frequency_penalty: float | None = None repetition_penalty: float | None = None - top_k: int | None = 40 + top_k: int | None = None stop: str | list[str] | None = None seed: int | None = None - min_p: float = 0.0 + min_p: float | None = None ignore_eos: bool | None = False skip_special_tokens: bool | None = True include_stop_str_in_output: bool | None = False diff --git a/lmdeploy/serve/openai/responses/request.py b/lmdeploy/serve/openai/responses/request.py index 2db1000532..50d6927139 100644 --- a/lmdeploy/serve/openai/responses/request.py +++ b/lmdeploy/serve/openai/responses/request.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import build_generation_config from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from lmdeploy.serve.openai.responses.protocol import ResponsesRequest from lmdeploy.utils import get_logger @@ -264,20 +265,19 @@ def _response_format_from_text(text: Any) -> dict[str, Any] | None: raise ValueError(f'Unsupported text.format type: {format_type!r}.') -def to_generation_config(request: ResponsesRequest) -> GenerationConfig: +def to_generation_config( + request: ResponsesRequest, + default_gen_config: dict | None = None, +) -> GenerationConfig: stop_words = [request.stop] if isinstance(request.stop, str) else request.stop - return GenerationConfig( + return build_generation_config( + request, + default_gen_config or {}, max_new_tokens=request.max_output_tokens, - do_sample=True, - top_k=40 if request.top_k is None else request.top_k, - top_p=1.0 if request.top_p is None else request.top_p, - temperature=1.0 if request.temperature is None else request.temperature, stop_words=stop_words, ignore_eos=request.ignore_eos, skip_special_tokens=request.skip_special_tokens, include_stop_str_in_output=request.include_stop_str_in_output, response_format=_response_format_from_text(request.text), - min_p=request.min_p, random_seed=request.seed, - repetition_penalty=1.0 if request.repetition_penalty is None else request.repetition_penalty, ) diff --git a/lmdeploy/serve/openai/responses/serving.py b/lmdeploy/serve/openai/responses/serving.py index 2837e483c5..abcac88908 100644 --- a/lmdeploy/serve/openai/responses/serving.py +++ b/lmdeploy/serve/openai/responses/serving.py @@ -97,7 +97,10 @@ async def create_response(self, request: ResponsesRequest, raw_request: Request) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='input') try: - gen_config = to_generation_config(request) + gen_config = to_generation_config( + request, + default_gen_config=self.server_context.default_gen_config, + ) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='text') try: diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 362f0bf9e5..8c44e5badd 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -1,12 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING +from lmdeploy.serve.core.generation_config import build_generation_config + from .protocol import ChatCompletionRequest if TYPE_CHECKING: from .api_server import VariableInterface +def _effective_sampling(request: ChatCompletionRequest, server_context: 'VariableInterface') -> dict: + gen = build_generation_config(request, server_context.default_gen_config) + return { + 'temperature': gen.temperature, + 'top_p': gen.top_p, + 'top_k': gen.top_k, + } + + def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() @@ -32,15 +43,17 @@ def check_request(request: ChatCompletionRequest, server_context: 'VariableInter if session_manager.has(request.session_id): return f'The session_id {request.session_id!r} is occupied.' + sampling = _effective_sampling(request, server_context) + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' - if not (0 < request.top_p <= 1): - return f'The top_p {request.top_p!r} must be in (0, 1].' - if request.top_k < 0: - return f'The top_k {request.top_k!r} cannot be a negative integer.' - if not (0 <= request.temperature <= 2): - return f'The temperature {request.temperature!r} must be in [0, 2]' + if not (0 < sampling['top_p'] <= 1): + return f'The top_p {sampling["top_p"]!r} must be in (0, 1].' + if sampling['top_k'] < 0: + return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.' + if not (0 <= sampling['temperature'] <= 2): + return f'The temperature {sampling["temperature"]!r} must be in [0, 2]' # Validate input_ids and image_data constraints. # messages has higher priority. input_ids and image_data are only used when diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index 759972db36..ca96f8789d 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -1,12 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING +from lmdeploy.serve.core.generation_config import build_generation_config + from .protocol import CompletionRequest if TYPE_CHECKING: from .api_server import VariableInterface +def _effective_sampling(request: CompletionRequest, server_context: 'VariableInterface') -> dict: + gen = build_generation_config(request, server_context.default_gen_config) + return { + 'temperature': gen.temperature, + 'top_p': gen.top_p, + 'top_k': gen.top_k, + } + + def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() @@ -24,14 +35,16 @@ def check_request(request: CompletionRequest, server_context: 'VariableInterface if session_manager.has(request.session_id): return f'The session_id {request.session_id!r} is occupied.' + sampling = _effective_sampling(request, server_context) + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' - if not (0 < request.top_p <= 1): - return f'The top_p {request.top_p!r} must be in (0, 1].' - if request.top_k < 0: - return f'The top_k {request.top_k!r} cannot be a negative integer.' - if not (0 <= request.temperature <= 2): - return f'The temperature {request.temperature!r} must be in [0, 2]' + if not (0 < sampling['top_p'] <= 1): + return f'The top_p {sampling["top_p"]!r} must be in (0, 1].' + if sampling['top_k'] < 0: + return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.' + if not (0 <= sampling['temperature'] <= 2): + return f'The temperature {sampling["temperature"]!r} must be in [0, 2]' return '' diff --git a/tests/test_lmdeploy/serve/test_generation_config.py b/tests/test_lmdeploy/serve/test_generation_config.py new file mode 100644 index 0000000000..8946ef8a61 --- /dev/null +++ b/tests/test_lmdeploy/serve/test_generation_config.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import ( + build_generation_config, + extract_request_sampling_values, + merge_sampling_params, + resolve_default_gen_config, +) +from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest + +_DEFAULTS = GenerationConfig() + + +def test_merge_sampling_params_priority(): + merged = merge_sampling_params( + {'temperature': 0.2}, + {'temperature': 0.5, 'top_k': 10}, + ) + assert merged == {'temperature': 0.2, 'top_k': 10} + + +def test_merge_sampling_params_uses_server_defaults(): + merged = merge_sampling_params({}, {'temperature': 0.5}) + assert merged == {'temperature': 0.5} + + +def test_extract_request_sampling_values_only_non_null(): + request = ChatCompletionRequest(model='test', messages='hi', temperature=0.3) + values = extract_request_sampling_values(request) + assert values == {'temperature': 0.3} + + +def test_build_generation_config_from_merged_values(): + request = ChatCompletionRequest(model='test', messages='hi', temperature=0.2) + gen_config = build_generation_config( + request, + {'top_k': 5}, + max_new_tokens=32, + ) + assert gen_config.temperature == 0.2 + assert gen_config.top_k == 5 + assert gen_config.max_new_tokens == 32 + assert gen_config.do_sample is True + + +def test_build_generation_config_max_new_tokens_defaults_to_none(): + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config(request, {}) + assert gen_config.max_new_tokens is None + + +def test_build_generation_config_uses_generation_config_defaults(): + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config(request, {}) + assert gen_config.temperature == _DEFAULTS.temperature + assert gen_config.top_k == _DEFAULTS.top_k + + +@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') +def test_resolve_default_gen_config_auto(mock_load): + mock_load.return_value = { + 'temperature': 0.6, + 'top_p': 0.8, + 'max_new_tokens': 2048, + } + config = resolve_default_gen_config('auto', '/fake/model', False) + assert config == { + 'temperature': 0.6, + 'top_p': 0.8, + 'max_new_tokens': 2048, + } + mock_load.assert_called_once_with('/fake/model', False) + + +def test_resolve_default_gen_config_lmdeploy(): + config = resolve_default_gen_config('lmdeploy', '/fake/model', False) + assert config == {} + + +def test_completion_request_sampling_merge(): + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config(request, {'temperature': 0.9}) + assert gen_config.temperature == 0.9 + assert gen_config.top_k == _DEFAULTS.top_k