Skip to content

Commit 5e3f012

Browse files
authored
[https://nvbugs/6143883][fix] Preserve ip:port for trtllm-serve visual-gen (#14355)
Signed-off-by: JunyiXu-nv <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent af5a22e commit 5e3f012

2 files changed

Lines changed: 40 additions & 17 deletions

File tree

tensorrt_llm/commands/serve.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -560,23 +560,40 @@ def launch_visual_gen_server(
560560
visual_gen_args: Optional validated VisualGenArgs for model configuration.
561561
metadata_server_cfg: Optional metadata server configuration.
562562
"""
563-
logger.info(f"Initializing VisualGen ({model})")
563+
# Reserve the listening (host, port) by binding the socket *before*
564+
# constructing the VisualGen pipeline, then hand the bound socket to
565+
# uvicorn. VisualGen initialization can take many minutes; if we deferred
566+
# the bind until uvicorn started, anything else on the host could grab the
567+
# port in that window and trtllm-serve would die at bind() time.
568+
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
569+
socket.SOCK_STREAM)
570+
address_family = socket.AF_INET6 if all(
571+
[info[0] == socket.AF_INET6 for info in addr_info]) else socket.AF_INET
572+
with socket.socket(address_family, socket.SOCK_STREAM) as s:
573+
try:
574+
s.bind((host, port))
575+
except OSError as e:
576+
raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}")
564577

565-
visual_gen_model = VisualGen(model=model, args=visual_gen_args)
578+
logger.info(f"Initializing VisualGen ({model})")
566579

567-
n_workers = visual_gen_model.args.parallel_config.n_workers
568-
logger.info(f"World size: {n_workers}")
569-
logger.info(f"CFG size: {visual_gen_model.args.parallel_config.cfg_size}")
570-
logger.info(
571-
f"Ulysses size: {visual_gen_model.args.parallel_config.ulysses_size}")
580+
visual_gen_model = VisualGen(model=model, args=visual_gen_args)
572581

573-
server = OpenAIServer(generator=visual_gen_model,
574-
model=model,
575-
server_role=ServerRole.VISUAL_GEN,
576-
metadata_server_cfg=metadata_server_cfg,
577-
tool_parser=None)
578-
_apply_fastapi_middlewares(server.app, middleware)
579-
asyncio.run(server(host, port))
582+
n_workers = visual_gen_model.args.parallel_config.n_workers
583+
logger.info(f"World size: {n_workers}")
584+
logger.info(
585+
f"CFG size: {visual_gen_model.args.parallel_config.cfg_size}")
586+
logger.info(
587+
f"Ulysses size: {visual_gen_model.args.parallel_config.ulysses_size}"
588+
)
589+
590+
server = OpenAIServer(generator=visual_gen_model,
591+
model=model,
592+
server_role=ServerRole.VISUAL_GEN,
593+
metadata_server_cfg=metadata_server_cfg,
594+
tool_parser=None)
595+
_apply_fastapi_middlewares(server.app, middleware)
596+
asyncio.run(server(host, port, sockets=[s]))
580597

581598

582599
class ChoiceWithAlias(click.Choice):

tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
import requests
4141
import yaml
4242

43-
from tensorrt_llm._utils import get_free_port
44-
4543
# ---------------------------------------------------------------------------
4644
# Model paths
4745
# ---------------------------------------------------------------------------
@@ -69,6 +67,14 @@ def _llm_models_root() -> str:
6967
_PROJECT_ROOT = Path(__file__).resolve().parents[4] # repo root
7068
_REF_IMAGE_PATH = _PROJECT_ROOT / "examples" / "visual_gen" / "cat_piano.png"
7169

70+
# Use the CI-aware port allocator from tests/integration/defs/common.py so
71+
# parallel pytest sessions on the same OCI node fall into disjoint port
72+
# sections (CONTAINER_PORT_START / CONTAINER_PORT_NUM). It transparently falls
73+
# back to the plain free-port scan when those env vars are not set.
74+
_INTEGRATION_TESTS_DIR = _PROJECT_ROOT / "tests" / "integration"
75+
if str(_INTEGRATION_TESTS_DIR) not in sys.path:
76+
sys.path.insert(0, str(_INTEGRATION_TESTS_DIR))
77+
from defs.common import get_free_port_in_ci # noqa: E402
7278

7379
# ---------------------------------------------------------------------------
7480
# Remote server helper (follows RemoteOpenAIServer pattern)
@@ -94,7 +100,7 @@ def __init__(
94100
env: Optional[dict] = None,
95101
) -> None:
96102
self.host = host
97-
self.port = port if port is not None else get_free_port()
103+
self.port = port if port is not None else get_free_port_in_ci()
98104
self._config_file: Optional[str] = None
99105
self.proc: Optional[subprocess.Popen] = None
100106

0 commit comments

Comments
 (0)