Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def add_parser_api_server():
# model args
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)
ArgumentHelper.generation_config(parser)

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
Expand Down Expand Up @@ -318,6 +319,7 @@ def api_server(args):
reasoning_parser=args.reasoning_parser,
tool_call_parser=args.tool_call_parser,
speculative_config=speculative_config,
generation_config=args.generation_config,
)
else:
from lmdeploy.serve.openai.launch_server import launch_server
Expand Down Expand Up @@ -350,6 +352,7 @@ def api_server(args):
reasoning_parser=args.reasoning_parser,
tool_call_parser=args.tool_call_parser,
speculative_config=speculative_config,
generation_config=args.generation_config,
)

@staticmethod
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,18 @@ def hf_overrides(parser):
default=None,
help='Extra arguments to be forwarded to the HuggingFace config.')

@staticmethod
def generation_config(parser):
"""Add argument generation_config to parser."""
return parser.add_argument(
'--generation-config',
type=str,
default='auto',
help='The folder path to the generation config. Defaults to "auto", the '
'generation config will be loaded from model path. If set to "lmdeploy", no '
'generation config is loaded, lmdeploy defaults will be used. If set to a folder '
'path, the generation config will be loaded from the specified folder path.')

@staticmethod
def use_logn_attn(parser):
"""Add argument use_logn_attn to parser."""
Expand Down
15 changes: 8 additions & 7 deletions lmdeploy/serve/anthropic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import shortuuid

from lmdeploy.messages import GenerationConfig
from lmdeploy.serve.core.generation_config import build_generation_config
from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName

