Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess # nosec B404
import sys
from pathlib import Path
from typing import Any, Dict, Literal, Mapping, Optional, Sequence
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple

import click
import torch
Expand Down Expand Up @@ -267,10 +267,13 @@ def launch_server(
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
multimodal_server_config: Optional[MultimodalServerConfig] = None,
served_model_name: Optional[str] = None):
served_model_name: Optional[Sequence[str]] = None):

backend = llm_args["backend"]
model = served_model_name or llm_args["model"]
# served_model_name may be a sequence (tuple from click's multiple=True, or list).
# Normalize to a list; empty sequence is treated as None.
served_model_names: List[str] = list(served_model_name) if served_model_name else []
model = served_model_names[0] if served_model_names else llm_args["model"]
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
socket.SOCK_STREAM)
address_family = socket.AF_INET6 if all(
Expand Down Expand Up @@ -303,7 +306,7 @@ def launch_server(
param_hint="backend")

server = OpenAIServer(generator=llm,
model=model,
model=served_model_names if served_model_names else model,
tool_parser=tool_parser,
server_role=server_role,
metadata_server_cfg=metadata_server_cfg,
Expand All @@ -321,7 +324,7 @@ def launch_server(
def launch_grpc_server(host: str,
port: int,
llm_args: dict,
served_model_name: Optional[str] = None):
served_model_name: Optional[Sequence[str]] = None):
"""
Launch a gRPC server for TensorRT-LLM.

Expand All @@ -332,7 +335,9 @@ def launch_grpc_server(host: str,
host: Host to bind to
port: Port to bind to
llm_args: Arguments for LLM initialization (from get_llm_args)
served_model_name: Custom model name for API responses (defaults to model path)
served_model_name: Model name(s) for API responses (defaults to model path).
Note: the gRPC server only uses the first (primary) name. Multiple
aliases are supported by the HTTP/OpenAI server only.
"""
import grpc

Expand All @@ -350,7 +355,8 @@ async def serve_grpc_async():
logger.info("Initializing TensorRT-LLM gRPC server...")

backend = llm_args.get("backend")
model_path = served_model_name or llm_args.get("model", "")
_names = list(served_model_name) if served_model_name else []
model_path = _names[0] if _names else llm_args.get("model", "")

if backend == "pytorch":
llm_args.pop("build_config", None)
Expand Down Expand Up @@ -766,11 +772,12 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
@click.option(
"--served_model_name",
type=str,
multiple=True,
default=None,
help=help_info_with_stability_tag(
"The model name used in the API. If not specified, the model path is "
"used as the model name. This is useful when the model path is long or "
"when you want to expose a custom name to clients.", "prototype"))
"The model name(s) used in the API. Can be specified multiple times for aliases. "
"The first name is primary; additional names are aliases that the server also accepts. "
"If not specified, the model path is used as the model name.", "prototype"))
@click.option("--extra_visual_gen_options",
type=str,
default=None,
Expand All @@ -794,7 +801,7 @@ def serve(
enable_attention_dp: bool, disagg_cluster_uri: Optional[str],
media_io_kwargs: Optional[str], video_pruning_rate: Optional[float],
custom_module_dirs: list[Path], chat_template: Optional[str],
grpc: bool, served_model_name: Optional[str],
grpc: bool, served_model_name: Tuple[str, ...],
extra_visual_gen_options: Optional[str]):
"""Running an OpenAI API compatible server

Expand Down
46 changes: 33 additions & 13 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from http import HTTPStatus
from pathlib import Path
from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List,
Optional, Union)
Optional, Sequence, Union)

import uvicorn
from fastapi import Body, FastAPI, Request
Expand Down Expand Up @@ -175,7 +175,7 @@ class OpenAIServer:
def __init__(
self,
generator: Union[LLM, MultimodalEncoder, VisualGen],
model: str,
model: Union[str, Sequence[str]],
tool_parser: Optional[str],
server_role: Optional[ServerRole],
metadata_server_cfg: MetadataServerConfig,
Expand All @@ -194,11 +194,25 @@ def __init__(
self.host = None
self.port = None

model_dir = Path(model)
if model_dir.exists() and model_dir.is_dir():
self.model = model_dir.name
# Normalize model names: accept either a single string or a list/tuple of strings.
# self.model is the primary name; self.served_model_names includes all aliases.
if isinstance(model, (list, tuple)):
names = list(model)
else:
self.model = model
names = [model]
if not names or not names[0]:
raise ValueError("At least one model name must be provided")
primary = names[0]
model_dir = Path(primary)
if model_dir.exists() and model_dir.is_dir():
primary = model_dir.name
self.model = primary
seen = {primary}
self.served_model_names: List[str] = [primary]
for n in names[1:]:
if n not in seen:
seen.add(n)
self.served_model_names.append(n)
self.metrics_collector = None
self.perf_metrics = None
self.perf_metrics_lock = None
Expand Down Expand Up @@ -607,7 +621,7 @@ async def health_generate(self, raw_request: Request) -> Response:
"role": "user",
"content": "hi"
}], # Minimal prompt (often > 1 token after tokenization)
model=self.model,
model=self.model, # Use primary model name for health checks
max_completion_tokens=1, # Request only 1 token out
stream=False,
temperature=0.0, # Deterministic output
Expand Down Expand Up @@ -644,8 +658,14 @@ async def version(self) -> JSONResponse:
ver = {"version": VERSION}
return JSONResponse(content=ver)

def _resolve_model_name(self, requested: Optional[str]) -> str:
"""Return the requested model name if it is a known alias, else the primary name."""
if requested and requested in self.served_model_names:
return requested
return self.model

async def get_model(self) -> JSONResponse:
model_list = ModelList(data=[ModelCard(id=self.model)])
model_list = ModelList(data=[ModelCard(id=name) for name in self.served_model_names])
return JSONResponse(content=model_list.model_dump())

async def get_iteration_stats(self) -> JSONResponse:
Expand Down Expand Up @@ -1062,7 +1082,7 @@ async def create_mm_embedding_response(promise: RequestOutput):
int(h["tensor_size"][0]) for h in mm_embedding_handles)
return ChatCompletionResponse(
id=str(promise.request_id),
model=self.model,
model=self._resolve_model_name(request.model),
choices=[
ChatCompletionResponseChoice(
index=0,
Expand Down Expand Up @@ -1173,7 +1193,7 @@ def merge_completion_responses(
cached_tokens=num_cached_tokens, ),
)
merged_rsp = CompletionResponse(
model=self.model,
model=self._resolve_model_name(request.model),
choices=all_choices,
usage=usage_info,
prompt_token_ids=all_prompt_token_ids,
Expand Down Expand Up @@ -1447,7 +1467,7 @@ async def create_response(
generator=promise,
request=request,
sampling_params=args.sampling_params,
model_name=self.model,
model_name=self._resolve_model_name(request.model),
conversation_store=self.conversation_store,
generation_result=None,
enable_store=self.enable_store and request.store,
Expand Down Expand Up @@ -1516,7 +1536,7 @@ async def create_streaming_generator(promise: RequestOutput,
streaming_processor = ResponsesStreamingProcessor(
request=request,
sampling_params=sampling_params,
model_name=self.model,
model_name=self._resolve_model_name(request.model),
conversation_store=self.conversation_store,
enable_store=self.enable_store and request.store,
use_harmony=self.use_harmony,
Expand All @@ -1525,7 +1545,7 @@ async def create_streaming_generator(promise: RequestOutput,
)

postproc_args = ResponsesAPIPostprocArgs(
model=self.model,
model=self._resolve_model_name(request.model),
request=request,
sampling_params=sampling_params,
use_harmony=self.use_harmony,
Expand Down