diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index fe9b8254a19b..82b3f96290ab 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -7,7 +7,7 @@ import subprocess # nosec B404 import sys from pathlib import Path -from typing import Any, Dict, Literal, Mapping, Optional, Sequence +from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple import click import torch @@ -267,10 +267,13 @@ def launch_server( server_role: Optional[ServerRole] = None, disagg_cluster_config: Optional[DisaggClusterConfig] = None, multimodal_server_config: Optional[MultimodalServerConfig] = None, - served_model_name: Optional[str] = None): + served_model_name: Optional[Sequence[str]] = None): backend = llm_args["backend"] - model = served_model_name or llm_args["model"] + # served_model_name may be a sequence (tuple from click's multiple=True, or list). + # Normalize to a list; empty sequence is treated as None. + served_model_names: List[str] = list(served_model_name) if served_model_name else [] + model = served_model_names[0] if served_model_names else llm_args["model"] addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM) address_family = socket.AF_INET6 if all( @@ -303,7 +306,7 @@ def launch_server( param_hint="backend") server = OpenAIServer(generator=llm, - model=model, + model=served_model_names if served_model_names else model, tool_parser=tool_parser, server_role=server_role, metadata_server_cfg=metadata_server_cfg, @@ -321,7 +324,7 @@ def launch_server( def launch_grpc_server(host: str, port: int, llm_args: dict, - served_model_name: Optional[str] = None): + served_model_name: Optional[Sequence[str]] = None): """ Launch a gRPC server for TensorRT-LLM. @@ -332,7 +335,9 @@ def launch_grpc_server(host: str, host: Host to bind to port: Port to bind to llm_args: Arguments for LLM initialization (from get_llm_args) - served_model_name: Custom model name for API responses (defaults to model path) + served_model_name: Model name(s) for API responses (defaults to model path). + Note: the gRPC server only uses the first (primary) name. Multiple + aliases are supported by the HTTP/OpenAI server only. """ import grpc @@ -350,7 +355,8 @@ async def serve_grpc_async(): logger.info("Initializing TensorRT-LLM gRPC server...") backend = llm_args.get("backend") - model_path = served_model_name or llm_args.get("model", "") + _names = list(served_model_name) if served_model_name else [] + model_path = _names[0] if _names else llm_args.get("model", "") if backend == "pytorch": llm_args.pop("build_config", None) @@ -766,11 +772,12 @@ def convert(self, value: Any, param: Optional["click.Parameter"], @click.option( "--served_model_name", type=str, + multiple=True, default=None, help=help_info_with_stability_tag( - "The model name used in the API. If not specified, the model path is " - "used as the model name. This is useful when the model path is long or " - "when you want to expose a custom name to clients.", "prototype")) + "The model name(s) used in the API. Can be specified multiple times for aliases. " + "The first name is primary; additional names are aliases that the server also accepts. " + "If not specified, the model path is used as the model name.", "prototype")) @click.option("--extra_visual_gen_options", type=str, default=None, @@ -794,7 +801,7 @@ def serve( enable_attention_dp: bool, disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str], video_pruning_rate: Optional[float], custom_module_dirs: list[Path], chat_template: Optional[str], - grpc: bool, served_model_name: Optional[str], + grpc: bool, served_model_name: Tuple[str, ...], extra_visual_gen_options: Optional[str]): """Running an OpenAI API compatible server diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index bebf897bf3d2..8ce5c7ea4286 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -15,7 +15,7 @@ from http import HTTPStatus from pathlib import Path from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List, - Optional, Union) + Optional, Sequence, Union) import uvicorn from fastapi import Body, FastAPI, Request @@ -175,7 +175,7 @@ class OpenAIServer: def __init__( self, generator: Union[LLM, MultimodalEncoder, VisualGen], - model: str, + model: Union[str, Sequence[str]], tool_parser: Optional[str], server_role: Optional[ServerRole], metadata_server_cfg: MetadataServerConfig, @@ -194,11 +194,25 @@ def __init__( self.host = None self.port = None - model_dir = Path(model) - if model_dir.exists() and model_dir.is_dir(): - self.model = model_dir.name + # Normalize model names: accept either a single string or a list/tuple of strings. + # self.model is the primary name; self.served_model_names includes all aliases. + if isinstance(model, (list, tuple)): + names = list(model) else: - self.model = model + names = [model] + if not names or not names[0]: + raise ValueError("At least one model name must be provided") + primary = names[0] + model_dir = Path(primary) + if model_dir.exists() and model_dir.is_dir(): + primary = model_dir.name + self.model = primary + seen = {primary} + self.served_model_names: List[str] = [primary] + for n in names[1:]: + if n not in seen: + seen.add(n) + self.served_model_names.append(n) self.metrics_collector = None self.perf_metrics = None self.perf_metrics_lock = None @@ -607,7 +621,7 @@ async def health_generate(self, raw_request: Request) -> Response: "role": "user", "content": "hi" }], # Minimal prompt (often > 1 token after tokenization) - model=self.model, + model=self.model, # Use primary model name for health checks max_completion_tokens=1, # Request only 1 token out stream=False, temperature=0.0, # Deterministic output @@ -644,8 +658,14 @@ async def version(self) -> JSONResponse: ver = {"version": VERSION} return JSONResponse(content=ver) + def _resolve_model_name(self, requested: Optional[str]) -> str: + """Return the requested model name if it is a known alias, else the primary name.""" + if requested and requested in self.served_model_names: + return requested + return self.model + async def get_model(self) -> JSONResponse: - model_list = ModelList(data=[ModelCard(id=self.model)]) + model_list = ModelList(data=[ModelCard(id=name) for name in self.served_model_names]) return JSONResponse(content=model_list.model_dump()) async def get_iteration_stats(self) -> JSONResponse: @@ -1062,7 +1082,7 @@ async def create_mm_embedding_response(promise: RequestOutput): int(h["tensor_size"][0]) for h in mm_embedding_handles) return ChatCompletionResponse( id=str(promise.request_id), - model=self.model, + model=self._resolve_model_name(request.model), choices=[ ChatCompletionResponseChoice( index=0, @@ -1173,7 +1193,7 @@ def merge_completion_responses( cached_tokens=num_cached_tokens, ), ) merged_rsp = CompletionResponse( - model=self.model, + model=self._resolve_model_name(request.model), choices=all_choices, usage=usage_info, prompt_token_ids=all_prompt_token_ids, @@ -1447,7 +1467,7 @@ async def create_response( generator=promise, request=request, sampling_params=args.sampling_params, - model_name=self.model, + model_name=self._resolve_model_name(request.model), conversation_store=self.conversation_store, generation_result=None, enable_store=self.enable_store and request.store, @@ -1516,7 +1536,7 @@ async def create_streaming_generator(promise: RequestOutput, streaming_processor = ResponsesStreamingProcessor( request=request, sampling_params=sampling_params, - model_name=self.model, + model_name=self._resolve_model_name(request.model), conversation_store=self.conversation_store, enable_store=self.enable_store and request.store, use_harmony=self.use_harmony, @@ -1525,7 +1545,7 @@ async def create_streaming_generator(promise: RequestOutput, ) postproc_args = ResponsesAPIPostprocArgs( - model=self.model, + model=self._resolve_model_name(request.model), request=request, sampling_params=sampling_params, use_harmony=self.use_harmony,