from .protocol import (
Expand Down Expand Up @@ -341,15 +342,15 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[
return lm_messages


def to_generation_config(request: MessagesRequest) -> GenerationConfig:
def to_generation_config(
request: MessagesRequest,
default_gen_config: dict | None = None,
) -> GenerationConfig:
"""Map Anthropic messages request to LMDeploy generation config."""

return GenerationConfig(
return build_generation_config(
request,
default_gen_config or {},
max_new_tokens=request.max_tokens,
do_sample=True,
top_k=40 if request.top_k is None else request.top_k,
top_p=1.0 if request.top_p is None else request.top_p,
temperature=1.0 if request.temperature is None else request.temperature,
stop_words=request.stop_sequences,
include_stop_str_in_output=request.include_stop_str_in_output or False,
skip_special_tokens=True,
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/serve/anthropic/endpoints/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ async def create_message(request: MessagesRequest, raw_request: Request):
result_generator = server_context.async_engine.generate(
engine_messages,
session,
gen_config=to_generation_config(request),
gen_config=to_generation_config(
request,
default_gen_config=server_context.default_gen_config,
),
tools=parsed_request.tools,
stream_response=True,
sequence_start=True,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/anthropic/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class MessagesRequest(BaseModel):
system: str | list[ContentBlockParam] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = 1.0
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, Any] | None = None
Expand Down
94 changes: 94 additions & 0 deletions lmdeploy/serve/core/generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Server-side generation config resolution and sampling parameter merge
helpers."""

from __future__ import annotations

import dataclasses
from typing import Any

from lmdeploy.messages import GenerationConfig
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]:
from transformers import GenerationConfig

try:
cfg = GenerationConfig.from_pretrained(path, trust_remote_code=trust_remote_code)
return cfg.to_diff_dict()
except OSError:
return {}


def resolve_default_gen_config(
generation_config: str,
model_path: str,
trust_remote_code: bool,
) -> dict[str, Any]:
"""Resolve server-side default generation config from CLI flags."""
src = generation_config

if src == 'lmdeploy':
config: dict[str, Any] = {}
elif src == 'auto':
config = _load_hf_generation_config(model_path, trust_remote_code)
else:
config = _load_hf_generation_config(src, trust_remote_code)

if config and src != 'lmdeploy':
source = "the model's `generation_config.json`" if src == 'auto' else src
logger.info(
f'Using default generation config from {source}: {config}. '
'Use `--generation-config lmdeploy` to disable.')

return config


def merge_sampling_params(
request_values: dict[str, Any],
default_gen_config: dict[str, Any],
) -> dict[str, Any]:
"""Merge sampling params with request > default_gen_config priority."""
merged: dict[str, Any] = {}
for key in set(default_gen_config) | set(request_values):
if key in request_values:
merged[key] = request_values[key]
else:
merged[key] = default_gen_config[key]
return merged


def extract_request_sampling_values(request: Any) -> dict[str, Any]:
"""Extract non-None GenerationConfig fields present on the request."""
values: dict[str, Any] = {}
for field in dataclasses.fields(GenerationConfig):
if not hasattr(request, field.name):
continue
value = getattr(request, field.name)
if value is not None:
values[field.name] = value
return values


def build_generation_config(
request: Any,
default_gen_config: dict[str, Any],
*,
max_new_tokens: int | None = None,
**extra_kwargs: Any,
) -> GenerationConfig:
"""Build ``GenerationConfig`` from merged sampling defaults and request
values."""
request_values = extract_request_sampling_values(request)
merged = merge_sampling_params(request_values, default_gen_config)
merged.pop('max_new_tokens', None)
merged.pop('do_sample', None)
return GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
**merged,
**extra_kwargs,
)
56 changes: 31 additions & 25 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
)
from lmdeploy.serve.anthropic import create_anthropic_router
from lmdeploy.serve.core import AsyncEngine, EngineHealthMonitor
from lmdeploy.serve.core.generation_config import (
build_generation_config,
resolve_default_gen_config,
)
from lmdeploy.serve.openai.protocol import (
AbortRequest,
ChatCompletionRequest,
Expand Down Expand Up @@ -110,6 +114,7 @@ class VariableInterface:
allow_terminate_by_client: bool = False
enable_abort_handling: bool = False
response_parser_cls: type[ResponseParser] | None = None
default_gen_config: dict = {}

@classmethod
def create_session(cls, user_session_id: int | None = None) -> Session:
Expand Down Expand Up @@ -147,6 +152,19 @@ def get_engine_config(cls):
return cls.async_engine.backend_config


def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfig:
"""Build ``GenerationConfig`` with server and request sampling merge."""
max_new_tokens = getattr(request, 'max_completion_tokens', None)
if max_new_tokens is None:
max_new_tokens = getattr(request, 'max_tokens', None)
return build_generation_config(
request,
VariableInterface.default_gen_config,
max_new_tokens=max_new_tokens,
**extra_kwargs,
)


async def _with_request_cleanup(generator, result_generators, sessions):
"""Yield from an API generator and cleanup when the HTTP task exits."""
session_mgr = VariableInterface.get_session_manager()
Expand Down Expand Up @@ -486,22 +504,16 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
# (e.g. GPT-OSS clears response_format and injects the schema into messages)
request = response_parser.request

gen_config = GenerationConfig(
max_new_tokens=request.max_completion_tokens,
do_sample=True,
gen_config = _build_serving_generation_config(
request,
logprobs=gen_logprobs,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
include_stop_str_in_output=request.include_stop_str_in_output,
skip_special_tokens=request.skip_special_tokens,
response_format=request.response_format,
logits_processors=logits_processors,
min_new_tokens=request.min_new_tokens,
min_p=request.min_p,
random_seed=random_seed,
spaces_between_special_tokens=request.spaces_between_special_tokens,
migration_request=migration_request,
Expand Down Expand Up @@ -828,20 +840,13 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
if isinstance(request.stop, str):
request.stop = [request.stop]
random_seed = request.seed if request.seed is not None else None
max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens)

gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
gen_config = _build_serving_generation_config(
request,
logprobs=request.logprobs,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens,
min_p=request.min_p,
random_seed=random_seed,
spaces_between_special_tokens=request.spaces_between_special_tokens,
migration_request=migration_request,
Expand Down Expand Up @@ -1012,15 +1017,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
prompt = [dict(role='user', content=[text_input] + image_input)]
input_ids = None

gen_config = GenerationConfig(
max_new_tokens=request.max_tokens,
do_sample=True,
gen_config = _build_serving_generation_config(
request,
logprobs=1 if request.return_logprob else None,
top_k=request.top_k,
top_p=request.top_p,
min_p=request.min_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
stop_token_ids=request.stop_token_ids,
Expand Down Expand Up @@ -1546,6 +1545,7 @@ def serve(model_path: str,
allow_terminate_by_client: bool = False,
enable_abort_handling: bool = False,
speculative_config: SpeculativeConfig | None = None,
generation_config: str = 'auto',
**kwargs):
"""An example to perform model inference through the command line
interface.
Expand Down Expand Up @@ -1606,6 +1606,12 @@ def serve(model_path: str,
VariableInterface.allow_terminate_by_client = allow_terminate_by_client
VariableInterface.enable_abort_handling = enable_abort_handling

VariableInterface.default_gen_config = resolve_default_gen_config(
generation_config,
model_path,
trust_remote_code,
)

ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http'
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
Expand Down
Loading
Loading