Skip to content

Commit b72fb89

Browse files
MrGevaGitLab CI Bot
authored andcommitted
[None][fix] AutoDeploy: Fixed wrong dist_backend AUTO detection when using trtllm-llmapi-launch (NVIDIA#15423)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> Signed-off-by: GitLab CI Bot <gitlab-ci@nvidia.com>
1 parent 5f1b1c9 commit b72fb89

3 files changed

Lines changed: 70 additions & 3 deletions

File tree

tensorrt_llm/_torch/auto_deploy/custom_ops/distributed/trtllm_dist.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
The torch fallback variants are defined separately to enable multi-pattern matching.
2020
"""
2121

22+
import os
2223
from typing import List, Optional
2324

2425
import torch
@@ -295,5 +296,18 @@ def trtllm_fused_allreduce_residual_rmsnorm_out_quant_nvfp4_fake(
295296

296297

297298
def is_trtllm_op_available():
298-
"""Check if TRT-LLM ops are available and running with MPI."""
299-
return is_ompi()
299+
"""Check if TRT-LLM ops are available for AutoDeploy collectives."""
300+
if is_ompi():
301+
return True
302+
303+
# trtllm-llmapi-launch intentionally removes OMPI/SLURM variables from
304+
# the trtllm-serve child to avoid duplicate MPI initialization. It leaves
305+
# these launcher-specific variables so the child can bind to pre-spawned
306+
# LLMAPI worker ranks.
307+
if os.getenv("TLLM_SPAWN_PROXY_PROCESS") == "1":
308+
try:
309+
return int(os.getenv("tllm_mpi_size") or "1") > 1
310+
except ValueError:
311+
return False
312+
313+
return False

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,6 +1482,20 @@ def validate_allreduce_strategy(v):
14821482
return v # Let Pydantic handle other types
14831483

14841484

1485+
_LOGGED_DIST_BACKEND_CHOICES: set[tuple[str, str]] = set()
1486+
1487+
1488+
def _log_dist_backend_choice(configured_backend: str, resolved_backend: str):
1489+
key = (configured_backend, resolved_backend)
1490+
if key in _LOGGED_DIST_BACKEND_CHOICES:
1491+
return
1492+
_LOGGED_DIST_BACKEND_CHOICES.add(key)
1493+
ad_logger.info(
1494+
f"AutoDeploy selected distributed backend: {resolved_backend} "
1495+
f"(configured: {configured_backend})"
1496+
)
1497+
1498+
14851499
def _get_dist_ops(backend: str):
14861500
"""Get the (all_gather, all_reduce) op pair for *backend*.
14871501
@@ -1492,12 +1506,27 @@ def _get_dist_ops(backend: str):
14921506
"""
14931507
if hasattr(backend, "value"):
14941508
backend = backend.value
1509+
configured_backend = str(backend)
14951510

1496-
if backend == "trtllm" or is_trtllm_op_available():
1511+
if backend == "trtllm":
1512+
_log_dist_backend_choice(configured_backend, "trtllm")
1513+
return (
1514+
torch.ops.auto_deploy.trtllm_dist_all_gather.default,
1515+
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
1516+
)
1517+
if backend == "torch":
1518+
_log_dist_backend_choice(configured_backend, "torch")
1519+
return (
1520+
torch.ops.auto_deploy.torch_dist_all_gather.default,
1521+
torch.ops.auto_deploy.torch_dist_all_reduce.default,
1522+
)
1523+
if is_trtllm_op_available():
1524+
_log_dist_backend_choice(configured_backend, "trtllm")
14971525
return (
14981526
torch.ops.auto_deploy.trtllm_dist_all_gather.default,
14991527
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
15001528
)
1529+
_log_dist_backend_choice(configured_backend, "torch")
15011530
return (
15021531
torch.ops.auto_deploy.torch_dist_all_gather.default,
15031532
torch.ops.auto_deploy.torch_dist_all_reduce.default,

tests/unittest/auto_deploy/multigpu/transformations/library/test_dist_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import torch.nn as nn
2222
import torch.nn.functional as F
2323

24+
from tensorrt_llm._torch.auto_deploy.custom_ops.distributed.trtllm_dist import (
25+
is_trtllm_op_available,
26+
)
2427
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
2528
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
2629
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@@ -155,6 +158,27 @@ def test_dist_backend_auto_and_default(dist_backend):
155158
_check_dist_ops(gm_transformed, expected_backend="any")
156159

157160

161+
def test_trtllm_ops_available_with_llmapi_launch_env(monkeypatch):
162+
"""LLMAPI launcher strips OMPI env but still provides TRT-LLM worker ranks."""
163+
monkeypatch.delenv("OMPI_COMM_WORLD_SIZE", raising=False)
164+
monkeypatch.setenv("TLLM_SPAWN_PROXY_PROCESS", "1")
165+
monkeypatch.setenv("tllm_mpi_size", "2")
166+
167+
assert is_trtllm_op_available()
168+
169+
170+
@pytest.mark.parametrize("dist_backend", ["auto", None])
171+
def test_dist_backend_auto_uses_trtllm_with_llmapi_launch_env(monkeypatch, dist_backend):
172+
"""AUTO should select TRT-LLM ops under trtllm-llmapi-launch."""
173+
monkeypatch.delenv("OMPI_COMM_WORLD_SIZE", raising=False)
174+
monkeypatch.setenv("TLLM_SPAWN_PROXY_PROCESS", "1")
175+
monkeypatch.setenv("tllm_mpi_size", "2")
176+
177+
model = SimpleMLP()
178+
gm_transformed = _create_and_transform_model(model, dist_backend=dist_backend, world_size=2)
179+
_check_dist_ops(gm_transformed, expected_backend="trtllm")
180+
181+
158182
@pytest.mark.parametrize("dist_backend", ["torch", "trtllm"])
159183
def test_dist_backend_all_gather(dist_backend):
160184
"""Test dist_backend with all_gather operations (column sharding with single Linear)."""

0 commit comments

Comments
 (0)