Skip to content

Commit 059de9c

Browse files
authored
[None][fix] PyExecutor Hang in Disagg TP Prefill (#14020)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
1 parent 06456e1 commit 059de9c

2 files changed

Lines changed: 46 additions & 14 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
get_global_profiler, host_profiler_context)
4343

4444
from ..distributed import Distributed
45+
from ..distributed.communicator import ReduceOp
4546
from ..expert_statistic import ExpertStatistic
4647
from ..models.modeling_llama import Llama4ForConditionalGeneration
4748
from ..models.modeling_utils import DecoderModelForCausalLM
@@ -1926,16 +1927,26 @@ def _executor_loop_pp(self):
19261927
and req.py_disaggregated_params.schedule_style ==
19271928
DisaggScheduleStyle.GENERATION_FIRST
19281929
for req in self.active_requests)
1929-
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
1930-
if not all_gen_first:
1930+
# [disagg-ctx-deadlock-fix] Mirror of the OR-gated entry in
1931+
# _executor_loop: ensure every TP rank either calls
1932+
# _check_disagg_ctx_cache_transfer_status together or skips
1933+
# it together, so the internal allgather in
1934+
# CacheTransceiver::checkContextTransferStatus always has
1935+
# full quorum. With PP > 1 the schedule is broadcast from
1936+
# rank 0 so num_fitting_reqs should already be uniform, but
1937+
# has_any_inflight_requests is rank-local and could
1938+
# otherwise diverge.
1939+
local_need_check = (num_fitting_reqs == 0 and
1940+
not fitting_disagg_gen_init_requests)
1941+
any_need_check = self.dist.allreduce(int(local_need_check),
1942+
op=ReduceOp.MAX)
1943+
if any_need_check > 0:
1944+
if local_need_check and not all_gen_first:
19311945
logger.warning(
19321946
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
19331947
)
19341948
self._check_disagg_ctx_cache_transfer_status(1)
1935-
elif self.async_transfer_manager.has_any_inflight_requests(
1936-
):
1937-
# Non-blocking cleanup of completed/timed-out
1938-
# transfers to free KV blocks (see _executor_loop).
1949+
else:
19391950
self._check_disagg_ctx_cache_transfer_status(0)
19401951

19411952
self.num_scheduled_requests = scheduled_batch.batch_size
@@ -2471,18 +2482,36 @@ def _prepare_and_schedule_batch(self):
24712482
req.py_disaggregated_params and req.py_disaggregated_params.
24722483
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
24732484
for req in self.active_requests)
2474-
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
2475-
if not all_gen_first:
2485+
# [disagg-ctx-deadlock-fix] _check_disagg_ctx_cache_transfer_status
2486+
# internally invokes a TP-wide allgather inside
2487+
# CacheTransceiver::checkContextTransferStatus. Gating the call on
2488+
# rank-local `num_fitting_reqs` (which can drift between ranks by
2489+
# one block due to per-rank UCX/CUDA-event-sync timing variance)
2490+
# lets some ranks enter the allgather while others skip ahead to
2491+
# model_forward / kv_connector → cross-rank collective-mismatch
2492+
# deadlock. OR the decision across TP ranks: if ANY rank wants the
2493+
# call, ALL ranks call it. Ranks that don't locally need it use the
2494+
# non-blocking variant so the collective stays in sync without
2495+
# holding any individual rank.
2496+
local_need_check = (num_fitting_reqs == 0
2497+
and not fitting_disagg_gen_init_requests)
2498+
any_need_check = self.dist.allreduce(int(local_need_check),
2499+
op=ReduceOp.MAX)
2500+
if any_need_check > 0:
2501+
if local_need_check and not all_gen_first:
24762502
logger.warning(
24772503
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
24782504
)
2505+
# Local conditions warrant a blocking wait for at least one
2506+
# in-flight transfer to complete so KV blocks can be freed.
24792507
self._check_disagg_ctx_cache_transfer_status(1)
2480-
elif self.async_transfer_manager.has_any_inflight_requests():
2481-
# Non-blocking cleanup of completed/timed-out transfers
2482-
# to free KV blocks. We avoid the blocking check because
2483-
# gen-first requests may be waiting for peer info (which
2484-
# would block indefinitely), but completed transfers must
2485-
# still be reaped so that KV cache can be reclaimed.
2508+
else:
2509+
# Either (a) a peer rank needed the call but we didn't, or
2510+
# (b) all active requests are gen-first so we don't
2511+
# actively block. In both cases the non-blocking variant
2512+
# still runs the internal allgather (keeping all ranks in
2513+
# sync) and reaps any already-completed transfers without
2514+
# blocking on un-finished ones.
24862515
self._check_disagg_ctx_cache_transfer_status(0)
24872516

24882517
# In gen-only benchmark mode, all requests must fit in KV cache

tests/unittest/_torch/executor/test_benchmark_disagg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def test_fetch_called_once_even_in_benchmark_disagg(self):
505505
ex.enable_attention_dp = False
506506
ex.num_fetch_requests = 0
507507
ex.dist = Mock(rank=0, tp_size=1)
508+
ex.dist.allreduce = Mock(side_effect=lambda v, op=None: v)
508509
ex.is_shutdown = False
509510
ex._is_warmup = False
510511
ex.enable_iter_perf_stats = False
@@ -894,6 +895,7 @@ def _make_executor(
894895
ex.enable_attention_dp = False
895896
ex.num_fetch_requests = num_fetch_requests
896897
ex.dist = Mock(rank=0, tp_size=1)
898+
ex.dist.allreduce.return_value = 0
897899
ex.is_shutdown = False
898900
ex._is_warmup = False
899901
ex.enable_iter_perf_stats = False
@@ -1040,6 +1042,7 @@ def _make_executor(self):
10401042
ex.num_fetch_requests = 0
10411043
ex.max_num_active_requests = self.MAX_BATCH_SIZE
10421044
ex.dist = Mock(rank=0, tp_size=self.TP_SIZE)
1045+
ex.dist.allreduce.return_value = 0
10431046
ex.is_shutdown = False
10441047
ex._is_warmup = False
10451048
ex.enable_iter_perf_stats = False

0 commit comments

Comments
 (0)