Skip to content

Commit cdef3d0

Browse files
committed
[None][fix] Post-rebase fixes + coverage tests for DeepSeek-V4 disaggregation
Squash of the post-cherry-pick work layered on top of the 8 DeepSeek-V4 disaggregation cherry-picks. Fixes: - ADP disagg error path: restore per-request hang signal (_event_loop_error), scan all candidates + prefer CTX role for mixed-batch dummy padding, and keep charge_budget=False on KV-transfer timeouts so they don't exhaust the global error budget and shut down the executor. - _count_schedulable_active_requests: gate the GENERATION_TO_COMPLETE exclusion on the V2 KV-cache manager. Only the V2 scheduler skips state >= GENERATION_TO_COMPLETE; the V1 scheduler still forwards those requests, so excluding them under V1 ADP undercounted and spuriously inserted an ADP dummy on top of a real request -- overflowing a small batch and tripping the mamba dummy-mask assert (n <= _dummy_request_mask_host.shape[0]) / "No free slots". Fixes test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance. - transceiver: only short-circuit the tp_allgather skip when pp_size==1 (_ctx_need_pp_sync) -- the PP>1 path asymmetrically flips send/recv markers across pipeline stages and deadlocks the _ctx_consensus pp_allgather. - py_executor: restore main's immediate benchmark fail-fast guard. - resource_manager: do NOT narrow trim_to_history's except (resize() can raise non-ValueError under v2 SWA + uneven-PP; narrowing leaked KV blocks). Tests (added to existing files): - test_py_executor.py: disagg cache-error sync + ADP no-op paths; ADP dummy-role and _pad_attention_dp_dummy_request V1/V2 GENERATION_TO_COMPLETE behavior (adp_balance regression). - test_kv_cache_v2_scheduler.py: trim_to_history. - test_cache_reuse_adapter.py: trim-to-prompt-history + transceiver ctx mgr. - test_router.py: finish_request explicit-session forwarding. - test_agent.py: BindingsNixlTransferStatus + shutdown idempotency (#14137). - transferAgentTest.cpp: status-outlives-agent (weak_ptr UAF safety) + concurrent submitTransferRequests. Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent 8b5e0d4 commit cdef3d0

9 files changed

Lines changed: 593 additions & 83 deletions

File tree

cpp/tests/unit_tests/executor/transferAgentTest.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <gmock/gmock.h>
2222
#include <gtest/gtest.h>
2323

24+
#include <atomic>
2425
#include <filesystem>
26+
#include <thread>
2527
#include <vector>
2628

2729
namespace fs = std::filesystem;
@@ -376,6 +378,107 @@ TEST_P(TransferAgentTest, SyncMessage)
376378
xferAgent1->invalidateRemoteAgent(agent0);
377379
}
378380

