Skip to content

Commit f96369a

Browse files
committed
[None][fix] Post-rebase fixes + coverage tests for DeepSeek-V4 disaggregation
Squash of the post-cherry-pick work layered on top of the 8 DeepSeek-V4 disaggregation cherry-picks. Fixes: - ADP disagg error path: restore per-request hang signal (_event_loop_error), scan all candidates + prefer CTX role for mixed-batch dummy padding, and keep charge_budget=False on KV-transfer timeouts so they don't exhaust the global error budget and shut down the executor. - _count_schedulable_active_requests: revert the GENERATION_TO_COMPLETE exclusion that #13900 added (upstream has no such exclusion). Under V1 ADP it undercounted -> a spurious ADP dummy was inserted on top of a real request -> the batch exceeded max_batch_size=1 and tripped the mamba dummy-mask assert / "No free slots". Restores upstream's exact method (fixes test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance). - _handle_responses KV-timeout: gate the tp_allgather on disagg (kv_cache_transceiver is not None), not just enable_attention_dp. py_kv_transfer_timed_out is disagg-only, so in non-disagg ADP this added a spurious per-iteration collective that desynced adp_router's tp_allgather (gather_all_rank_states received a bool -> TypeError). Verified by 4-GPU DeepSeek-V3-Lite adp_balance e2e A/B (buggy: timeout hang; fixed: completes). - transceiver: only short-circuit the tp_allgather skip when pp_size==1 (_ctx_need_pp_sync) -- the PP>1 path asymmetrically flips send/recv markers across pipeline stages and deadlocks the _ctx_consensus pp_allgather. - py_executor: restore main's immediate benchmark fail-fast guard. - resource_manager: do NOT narrow trim_to_history's except (resize() can raise non-ValueError under v2 SWA + uneven-PP; narrowing leaked KV blocks). Tests (added to existing files): - test_py_executor.py: disagg cache-error sync + ADP no-op paths; ADP dummy-role behavior; GENERATION_TO_COMPLETE counts as active (adp_balance regression); CTX dummy padding for disagg idle ranks (incl. awaiting-KV-transfer-only). - test_kv_cache_v2_scheduler.py: trim_to_history. - test_cache_reuse_adapter.py: trim-to-prompt-history + transceiver ctx mgr; _create_kv_slice TokenRange (stub sets py_beam_width for the #14876 path). - test_router.py: finish_request explicit-session forwarding. - test_agent.py: BindingsNixlTransferStatus + shutdown idempotency (#14137). - transferAgentTest.cpp: status-outlives-agent (weak_ptr UAF safety) + concurrent submitTransferRequests. Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent f02b4c2 commit f96369a

8 files changed

Lines changed: 566 additions & 104 deletions

File tree

cpp/tests/unit_tests/executor/transferAgentTest.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <gmock/gmock.h>
2222
#include <gtest/gtest.h>
2323

24+
#include <atomic>
2425
#include <filesystem>
26+
#include <thread>
2527
#include <vector>
2628

2729
namespace fs = std::filesystem;
@@ -376,6 +378,107 @@ TEST_P(TransferAgentTest, SyncMessage)
376378
xferAgent1->invalidateRemoteAgent(agent0);
377379
}
378380

