11# Copyright (c) OpenMMLab. All rights reserved.
2+ import warnings
23from unittest .mock import patch
34
45from lmdeploy .messages import GenerationConfig
89 merge_gen_config ,
910 resolve_default_gen_config ,
1011)
11- from lmdeploy .serve .openai .protocol import ChatCompletionRequest , CompletionRequest
12+ from lmdeploy .serve .openai .protocol import ChatCompletionRequest , CompletionRequest , GenerateReqInput
13+ from lmdeploy .serve .openai .serving_generate import check_request as check_generate_request
1214
1315_DEFAULTS = GenerationConfig ()
1416
1517
18+ class _FakeEngineConfig :
19+ logprobs_mode = None
20+
21+
22+ class _FakeSessionManager :
23+
24+ def has (self , session_id ):
25+ return False
26+
27+
28+ class _FakeServerContext :
29+
30+ def get_engine_config (self ):
31+ return _FakeEngineConfig ()
32+
33+ def get_session_manager (self ):
34+ return _FakeSessionManager ()
35+
36+
1637def test_merge_gen_config_priority ():
1738 merged = merge_gen_config (
1839 {'temperature' : 0.2 },
@@ -26,7 +47,7 @@ def test_merge_gen_config_uses_server_defaults():
2647 assert merged == {'temperature' : 0.5 }
2748
2849
29- def test_extract_request_gen_config_only_non_null ():
50+ def test_extract_request_gen_config_only_explicit_fields ():
3051 request = ChatCompletionRequest (model = 'test' , messages = 'hi' , temperature = 0.3 )
3152 values = extract_request_gen_config (request )
3253 assert values == {'temperature' : 0.3 }
@@ -58,12 +79,65 @@ def test_build_generation_config_uses_generation_config_defaults():
5879 assert gen_config .top_k == _DEFAULTS .top_k
5980
6081
82+ def test_build_generation_config_ignores_unsupported_defaults ():
83+ request = CompletionRequest (model = 'test' , prompt = 'hello' )
84+ gen_config = build_generation_config (
85+ request ,
86+ {
87+ 'temperature' : 0.6 ,
88+ 'eos_token_id' : 2 ,
89+ 'pad_token_id' : 0 ,
90+ 'transformers_version' : '5.12.1' ,
91+ },
92+ )
93+ assert gen_config .temperature == 0.6
94+
95+
96+ def test_completion_request_max_tokens_is_optional ():
97+ request = CompletionRequest (model = 'test' , prompt = 'hello' )
98+ with warnings .catch_warnings ():
99+ warnings .simplefilter ('ignore' , DeprecationWarning )
100+ assert request .max_tokens is None
101+
102+
103+ def test_generate_request_sampling_defaults_match_chat_request ():
104+ chat_request = ChatCompletionRequest (model = 'test' , messages = 'hello' )
105+ generate_request = GenerateReqInput (prompt = 'hello' )
106+ for name in ('temperature' , 'top_p' , 'top_k' , 'min_p' ):
107+ assert getattr (generate_request , name ) == getattr (chat_request , name )
108+
109+
110+ def test_generate_request_accepts_none_sampling_defaults ():
111+ request = GenerateReqInput (prompt = 'hello' )
112+ assert check_generate_request (request , _FakeServerContext ()) == ''
113+
114+
115+ def test_generate_request_sampling_merge_uses_server_defaults ():
116+ request = GenerateReqInput (prompt = 'hello' )
117+ gen_config = build_generation_config (
118+ request ,
119+ {
120+ 'temperature' : 0.2 ,
121+ 'top_p' : 0.3 ,
122+ 'top_k' : 7 ,
123+ 'min_p' : 0.1 ,
124+ },
125+ max_new_tokens = request .max_tokens ,
126+ )
127+ assert gen_config .temperature == 0.2
128+ assert gen_config .top_p == 0.3
129+ assert gen_config .top_k == 7
130+ assert gen_config .min_p == 0.1
131+
132+
61133@patch ('lmdeploy.serve.core.generation_config._load_hf_generation_config' )
62134def test_resolve_default_gen_config_auto (mock_load ):
63135 mock_load .return_value = {
64136 'temperature' : 0.6 ,
65137 'top_p' : 0.8 ,
66138 'max_new_tokens' : 2048 ,
139+ 'eos_token_id' : 2 ,
140+ 'transformers_version' : '5.12.1' ,
67141 }
68142 config = resolve_default_gen_config ('auto' , '/fake/model' , False )
69143 assert config == {
0 commit comments