Skip to content

Commit 4d9cfa8

Browse files
lvhan028cursoragent
andcommitted
refactor(serve): restore direct sampling validation in serving checks
Validate explicit request sampling fields only, leaving default merge to build_generation_config instead of duplicating it in check_request. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 6e30d4c commit 4d9cfa8

2 files changed

Lines changed: 12 additions & 38 deletions

File tree

lmdeploy/serve/openai/serving_chat_completion.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from typing import TYPE_CHECKING
33

4-
from lmdeploy.serve.core.generation_config import build_generation_config
5-
64
from .protocol import ChatCompletionRequest
75

86
if TYPE_CHECKING:
97
from .api_server import VariableInterface
108

119

12-
def _effective_sampling(request: ChatCompletionRequest, server_context: 'VariableInterface') -> dict:
13-
gen = build_generation_config(request, server_context.default_gen_config)
14-
return {
15-
'temperature': gen.temperature,
16-
'top_p': gen.top_p,
17-
'top_k': gen.top_k,
18-
}
19-
20-
2110
def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str:
2211
engine_config = server_context.get_engine_config()
2312
session_manager = server_context.get_session_manager()
@@ -43,17 +32,15 @@ def check_request(request: ChatCompletionRequest, server_context: 'VariableInter
4332
if session_manager.has(request.session_id):
4433
return f'The session_id {request.session_id!r} is occupied.'
4534

46-
sampling = _effective_sampling(request, server_context)
47-
4835
# check sampling settings
4936
if request.n <= 0:
5037
return f'The n {request.n!r} must be a positive int.'
51-
if not (0 < sampling['top_p'] <= 1):
52-
return f'The top_p {sampling["top_p"]!r} must be in (0, 1].'
53-
if sampling['top_k'] < 0:
54-
return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.'
55-
if not (0 <= sampling['temperature'] <= 2):
56-
return f'The temperature {sampling["temperature"]!r} must be in [0, 2]'
38+
if request.top_p is not None and not (0 < request.top_p <= 1):
39+
return f'The top_p {request.top_p!r} must be in (0, 1].'
40+
if request.top_k is not None and request.top_k < 0:
41+
return f'The top_k {request.top_k!r} cannot be a negative integer.'
42+
if request.temperature is not None and not (0 <= request.temperature <= 2):
43+
return f'The temperature {request.temperature!r} must be in [0, 2]'
5744

5845
# Validate input_ids and image_data constraints.
5946
# messages has higher priority. input_ids and image_data are only used when
Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from typing import TYPE_CHECKING
33

4-
from lmdeploy.serve.core.generation_config import build_generation_config
5-
64
from .protocol import CompletionRequest
75

86
if TYPE_CHECKING:
97
from .api_server import VariableInterface
108

119

12-
def _effective_sampling(request: CompletionRequest, server_context: 'VariableInterface') -> dict:
13-
gen = build_generation_config(request, server_context.default_gen_config)
14-
return {
15-
'temperature': gen.temperature,
16-
'top_p': gen.top_p,
17-
'top_k': gen.top_k,
18-
}
19-
20-
2110
def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str:
2211
engine_config = server_context.get_engine_config()
2312
session_manager = server_context.get_session_manager()
@@ -35,16 +24,14 @@ def check_request(request: CompletionRequest, server_context: 'VariableInterface
3524
if session_manager.has(request.session_id):
3625
return f'The session_id {request.session_id!r} is occupied.'
3726

38-
sampling = _effective_sampling(request, server_context)
39-
4027
# check sampling settings
4128
if request.n <= 0:
4229
return f'The n {request.n!r} must be a positive int.'
43-
if not (0 < sampling['top_p'] <= 1):
44-
return f'The top_p {sampling["top_p"]!r} must be in (0, 1].'
45-
if sampling['top_k'] < 0:
46-
return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.'
47-
if not (0 <= sampling['temperature'] <= 2):
48-
return f'The temperature {sampling["temperature"]!r} must be in [0, 2]'
30+
if request.top_p is not None and not (0 < request.top_p <= 1):
31+
return f'The top_p {request.top_p!r} must be in (0, 1].'
32+
if request.top_k is not None and request.top_k < 0:
33+
return f'The top_k {request.top_k!r} cannot be a negative integer.'
34+
if request.temperature is not None and not (0 <= request.temperature <= 2):
35+
return f'The temperature {request.temperature!r} must be in [0, 2]'
4936

5037
return ''

0 commit comments

Comments
 (0)