Skip to content

Commit 6e30d4c

Browse files
lvhan028cursoragent
andcommitted
refactor(serve): rename generation config helper functions
Rename extract_request_sampling_values to extract_request_gen_config and merge_sampling_params to merge_gen_config for clearer naming. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 2c4a284 commit 6e30d4c

2 files changed

Lines changed: 13 additions & 13 deletions

File tree

lmdeploy/serve/core/generation_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def resolve_default_gen_config(
4747
return config
4848

4949

50-
def merge_sampling_params(
50+
def merge_gen_config(
5151
request_values: dict[str, Any],
5252
default_gen_config: dict[str, Any],
5353
) -> dict[str, Any]:
54-
"""Merge sampling params with request > default_gen_config priority."""
54+
"""Merge generation config with request > default_gen_config priority."""
5555
merged: dict[str, Any] = {}
5656
for key in set(default_gen_config) | set(request_values):
5757
if key in request_values:
@@ -61,7 +61,7 @@ def merge_sampling_params(
6161
return merged
6262

6363

64-
def extract_request_sampling_values(request: Any) -> dict[str, Any]:
64+
def extract_request_gen_config(request: Any) -> dict[str, Any]:
6565
"""Extract non-None GenerationConfig fields present on the request."""
6666
values: dict[str, Any] = {}
6767
for field in dataclasses.fields(GenerationConfig):
@@ -82,8 +82,8 @@ def build_generation_config(
8282
) -> GenerationConfig:
8383
"""Build ``GenerationConfig`` from merged sampling defaults and request
8484
values."""
85-
request_values = extract_request_sampling_values(request)
86-
merged = merge_sampling_params(request_values, default_gen_config)
85+
request_values = extract_request_gen_config(request)
86+
merged = merge_gen_config(request_values, default_gen_config)
8787
merged.pop('max_new_tokens', None)
8888
merged.pop('do_sample', None)
8989
return GenerationConfig(

tests/test_lmdeploy/serve/test_generation_config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,31 @@
44
from lmdeploy.messages import GenerationConfig
55
from lmdeploy.serve.core.generation_config import (
66
build_generation_config,
7-
extract_request_sampling_values,
8-
merge_sampling_params,
7+
extract_request_gen_config,
8+
merge_gen_config,
99
resolve_default_gen_config,
1010
)
1111
from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest
1212

1313
_DEFAULTS = GenerationConfig()
1414

1515

16-
def test_merge_sampling_params_priority():
17-
merged = merge_sampling_params(
16+
def test_merge_gen_config_priority():
17+
merged = merge_gen_config(
1818
{'temperature': 0.2},
1919
{'temperature': 0.5, 'top_k': 10},
2020
)
2121
assert merged == {'temperature': 0.2, 'top_k': 10}
2222

2323

24-
def test_merge_sampling_params_uses_server_defaults():
25-
merged = merge_sampling_params({}, {'temperature': 0.5})
24+
def test_merge_gen_config_uses_server_defaults():
25+
merged = merge_gen_config({}, {'temperature': 0.5})
2626
assert merged == {'temperature': 0.5}
2727

2828

29-
def test_extract_request_sampling_values_only_non_null():
29+
def test_extract_request_gen_config_only_non_null():
3030
request = ChatCompletionRequest(model='test', messages='hi', temperature=0.3)
31-
values = extract_request_sampling_values(request)
31+
values = extract_request_gen_config(request)
3232
assert values == {'temperature': 0.3}
3333

3434

0 commit comments

Comments
 (0)