Skip to content

Commit 2e1deba

Browse files
committed
Guard NCCL EP lazy init during graph capture
Fail fast if the lazy NCCL EP context or dispatch handle would be initialized while CUDA graph capture is active, and cover those guard paths in the communication factory tests. Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent 3c9e1a0 commit 2e1deba

2 files changed

Lines changed: 45 additions & 0 deletions

File tree

tensorrt_llm/_torch/modules/fused_moe/communication/nccl_ep.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ def supports_post_quant_dispatch(self) -> bool:
114114

115115
def _get_context(self):
116116
if self._ctx is None:
117+
if torch.cuda.is_current_stream_capturing():
118+
raise RuntimeError(
119+
"NcclEP context must be initialized before CUDA graph capture. "
120+
"Run an eager warmup forward before enabling or capturing CUDA graphs."
121+
)
117122
from nccl.ep import Layout
118123

119124
from tensorrt_llm._torch.modules.fused_moe.nccl_ep_utils import (
@@ -134,6 +139,11 @@ def _get_context(self):
134139
def _setup_handle(self, ctx, topk_nd, stream):
135140
"""Ensure self._handle exists; rebind topk via handle.update on subsequent calls."""
136141
if self._handle is None:
142+
if torch.cuda.is_current_stream_capturing():
143+
raise RuntimeError(
144+
"NcclEP dispatch handle must be initialized before CUDA graph capture. "
145+
"Run an eager warmup forward before enabling or capturing CUDA graphs."
146+
)
137147
self._handle = ctx.ep_group.create_handle(
138148
ctx.layout,
139149
topk_nd,

tests/unittest/_torch/modules/moe/test_communication_factory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensorrt_llm._torch.modules.fused_moe.communication.allgather_reducescatter import (
2525
AllGatherReduceScatter,
2626
)
27+
from tensorrt_llm._torch.modules.fused_moe.communication.nccl_ep import NcclEP
2728

2829

2930
def _make_model_config(
@@ -181,3 +182,37 @@ def fail_if_called(*args, **kwargs):
181182
)
182183

183184
assert isinstance(strategy, AllGatherReduceScatter)
185+
186+
187+
def test_nccl_ep_context_init_rejects_cuda_graph_capture(
188+
monkeypatch: pytest.MonkeyPatch,
189+
):
190+
strategy = object.__new__(NcclEP)
191+
strategy._ctx = None
192+
monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True)
193+
194+
with pytest.raises(
195+
RuntimeError, match="context must be initialized before CUDA graph capture"
196+
):
197+
strategy._get_context()
198+
199+
200+
def test_nccl_ep_handle_init_rejects_cuda_graph_capture(
201+
monkeypatch: pytest.MonkeyPatch,
202+
):
203+
strategy = object.__new__(NcclEP)
204+
strategy._handle = None
205+
monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True)
206+
207+
def fail_create_handle(*args, **kwargs):
208+
raise AssertionError("create_handle should not run during CUDA graph capture")
209+
210+
ctx = SimpleNamespace(
211+
ep_group=SimpleNamespace(create_handle=fail_create_handle),
212+
layout=object(),
213+
)
214+
215+
with pytest.raises(
216+
RuntimeError, match="dispatch handle must be initialized before CUDA graph capture"
217+
):
218+
strategy._setup_handle(ctx, object(), 0)

0 commit comments

Comments
 (0)