Skip to content

Commit cdda8d8

Browse files
nvyutwuclaude
andcommitted
feat: support multiple model names in --served_model_name
Allow specifying multiple served model names so that requests using any alias are accepted and the /v1/models endpoint returns all names. Changes: - serve.py: --served_model_name is now multiple=True (specify flag multiple times); launch_server/launch_grpc_server accept Sequence[str]; passes list to OpenAIServer - openai_server.py: __init__ accepts Union[str, Sequence[str]]; stores self.model (primary) and self.served_model_names (all aliases); get_model() returns a ModelCard for each name; added _resolve_model_name() to echo back the client-requested name in responses if it matches a known alias Usage: trtllm-serve model --served_model_name my-model --served_model_name alias1 Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
1 parent 7ee9e8b commit cdda8d8

File tree

2 files changed

+43
-23
lines changed

2 files changed

+43
-23
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import subprocess # nosec B404
88
import sys
99
from pathlib import Path
10-
from typing import Any, Dict, Literal, Mapping, Optional, Sequence
10+
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple
1111

1212
import click
1313
import torch
@@ -267,10 +267,13 @@ def launch_server(
267267
server_role: Optional[ServerRole] = None,
268268
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
269269
multimodal_server_config: Optional[MultimodalServerConfig] = None,
270-
served_model_name: Optional[str] = None):
270+
served_model_name: Optional[Sequence[str]] = None):
271271

272272
backend = llm_args["backend"]
273-
model = served_model_name or llm_args["model"]
273+
# served_model_name may be a sequence (tuple from click's multiple=True, or list).
274+
# Normalize to a list; empty sequence is treated as None.
275+
served_model_names: List[str] = list(served_model_name) if served_model_name else []
276+
model = served_model_names[0] if served_model_names else llm_args["model"]
274277
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
275278
socket.SOCK_STREAM)
276279
address_family = socket.AF_INET6 if all(
@@ -303,7 +306,7 @@ def launch_server(
303306
param_hint="backend")
304307

305308
server = OpenAIServer(generator=llm,
306-
model=model,
309+
model=served_model_names if served_model_names else model,
307310
tool_parser=tool_parser,
308311
server_role=server_role,
309312
metadata_server_cfg=metadata_server_cfg,
@@ -321,7 +324,7 @@ def launch_server(
321324
def launch_grpc_server(host: str,
322325
port: int,
323326
llm_args: dict,
324-
served_model_name: Optional[str] = None):
327+
served_model_name: Optional[Sequence[str]] = None):
325328
"""
326329
Launch a gRPC server for TensorRT-LLM.
327330
@@ -350,7 +353,8 @@ async def serve_grpc_async():
350353
logger.info("Initializing TensorRT-LLM gRPC server...")
351354

352355
backend = llm_args.get("backend")
353-
model_path = served_model_name or llm_args.get("model", "")
356+
_names = list(served_model_name) if served_model_name else []
357+
model_path = _names[0] if _names else llm_args.get("model", "")
354358

355359
if backend == "pytorch":
356360
llm_args.pop("build_config", None)
@@ -766,11 +770,12 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
766770
@click.option(
767771
"--served_model_name",
768772
type=str,
773+
multiple=True,
769774
default=None,
770775
help=help_info_with_stability_tag(
771-
"The model name used in the API. If not specified, the model path is "
772-
"used as the model name. This is useful when the model path is long or "
773-
"when you want to expose a custom name to clients.", "prototype"))
776+
"The model name(s) used in the API. Can be specified multiple times for aliases. "
777+
"The first name is primary; additional names are aliases that the server also accepts. "
778+
"If not specified, the model path is used as the model name.", "prototype"))
774779
@click.option("--extra_visual_gen_options",
775780
type=str,
776781
default=None,
@@ -794,7 +799,7 @@ def serve(
794799
enable_attention_dp: bool, disagg_cluster_uri: Optional[str],
795800
media_io_kwargs: Optional[str], video_pruning_rate: Optional[float],
796801
custom_module_dirs: list[Path], chat_template: Optional[str],
797-
grpc: bool, served_model_name: Optional[str],
802+
grpc: bool, served_model_name: Tuple[str, ...],
798803
extra_visual_gen_options: Optional[str]):
799804
"""Running an OpenAI API compatible server
800805

tensorrt_llm/serve/openai_server.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from http import HTTPStatus
1616
from pathlib import Path
1717
from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List,
18-
Optional, Union)
18+
Optional, Sequence, Union)
1919

