Skip to content

Commit fdbf098

Browse files
committed
fix(serve): filter HF generation config and fix request extract merge
1 parent 1cb9465 commit fdbf098

4 files changed

Lines changed: 102 additions & 21 deletions

File tree

lmdeploy/serve/core/generation_config.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
import dataclasses
88
from typing import Any
99

10+
from pydantic import BaseModel
11+
1012
from lmdeploy.messages import GenerationConfig
1113
from lmdeploy.utils import get_logger
1214

1315
logger = get_logger('lmdeploy')
16+
_GENERATION_CONFIG_FIELDS = {field.name for field in dataclasses.fields(GenerationConfig)}
17+
18+
19+
def _filter_gen_config(config: dict[str, Any]) -> dict[str, Any]:
20+
return {key: value for key, value in config.items() if key in _GENERATION_CONFIG_FIELDS}
1421

1522

1623
def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]:
@@ -37,6 +44,7 @@ def resolve_default_gen_config(
3744
config = _load_hf_generation_config(model_path, trust_remote_code)
3845
else:
3946
config = _load_hf_generation_config(src, trust_remote_code)
47+
config = _filter_gen_config(config)
4048

4149
if config and src != 'lmdeploy':
4250
source = "the model's `generation_config.json`" if src == 'auto' else src
@@ -61,20 +69,19 @@ def merge_gen_config(
6169
return merged
6270

6371

64-
def extract_request_gen_config(request: Any) -> dict[str, Any]:
65-
"""Extract non-None GenerationConfig fields present on the request."""
66-
values: dict[str, Any] = {}
67-
for field in dataclasses.fields(GenerationConfig):
68-
if not hasattr(request, field.name):
69-
continue
70-
value = getattr(request, field.name)
71-
if value is not None:
72-
values[field.name] = value
73-
return values
72+
def extract_request_gen_config(request: BaseModel) -> dict[str, Any]:
73+
"""Extract explicit non-None GenerationConfig fields from a request."""
74+
# exclude_unset keeps client-supplied fields plus parser-updated fields,
75+
# while leaving plain Pydantic defaults available for server defaults.
76+
return {
77+
key: value
78+
for key, value in request.model_dump(exclude_unset=True).items()
79+
if key in _GENERATION_CONFIG_FIELDS and value is not None
80+
}
7481

7582

7683
def build_generation_config(
77-
request: Any,
84+
request: BaseModel,
7885
default_gen_config: dict[str, Any],
7986
*,
8087
max_new_tokens: int | None = None,
@@ -85,7 +92,7 @@ def build_generation_config(
8592
request_gen_config = extract_request_gen_config(request)
8693
for key in extra_kwargs:
8794
request_gen_config.pop(key, None)
88-
merged = merge_gen_config(request_gen_config, default_gen_config)
95+
merged = merge_gen_config(request_gen_config, _filter_gen_config(default_gen_config))
8996
merged.pop('max_new_tokens', None)
9097
merged.pop('do_sample', None)
9198
return GenerationConfig(

lmdeploy/serve/openai/protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,12 @@ class GenerateReqInput(BaseModel):
549549
stop: str | list[str] | None = None
550550
stop_token_ids: list[int] | None = None
551551
stream: bool | None = False
552-
temperature: float = 1.0
552+
temperature: float | None = None
553553
repetition_penalty: float | None = None
554554
ignore_eos: bool | None = False
555-
top_p: float = 1.0
556-
top_k: int = 0
557-
min_p: float = 0.0
555+
top_p: float | None = None
556+
top_k: int | None = None
557+
min_p: float | None = None
558558
skip_special_tokens: bool | None = True
559559
spaces_between_special_tokens: bool | None = True
560560
include_stop_str_in_output: bool | None = False

lmdeploy/serve/openai/serving_generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def check_request(request: GenerateReqInput, server_context: 'VariableInterface'
3535
return f'The session_id {request.session_id!r} is occupied.'
3636

3737
# check sampling settings
38-
if not (0 < request.top_p <= 1):
38+
if request.top_p is not None and not (0 < request.top_p <= 1):
3939
return f'The top_p {request.top_p!r} must be in (0, 1].'
40-
if request.top_k < 0:
40+
if request.top_k is not None and request.top_k < 0:
4141
return f'The top_k {request.top_k!r} cannot be a negative integer.'
42-
if not (0 <= request.temperature <= 2):
42+
if request.temperature is not None and not (0 <= request.temperature <= 2):
4343
return f'The temperature {request.temperature!r} must be in [0, 2]'
4444

4545
return ''

tests/test_lmdeploy/serve/test_generation_config.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import warnings
23
from unittest.mock import patch
34

45
from lmdeploy.messages import GenerationConfig
@@ -8,11 +9,31 @@
89
merge_gen_config,
910
resolve_default_gen_config,
1011
)
11-
from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest
12+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, GenerateReqInput
13+
from lmdeploy.serve.openai.serving_generate import check_request as check_generate_request
1214

1315
_DEFAULTS = GenerationConfig()
1416

1517

18+
class _FakeEngineConfig:
19+
logprobs_mode = None
20+
21+
22+
class _FakeSessionManager:
23+
24+
def has(self, session_id):
25+
return False
26+
27+
28+
class _FakeServerContext:
29+
30+
def get_engine_config(self):
31+
return _FakeEngineConfig()
32+
33+
def get_session_manager(self):
34+
return _FakeSessionManager()
35+
36+
1637
def test_merge_gen_config_priority():
1738
merged = merge_gen_config(
1839
{'temperature': 0.2},
@@ -26,7 +47,7 @@ def test_merge_gen_config_uses_server_defaults():
2647
assert merged == {'temperature': 0.5}
2748

2849

29-
def test_extract_request_gen_config_only_non_null():
50+
def test_extract_request_gen_config_only_explicit_fields():
3051
request = ChatCompletionRequest(model='test', messages='hi', temperature=0.3)
3152
values = extract_request_gen_config(request)
3253
assert values == {'temperature': 0.3}
@@ -58,12 +79,65 @@ def test_build_generation_config_uses_generation_config_defaults():
5879
assert gen_config.top_k == _DEFAULTS.top_k
5980

6081

82+
def test_build_generation_config_ignores_unsupported_defaults():
83+
request = CompletionRequest(model='test', prompt='hello')
84+
gen_config = build_generation_config(
85+
request,
86+
{
87+
'temperature': 0.6,
88+
'eos_token_id': 2,
89+
'pad_token_id': 0,
90+
'transformers_version': '5.12.1',
91+
},
92+
)
93+
assert gen_config.temperature == 0.6
94+
95+
96+
def test_completion_request_max_tokens_is_optional():
97+
request = CompletionRequest(model='test', prompt='hello')
98+
with warnings.catch_warnings():
99+
warnings.simplefilter('ignore', DeprecationWarning)
100+
assert request.max_tokens is None
101+
102+
103+
def test_generate_request_sampling_defaults_match_chat_request():
104+
chat_request = ChatCompletionRequest(model='test', messages='hello')
105+
generate_request = GenerateReqInput(prompt='hello')
106+
for name in ('temperature', 'top_p', 'top_k', 'min_p'):
107+
assert getattr(generate_request, name) == getattr(chat_request, name)
108+
109+
110+
def test_generate_request_accepts_none_sampling_defaults():
111+
request = GenerateReqInput(prompt='hello')
112+
assert check_generate_request(request, _FakeServerContext()) == ''
113+
114+
115+
def test_generate_request_sampling_merge_uses_server_defaults():
116+
request = GenerateReqInput(prompt='hello')
117+
gen_config = build_generation_config(
118+
request,
119+
{
120+
'temperature': 0.2,
121+
'top_p': 0.3,
122+
'top_k': 7,
123+
'min_p': 0.1,
124+
},
125+
max_new_tokens=request.max_tokens,
126+
)
127+
assert gen_config.temperature == 0.2
128+
assert gen_config.top_p == 0.3
129+
assert gen_config.top_k == 7
130+
assert gen_config.min_p == 0.1
131+
132+
61133
@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config')
62134
def test_resolve_default_gen_config_auto(mock_load):
63135
mock_load.return_value = {
64136
'temperature': 0.6,
65137
'top_p': 0.8,
66138
'max_new_tokens': 2048,
139+
'eos_token_id': 2,
140+
'transformers_version': '5.12.1',
67141
}
68142
config = resolve_default_gen_config('auto', '/fake/model', False)
69143
assert config == {

0 commit comments

Comments
 (0)