381+
// Status must survive destruction of its owning agent (#14137 UAF-safety): the
382+
// status holds a weak_ptr<nixlAgent>; once the agent is reset the weak_ptr expires
383+
// and orphaned queries must report failure rather than dereference a dangling agent.
384+
TEST_P(TransferAgentTest, StatusOutlivesAgent)
385+
{
386+
std::string const agent0{"agent0"}, agent1{"agent1"};
387+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
388+
auto xferAgent0 = makeTransferAgent(config0);
389+
auto xferAgent1 = makeTransferAgent(config1);
390+
TLLM_CHECK(xferAgent0);
391+
TLLM_CHECK(xferAgent1);
392+
393+
std::vector<char> memory0(100, 10);
394+
std::vector<char> memory1(100, 1);
395+
396+
// RegisteredHostMemory holds a raw agent pointer and deregisters in its
397+
// dtor, so it must NOT outlive its agent. Scope it (and the transfer) so it
398+
// deregisters while both agents are alive; only `status` is kept past here.
399+
std::unique_ptr<TransferStatus> status;
400+
{
401+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
402+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
403+
404+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
405+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
406+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
407+
{
408+
}
409+
410+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
411+
status = xferAgent0->submitTransferRequests(writeReq);
412+
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
413+
}
414+
415+
// Destroy the owning agent BEFORE the status. shutdown() resets the
416+
// shared_ptr<nixlAgent>, expiring the status's weak_ptr.
417+
xferAgent0.reset();
418+
419+
// Orphaned queries are safe and report failure (no use-after-free):
420+
// wait()/isCompleted() see mWeakAgent.lock() == nullptr and return
421+
// kFAILURE/false instead of dereferencing the freed agent.
422+
EXPECT_FALSE(status->isCompleted());
423+
EXPECT_EQ(status->wait(0), TransferState::kFAILURE);
424+
// `status` destructor runs at scope exit: weak_ptr.lock() == nullptr ->
425+
// early return (no releaseXferReq on a dangling agent).
426+
}
427+
428+
// Concurrent submitTransferRequests (#14137): submit holds a std::shared_lock and
429+
// copies reqParams per-request, so many threads can submit at once without racing
430+
// a shared mExtraParams. All concurrently-submitted transfers must still succeed.
431+
TEST_P(TransferAgentTest, ConcurrentSubmit)
432+
{
433+
std::string const agent0{"agent0"}, agent1{"agent1"};
434+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
435+
auto xferAgent0 = makeTransferAgent(config0);
436+
auto xferAgent1 = makeTransferAgent(config1);
437+
TLLM_CHECK(xferAgent0);
438+
TLLM_CHECK(xferAgent1);
439+
440+
std::vector<char> memory0(100, 10);
441+
std::vector<char> memory1(100, 1);
442+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
443+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
444+
445+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
446+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
447+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
448+
{
449+
}
450+
451+
constexpr int kNumThreads = 8;
452+
std::vector<std::thread> threads;
453+
std::vector<std::unique_ptr<TransferStatus>> statuses(kNumThreads);
454+
std::atomic<int> ready{0};
455+
for (int i = 0; i < kNumThreads; ++i)
456+
{
457+
threads.emplace_back(
458+
[&, i]()
459+
{
460+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
461+
ready.fetch_add(1);
462+
while (ready.load() < kNumThreads)
463+
{
464+
// Align thread starts to maximize submit contention.
465+
}
466+
statuses[i] = xferAgent0->submitTransferRequests(writeReq);
467+
});
468+
}
469+
for (auto& t : threads)
470+
{
471+
t.join();
472+
}
473+
for (auto& status : statuses)
474+
{
475+
TLLM_CHECK(status);
476+
EXPECT_EQ(status->wait(), TransferState::kSUCCESS);
477+
}
478+
TLLM_CHECK(memory0 == memory1);
479+
xferAgent0->invalidateRemoteAgent(agent1);
480+
}
481+
379482
INSTANTIATE_TEST_SUITE_P(AvailableBackends, TransferAgentTest, ::testing::ValuesIn(getAvailableBackends()),
380483
[](::testing::TestParamInfo<TransferAgentTest::ParamType> const& info) { return info.param; });
381484

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,7 @@ def request_and_receive_async(self, req: LlmRequest):
517517
def check_context_transfer_status(
518518
self, at_least_request_num: Optional[int], mark_complete: bool = False
519519
):
520-
# Skip the tp_allgather in _ctx_consensus when this transceiver never sends (pure GEN role).
521-
if not self._ever_had_send_session:
520+
if not self._ever_had_send_session and not self._ctx_need_pp_sync:
522521
return [], []
523522
block_all = at_least_request_num is None
524523
wait_num = at_least_request_num if not block_all else 0
@@ -573,8 +572,7 @@ def check_context_transfer_status(
573572
return completed, failed
574573

575574
def check_gen_transfer_status(self, at_least_request_num: Optional[int]):
576-
# Skip the allgather in _gen_consensus when this transceiver never receives (pure CTX role).
577-
if not self._ever_had_recv_session:
575+
if not self._ever_had_recv_session and not self._ctx_need_pp_sync:
578576
return [], [], []
579577
block_all = at_least_request_num is None
580578
wait_num = at_least_request_num if not block_all else 0

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 57 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,6 @@ def __init__(
517517
self.num_scheduled_requests: int = 0
518518
self.benchmark_req_queues_size = int(
519519
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
520-
self.benchmark_fill_stall_timeout_s = float(
521-
os.environ.get("TLLM_BENCHMARK_FILL_STALL_TIMEOUT_S", 60.0))
522520

523521
# list of requests in each PP micro batch
524522
self.num_micro_batches = max(self.dist.pp_size,
@@ -588,6 +586,9 @@ def __init__(
588586
def on_detected():
589587
logger.error(
590588
f"Hang detected on rank {self.global_rank} in PyExecutor.")
589+
if self._event_loop_error is None:
590+
self._event_loop_error = RuntimeError(
591+
f"Hang detected on rank {self.global_rank} in PyExecutor.")
591592
self.shutdown_event.set()
592593
self.is_shutdown = True
593594

@@ -2740,67 +2741,30 @@ def _prepare_and_schedule_batch(self):
27402741
# scheduler could not allocate KV for any of them, the benchmark
27412742
# will hang forever because in-progress generation requests won't
27422743
# release their KV cache.
2743-
#
2744-
# Only watch during the fill phase: once fill completes the count
2745-
# stays at its target value through the entire decode, which would
2746-
# otherwise look like a stall. With ADP, requests are sharded
2747-
# across TP ranks so the comparison must use the global count
2748-
# (allgather) against the global target.
2749-
if (self.is_benchmark_disagg and self._benchmark_fill_phase_active
2750-
and not self.is_warmup):
2751-
# NOTE: keep the gate condition free of any per-rank state
2752-
# (e.g. `fitting_disagg_gen_init_requests`). The
2753-
# `tp_allgather` below is a collective and every ADP rank
2754-
# must participate together; otherwise ranks desync and a
2755-
# later allgather mixes payload shapes (list[int] from
2756-
# gather_all_rank_states vs int from the gate's
2757-
# _is_benchmark_disagg_fill_complete), producing TypeErrors
2758-
# like "argument after * must be an iterable, not int" or
2759-
# "unsupported operand type(s) for +: 'int' and 'list'".
2760-
# The per-rank "still has fitting requests" hint is folded
2761-
# into the same allgather so we can suppress the stall
2762-
# check globally when any rank is still making progress.
2763-
local_ready_gen = sum(
2764-
1 for req in self.active_requests if req.state in (
2765-
LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE,
2766-
LlmRequestState.GENERATION_IN_PROGRESS,
2767-
))
2768-
local_has_fitting = 1 if fitting_disagg_gen_init_requests else 0
2769-
if self.enable_attention_dp:
2770-
responses = self.dist.tp_allgather(
2771-
[local_ready_gen, local_has_fitting])
2772-
total_ready_gen = sum(r[0] for r in responses)
2773-
any_rank_has_fitting = any(r[1] for r in responses)
2774-
else:
2775-
total_ready_gen = local_ready_gen
2776-
any_rank_has_fitting = bool(local_has_fitting)
2777-
2778-
if not any_rank_has_fitting:
2779-
now = time.time()
2780-
last_count = getattr(self, "_bench_disagg_last_gen_count",
2781-
None)
2782-
last_change_time = getattr(
2783-
self, "_bench_disagg_last_gen_count_time", None)
2784-
if (last_count != total_ready_gen
2785-
or last_change_time is None):
2786-
self._bench_disagg_last_gen_count = total_ready_gen
2787-
self._bench_disagg_last_gen_count_time = now
2788-
elif (now - last_change_time
2789-
> self.benchmark_fill_stall_timeout_s
2790-
and total_ready_gen < self.benchmark_req_queues_size):
2791-
error_msg = (
2792-
f"Benchmark gen request count stalled at "
2793-
f"{total_ready_gen} "
2794-
f"for {now - last_change_time:.0f}s "
2795-
f"(target {self.benchmark_req_queues_size}, "
2796-
f"fetched={self.num_fetch_requests}). "
2797-
f"Likely causes: KV transfer stuck, KV cache pool "
2798-
f"too small, or transceiver deadlock. Aborting all "
2799-
f"active requests.")
2800-
logger.error(error_msg)
2801-
self._handle_errors(error_msg,
2802-
requests=self.active_requests)
2803-
return None, None
2744+
if (self.benchmark_req_queues_size > 0 and not self.is_warmup
2745+
and not fitting_disagg_gen_init_requests):
2746+
stuck_init_requests = [
2747+
req for req in self.active_requests
2748+
if req.is_disagg_generation_init_state
2749+
]
2750+
# Only fail once all benchmark requests have been fetched
2751+
# so that _handle_errors covers every request and every
2752+
# client receives an error response.
2753+
if (stuck_init_requests and self.num_fetch_requests
2754+
>= self.benchmark_req_queues_size):
2755+
error_msg = (
2756+
f"Insufficient KV cache for gen-only benchmark mode: "
2757+
f"{len(stuck_init_requests)} request(s) are waiting for "
2758+
f"KV cache allocation but the scheduler could not fit "
2759+
f"any of them. Increase free_gpu_memory_fraction or "
2760+
f"reduce TLLM_BENCHMARK_REQ_QUEUES_SIZE (currently "
2761+
f"{self.benchmark_req_queues_size}).")
2762+
logger.error(error_msg)
2763+
# Fail all active and waiting requests so every
2764+
# client receives an error instead of hanging.
2765+
self._handle_errors(error_msg,
2766+
requests=self.active_requests)
2767+
return None, None
28042768

28052769
self.num_scheduled_requests = scheduled_batch.batch_size
28062770
logger.debug(
@@ -4400,27 +4364,26 @@ def _check_disagg_ctx_schedulable_status(self,
44004364
gen_first_ctx_requests)
44014365

44024366
def _count_schedulable_active_requests(self) -> int:
4403-
"""Count active requests eligible for scheduling.
4404-
4405-
Excludes GENERATION_TO_COMPLETE (V2 scheduler skips state
4406-
>= GENERATION_TO_COMPLETE) and, in disaggregated mode, requests
4407-
still awaiting KV cache transfer.
4408-
"""
4367+
"""Count active requests that are ready for scheduling.
44094368
4410-
def _is_to_complete(req) -> bool:
4411-
return req.state == LlmRequestState.GENERATION_TO_COMPLETE
4369+
In non-disaggregated mode, all active requests are schedulable.
4370+
In disaggregated mode, requests still waiting for KV cache
4371+
transfer (in INIT or transmission-in-progress state) are
4372+
excluded because they cannot participate in the forward pass
4373+
until transfer completes.
44124374
4375+
Returns:
4376+
The number of active requests eligible for scheduling.
4377+
"""
44134378
if self.kv_cache_transceiver is None:
4414-
return sum(1 for req in self.active_requests
4415-
if not _is_to_complete(req))
4379+
return len(self.active_requests)
44164380

44174381
def _is_awaiting_kv_transfer(req) -> bool:
44184382
return (req.is_disagg_generation_init_state
44194383
or req.is_disagg_generation_transmission_in_progress)
44204384

4421-
return sum(
4422-
1 for req in self.active_requests
4423-
if not _is_awaiting_kv_transfer(req) and not _is_to_complete(req))
4385+
return sum(1 for req in self.active_requests
4386+
if not _is_awaiting_kv_transfer(req))
44244387

44254388
def _should_skip_dummy_for_benchmark_disagg(
44264389
self, num_schedulable_requests: int) -> bool:
@@ -4456,14 +4419,21 @@ def _should_skip_dummy_for_benchmark_disagg(
44564419
def _update_adp_dummy_role(self, candidates: List[LlmRequest]) -> None:
44574420
if not self.enable_attention_dp or self.kv_cache_transceiver is None:
44584421
return
4422+
has_ctx = False
4423+
has_gen = False
44594424
for req in candidates:
44604425
rt = getattr(req, "llm_request_type", None)
44614426
if rt == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY:
4462-
self._adp_dummy_is_gen = False
4463-
return
4464-
if rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4465-
self._adp_dummy_is_gen = True
4466-
return
4427+
has_ctx = True
4428+
elif rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4429+
has_gen = True
4430+
# Prefer the CTX role when both types are present this iteration: a CTX
4431+
# dummy is padded to max_num_tokens so idle ranks keep MoE all-to-all
4432+
# token counts comparable with ranks doing real context work.
4433+
if has_ctx:
4434+
self._adp_dummy_is_gen = False
4435+
elif has_gen:
4436+
self._adp_dummy_is_gen = True
44674437

44684438
@nvtx_range("_pad_attention_dp_dummy_request")
44694439
def _pad_attention_dp_dummy_request(self):
@@ -4482,8 +4452,6 @@ def _pad_attention_dp_dummy_request(self):
44824452
# Other ranks have work but this rank is idle — insert a dummy so
44834453
# it can participate in collective operations during the forward pass.
44844454
if num_active_request == 0 and self.expected_num_active_requests > 0:
4485-
# Pad CTX-type dummies to max_num_tokens so the MoE all-to-all
4486-
# sees a comparable token count across ranks.
44874455
token_nums = None
44884456
if (not self._adp_dummy_is_gen
44894457
and self.kv_cache_transceiver is not None
@@ -5518,17 +5486,20 @@ def _handle_responses(self, emit_first_iter: bool = True):
55185486
self._enqueue_responses(new_responses)
55195487
for request in requests_to_terminate:
55205488
self._terminate_request(request)
5521-
if self.enable_attention_dp and self.dist.world_size != 1:
5489+
if (self.kv_cache_transceiver is not None and self.enable_attention_dp
5490+
and self.dist.world_size != 1):
55225491
any_timed_out = any(self.dist.tp_allgather(
55235492
bool(timed_out_requests)))
55245493
if any_timed_out:
55255494
self._handle_errors(error_msg="Request timed out (KV transfer)",
5526-
requests=timed_out_requests)
5495+
requests=timed_out_requests,
5496+
charge_budget=False)
55275497
else:
55285498
for req in timed_out_requests:
55295499
self._handle_errors(
55305500
error_msg=f"Request {req.py_request_id} timed out",
5531-
requests=[req])
5501+
requests=[req],
5502+
charge_budget=False)
55325503
return requests_to_terminate + requests_finished_by_transfer
55335504

55345505
def _await_any_response(self,

0 commit comments

Comments
 (0)