2020
import uvicorn
2121
from fastapi import Body, FastAPI, Request
@@ -175,7 +175,7 @@ class OpenAIServer:
175175
def __init__(
176176
self,
177177
generator: Union[LLM, MultimodalEncoder, VisualGen],
178-
model: str,
178+
model: Union[str, Sequence[str]],
179179
tool_parser: Optional[str],
180180
server_role: Optional[ServerRole],
181181
metadata_server_cfg: MetadataServerConfig,
@@ -194,11 +194,20 @@ def __init__(
194194
self.host = None
195195
self.port = None
196196

197-
model_dir = Path(model)
198-
if model_dir.exists() and model_dir.is_dir():
199-
self.model = model_dir.name
197+
# Normalize model names: accept either a single string or a list/tuple of strings.
198+
# self.model is the primary name; self.served_model_names includes all aliases.
199+
if isinstance(model, (list, tuple)):
200+
names = list(model)
200201
else:
201-
self.model = model
202+
names = [model]
203+
primary = names[0] if names else ""
204+
model_dir = Path(primary)
205+
if model_dir.exists() and model_dir.is_dir():
206+
primary = model_dir.name
207+
self.model = primary
208+
self.served_model_names: List[str] = [primary] + [
209+
n for n in names[1:] if n != primary
210+
]
202211
self.metrics_collector = None
203212
self.perf_metrics = None
204213
self.perf_metrics_lock = None
@@ -607,7 +616,7 @@ async def health_generate(self, raw_request: Request) -> Response:
607616
"role": "user",
608617
"content": "hi"
609618
}], # Minimal prompt (often > 1 token after tokenization)
610-
model=self.model,
619+
model=self.model, # Use primary model name for health checks
611620
max_completion_tokens=1, # Request only 1 token out
612621
stream=False,
613622
temperature=0.0, # Deterministic output
@@ -644,8 +653,14 @@ async def version(self) -> JSONResponse:
644653
ver = {"version": VERSION}
645654
return JSONResponse(content=ver)
646655

656+
def _resolve_model_name(self, requested: Optional[str]) -> str:
657+
"""Return the requested model name if it is a known alias, else the primary name."""
658+
if requested and requested in self.served_model_names:
659+
return requested
660+
return self.model
661+
647662
async def get_model(self) -> JSONResponse:
648-
model_list = ModelList(data=[ModelCard(id=self.model)])
663+
model_list = ModelList(data=[ModelCard(id=name) for name in self.served_model_names])
649664
return JSONResponse(content=model_list.model_dump())
650665

651666
async def get_iteration_stats(self) -> JSONResponse:
@@ -1062,7 +1077,7 @@ async def create_mm_embedding_response(promise: RequestOutput):
10621077
int(h["tensor_size"][0]) for h in mm_embedding_handles)
10631078
return ChatCompletionResponse(
10641079
id=str(promise.request_id),
1065-
model=self.model,
1080+
model=self._resolve_model_name(request.model),
10661081
choices=[
10671082
ChatCompletionResponseChoice(
10681083
index=0,
@@ -1173,7 +1188,7 @@ def merge_completion_responses(
11731188
cached_tokens=num_cached_tokens, ),
11741189
)
11751190
merged_rsp = CompletionResponse(
1176-
model=self.model,
1191+
model=self._resolve_model_name(request.model),
11771192
choices=all_choices,
11781193
usage=usage_info,
11791194
prompt_token_ids=all_prompt_token_ids,
@@ -1447,7 +1462,7 @@ async def create_response(
14471462
generator=promise,
14481463
request=request,
14491464
sampling_params=args.sampling_params,
1450-
model_name=self.model,
1465+
model_name=self._resolve_model_name(request.model),
14511466
conversation_store=self.conversation_store,
14521467
generation_result=None,
14531468
enable_store=self.enable_store and request.store,
@@ -1516,7 +1531,7 @@ async def create_streaming_generator(promise: RequestOutput,
15161531
streaming_processor = ResponsesStreamingProcessor(
15171532
request=request,
15181533
sampling_params=sampling_params,
1519-
model_name=self.model,
1534+
model_name=self._resolve_model_name(request.model),
15201535
conversation_store=self.conversation_store,
15211536
enable_store=self.enable_store and request.store,
15221537
use_harmony=self.use_harmony,
@@ -1525,7 +1540,7 @@ async def create_streaming_generator(promise: RequestOutput,
15251540
)
15261541

15271542
postproc_args = ResponsesAPIPostprocArgs(
1528-
model=self.model,
1543+
model=self._resolve_model_name(request.model),
15291544
request=request,
15301545
sampling_params=sampling_params,
15311546
use_harmony=self.use_harmony,

0 commit comments

Comments
 (0)