381+
// Status must survive destruction of its owning agent (#14137 UAF-safety): the
382+
// status holds a weak_ptr<nixlAgent>; once the agent is reset the weak_ptr expires
383+
// and orphaned queries must report failure rather than dereference a dangling agent.
384+
TEST_P(TransferAgentTest, StatusOutlivesAgent)
385+
{
386+
std::string const agent0{"agent0"}, agent1{"agent1"};
387+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
388+
auto xferAgent0 = makeTransferAgent(config0);
389+
auto xferAgent1 = makeTransferAgent(config1);
390+
TLLM_CHECK(xferAgent0);
391+
TLLM_CHECK(xferAgent1);
392+
393+
std::vector<char> memory0(100, 10);
394+
std::vector<char> memory1(100, 1);
395+
396+
// RegisteredHostMemory holds a raw agent pointer and deregisters in its
397+
// dtor, so it must NOT outlive its agent. Scope it (and the transfer) so it
398+
// deregisters while both agents are alive; only `status` is kept past here.
399+
std::unique_ptr<TransferStatus> status;
400+
{
401+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
402+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
403+
404+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
405+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
406+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
407+
{
408+
}
409+
410+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
411+
status = xferAgent0->submitTransferRequests(writeReq);
412+
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
413+
}
414+
415+
// Destroy the owning agent BEFORE the status. shutdown() resets the
416+
// shared_ptr<nixlAgent>, expiring the status's weak_ptr.
417+
xferAgent0.reset();
418+
419+
// Orphaned queries are safe and report failure (no use-after-free):
420+
// wait()/isCompleted() see mWeakAgent.lock() == nullptr and return
421+
// kFAILURE/false instead of dereferencing the freed agent.
422+
EXPECT_FALSE(status->isCompleted());
423+
EXPECT_EQ(status->wait(0), TransferState::kFAILURE);
424+
// `status` destructor runs at scope exit: weak_ptr.lock() == nullptr ->
425+
// early return (no releaseXferReq on a dangling agent).
426+
}
427+
428+
// Concurrent submitTransferRequests (#14137): submit holds a std::shared_lock and
429+
// copies reqParams per-request, so many threads can submit at once without racing
430+
// a shared mExtraParams. All concurrently-submitted transfers must still succeed.
431+
TEST_P(TransferAgentTest, ConcurrentSubmit)
432+
{
433+
std::string const agent0{"agent0"}, agent1{"agent1"};
434+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
435+
auto xferAgent0 = makeTransferAgent(config0);
436+
auto xferAgent1 = makeTransferAgent(config1);
437+
TLLM_CHECK(xferAgent0);
438+
TLLM_CHECK(xferAgent1);
439+
440+
std::vector<char> memory0(100, 10);
441+
std::vector<char> memory1(100, 1);
442+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
443+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
444+
445+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
446+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
447+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
448+
{
449+
}
450+
451+
constexpr int kNumThreads = 8;
452+
std::vector<std::thread> threads;
453+
std::vector<std::unique_ptr<TransferStatus>> statuses(kNumThreads);
454+
std::atomic<int> ready{0};
455+
for (int i = 0; i < kNumThreads; ++i)
456+
{
457+
threads.emplace_back(
458+
[&, i]()
459+
{
460+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
461+
ready.fetch_add(1);
462+
while (ready.load() < kNumThreads)
463+
{
464+
// Align thread starts to maximize submit contention.
465+
}
466+
statuses[i] = xferAgent0->submitTransferRequests(writeReq);
467+
});
468+
}
469+
for (auto& t : threads)
470+
{
471+
t.join();
472+
}
473+
for (auto& status : statuses)
474+
{
475+
TLLM_CHECK(status);
476+
EXPECT_EQ(status->wait(), TransferState::kSUCCESS);
477+
}
478+
TLLM_CHECK(memory0 == memory1);
479+
xferAgent0->invalidateRemoteAgent(agent1);
480+
}
481+
379482
INSTANTIATE_TEST_SUITE_P(AvailableBackends, TransferAgentTest, ::testing::ValuesIn(getAvailableBackends()),
380483
[](::testing::TestParamInfo<TransferAgentTest::ParamType> const& info) { return info.param; });
381484

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,12 @@ def request_and_receive_async(self, req: LlmRequest):
517517
def check_context_transfer_status(
518518
self, at_least_request_num: Optional[int], mark_complete: bool = False
519519
):
520-
# Skip the tp_allgather in _ctx_consensus when this transceiver never sends (pure GEN role).
521-
if not self._ever_had_send_session:
520+
# Skip the consensus collectives when this transceiver never sends (pure GEN role).
521+
# Guarded with pp_size==1 (not _ctx_need_pp_sync): under pipeline parallelism the
522+
# per-rank send marker flips asymmetrically across PP stages, so short-circuiting here
523+
# would let some ranks skip the pp_allgather barrier while peers enter it -> deadlock
524+
# (e.g. ADP+PP tp4_pp2_dp_both). With PP=1 there is no cross-stage consensus barrier.
525+
if not self._ever_had_send_session and not self._ctx_need_pp_sync:
522526
return [], []
523527
block_all = at_least_request_num is None
524528
wait_num = at_least_request_num if not block_all else 0
@@ -573,8 +577,11 @@ def check_context_transfer_status(
573577
return completed, failed
574578

575579
def check_gen_transfer_status(self, at_least_request_num: Optional[int]):
576-
# Skip the allgather in _gen_consensus when this transceiver never receives (pure CTX role).
577-
if not self._ever_had_recv_session:
580+
# Skip the consensus collectives when this transceiver never receives (pure CTX role).
581+
# Guarded with pp_size==1 (not _ctx_need_pp_sync): see check_context_transfer_status --
582+
# under PP the per-rank recv marker flips asymmetrically across stages, so an early
583+
# return would desync the consensus barrier; only short-circuit when PP is absent.
584+
if not self._ever_had_recv_session and not self._ctx_need_pp_sync:
578585
return [], [], []
579586
block_all = at_least_request_num is None
580587
wait_num = at_least_request_num if not block_all else 0

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,12 @@ def trim_to_history(self, req: LlmRequest, history_length: int) -> bool:
11711171
try:
11721172
return kv_cache.resize(target_capacity, history_length=history_length)
11731173
except Exception as e:
1174+
# Best-effort SWA trim: resize() can raise more than ValueError
1175+
# under v2 KV-cache + uneven-PP disagg (e.g. internal state
1176+
# assertions). A failed trim MUST degrade gracefully (return
1177+
# False) -- letting the exception propagate aborts KV-block
1178+
# release, leaking storage slots and killing the run. Do not
1179+
# narrow this except.
11741180
logger.warning(
11751181
f"trim_to_history failed for req {req.py_request_id} "
11761182
f"(capacity={kv_cache.capacity}, target_history={history_length}): {e}"

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 56 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,6 @@ def __init__(
517517
self.num_scheduled_requests: int = 0
518518
self.benchmark_req_queues_size = int(
519519
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
520-
self.benchmark_fill_stall_timeout_s = float(
521-
os.environ.get("TLLM_BENCHMARK_FILL_STALL_TIMEOUT_S", 60.0))
522520

523521
# list of requests in each PP micro batch
524522
self.num_micro_batches = max(self.dist.pp_size,
@@ -588,6 +586,14 @@ def __init__(
588586
def on_detected():
589587
logger.error(
590588
f"Hang detected on rank {self.global_rank} in PyExecutor.")
589+
# Surface a concrete error to local waiters (e.g.
590+
# _await_single_response) the same way _event_loop_wrapper does,
591+
# without calling _handle_errors here: _handle_errors triggers
592+
# tp_gather/allgather collectives, which are unsafe to run from
593+
# the hang-detector thread while the worker thread is hung.
594+
if self._event_loop_error is None:
595+
self._event_loop_error = RuntimeError(
596+
f"Hang detected on rank {self.global_rank} in PyExecutor.")
591597
self.shutdown_event.set()
592598
self.is_shutdown = True
593599

@@ -2740,67 +2746,30 @@ def _prepare_and_schedule_batch(self):
27402746
# scheduler could not allocate KV for any of them, the benchmark
27412747
# will hang forever because in-progress generation requests won't
27422748
# release their KV cache.
2743-
#
2744-
# Only watch during the fill phase: once fill completes the count
2745-
# stays at its target value through the entire decode, which would
2746-
# otherwise look like a stall. With ADP, requests are sharded
2747-
# across TP ranks so the comparison must use the global count
2748-
# (allgather) against the global target.
2749-
if (self.is_benchmark_disagg and self._benchmark_fill_phase_active
2750-
and not self.is_warmup):
2751-
# NOTE: keep the gate condition free of any per-rank state
2752-
# (e.g. `fitting_disagg_gen_init_requests`). The
2753-
# `tp_allgather` below is a collective and every ADP rank
2754-
# must participate together; otherwise ranks desync and a
2755-
# later allgather mixes payload shapes (list[int] from
2756-
# gather_all_rank_states vs int from the gate's
2757-
# _is_benchmark_disagg_fill_complete), producing TypeErrors
2758-
# like "argument after * must be an iterable, not int" or
2759-
# "unsupported operand type(s) for +: 'int' and 'list'".
2760-
# The per-rank "still has fitting requests" hint is folded
2761-
# into the same allgather so we can suppress the stall
2762-
# check globally when any rank is still making progress.
2763-
local_ready_gen = sum(
2764-
1 for req in self.active_requests if req.state in (
2765-
LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE,
2766-
LlmRequestState.GENERATION_IN_PROGRESS,
2767-
))
2768-
local_has_fitting = 1 if fitting_disagg_gen_init_requests else 0
2769-
if self.enable_attention_dp:
2770-
responses = self.dist.tp_allgather(
2771-
[local_ready_gen, local_has_fitting])
2772-
total_ready_gen = sum(r[0] for r in responses)
2773-
any_rank_has_fitting = any(r[1] for r in responses)
2774-
else:
2775-
total_ready_gen = local_ready_gen
2776-
any_rank_has_fitting = bool(local_has_fitting)
2777-
2778-
if not any_rank_has_fitting:
2779-
now = time.time()
2780-
last_count = getattr(self, "_bench_disagg_last_gen_count",
2781-
None)
2782-
last_change_time = getattr(
2783-
self, "_bench_disagg_last_gen_count_time", None)
2784-
if (last_count != total_ready_gen
2785-
or last_change_time is None):
2786-
self._bench_disagg_last_gen_count = total_ready_gen
2787-
self._bench_disagg_last_gen_count_time = now
2788-
elif (now - last_change_time
2789-
> self.benchmark_fill_stall_timeout_s
2790-
and total_ready_gen < self.benchmark_req_queues_size):
2791-
error_msg = (
2792-
f"Benchmark gen request count stalled at "
2793-
f"{total_ready_gen} "
2794-
f"for {now - last_change_time:.0f}s "
2795-
f"(target {self.benchmark_req_queues_size}, "
2796-
f"fetched={self.num_fetch_requests}). "
2797-
f"Likely causes: KV transfer stuck, KV cache pool "
2798-
f"too small, or transceiver deadlock. Aborting all "
2799-
f"active requests.")
2800-
logger.error(error_msg)
2801-
self._handle_errors(error_msg,
2802-
requests=self.active_requests)
2803-
return None, None
2749+
if (self.benchmark_req_queues_size > 0 and not self.is_warmup
2750+
and not fitting_disagg_gen_init_requests):
2751+
stuck_init_requests = [
2752+
req for req in self.active_requests
2753+
if req.is_disagg_generation_init_state
2754+
]
2755+
# Only fail once all benchmark requests have been fetched
2756+
# so that _handle_errors covers every request and every
2757+
# client receives an error response.
2758+
if (stuck_init_requests and self.num_fetch_requests
2759+
>= self.benchmark_req_queues_size):
2760+
error_msg = (
2761+
f"Insufficient KV cache for gen-only benchmark mode: "
2762+
f"{len(stuck_init_requests)} request(s) are waiting for "
2763+
f"KV cache allocation but the scheduler could not fit "
2764+
f"any of them. Increase free_gpu_memory_fraction or "
2765+
f"reduce TLLM_BENCHMARK_REQ_QUEUES_SIZE (currently "
2766+
f"{self.benchmark_req_queues_size}).")
2767+
logger.error(error_msg)
2768+
# Fail all active and waiting requests so every
2769+
# client receives an error instead of hanging.
2770+
self._handle_errors(error_msg,
2771+
requests=self.active_requests)
2772+
return None, None
28042773

28052774
self.num_scheduled_requests = scheduled_batch.batch_size
28062775
logger.debug(
@@ -4402,13 +4371,17 @@ def _check_disagg_ctx_schedulable_status(self,
44024371
def _count_schedulable_active_requests(self) -> int:
44034372
"""Count active requests eligible for scheduling.
44044373
4405-
Excludes GENERATION_TO_COMPLETE (V2 scheduler skips state
4406-
>= GENERATION_TO_COMPLETE) and, in disaggregated mode, requests
4407-
still awaiting KV cache transfer.
4374+
Excludes GENERATION_TO_COMPLETE only under the V2 KV-cache manager,
4375+
whose scheduler skips state >= GENERATION_TO_COMPLETE. The V1
4376+
scheduler still forwards those requests, so excluding them there
4377+
would undercount and spuriously insert an ADP dummy on top of a real
4378+
request -- overflowing a small batch (e.g. max_batch_size=1). In
4379+
disaggregated mode, also exclude requests still awaiting KV transfer.
44084380
"""
44094381

44104382
def _is_to_complete(req) -> bool:
4411-
return req.state == LlmRequestState.GENERATION_TO_COMPLETE
4383+
return (self._is_kv_manager_v2
4384+
and req.state == LlmRequestState.GENERATION_TO_COMPLETE)
44124385

44134386
if self.kv_cache_transceiver is None:
44144387
return sum(1 for req in self.active_requests
@@ -4456,14 +4429,21 @@ def _should_skip_dummy_for_benchmark_disagg(
44564429
def _update_adp_dummy_role(self, candidates: List[LlmRequest]) -> None:
44574430
if not self.enable_attention_dp or self.kv_cache_transceiver is None:
44584431
return
4432+
has_ctx = False
4433+
has_gen = False
44594434
for req in candidates:
44604435
rt = getattr(req, "llm_request_type", None)
44614436
if rt == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY:
4462-
self._adp_dummy_is_gen = False
4463-
return
4464-
if rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4465-
self._adp_dummy_is_gen = True
4466-
return
4437+
has_ctx = True
4438+
elif rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4439+
has_gen = True
4440+
# Prefer the CTX role when both types are present this iteration: a CTX
4441+
# dummy is padded to max_num_tokens so idle ranks keep MoE all-to-all
4442+
# token counts comparable with ranks doing real context work.
4443+
if has_ctx:
4444+
self._adp_dummy_is_gen = False
4445+
elif has_gen:
4446+
self._adp_dummy_is_gen = True
44674447

44684448
@nvtx_range("_pad_attention_dp_dummy_request")
44694449
def _pad_attention_dp_dummy_request(self):
@@ -5523,12 +5503,14 @@ def _handle_responses(self, emit_first_iter: bool = True):
55235503
bool(timed_out_requests)))
55245504
if any_timed_out:
55255505
self._handle_errors(error_msg="Request timed out (KV transfer)",
5526-
requests=timed_out_requests)
5506+
requests=timed_out_requests,
5507+
charge_budget=False)
55275508
else:
55285509
for req in timed_out_requests:
55295510
self._handle_errors(
55305511
error_msg=f"Request {req.py_request_id} timed out",
5531-
requests=[req])
5512+
requests=[req],
5513+
charge_budget=False)
55325514
return requests_to_terminate + requests_finished_by_transfer
55335515

55345516
def _await_any_response(self,

0 commit comments

Comments
 (0)