Skip to content

Commit a8cff5d

Browse files
committed
[None][fix] Post-rebase fixes + coverage tests for DeepSeek-V4 disaggregation
Squash of the post-cherry-pick work layered on top of the 8 DeepSeek-V4 disaggregation cherry-picks. Fixes: - ADP disagg error path: restore per-request hang signal (_event_loop_error), scan all candidates + prefer CTX role for mixed-batch dummy padding, and keep charge_budget=False on KV-transfer timeouts so they don't exhaust the global error budget and shut down the executor. - transceiver: only short-circuit the tp_allgather skip when pp_size==1 (_ctx_need_pp_sync) -- the PP>1 path asymmetrically flips send/recv markers across pipeline stages and deadlocks the _ctx_consensus pp_allgather. - py_executor: restore main's immediate benchmark fail-fast guard. - resource_manager: do NOT narrow trim_to_history's except (resize() can raise non-ValueError under v2 SWA + uneven-PP; narrowing leaked KV blocks). Tests (added to existing files): - test_py_executor.py: disagg cache-error sync + ADP no-op paths. - test_kv_cache_v2_scheduler.py: trim_to_history. - test_cache_reuse_adapter.py: trim-to-prompt-history + transceiver ctx mgr. - test_router.py: finish_request explicit-session forwarding. - test_agent.py: BindingsNixlTransferStatus + shutdown idempotency (#14137). - transferAgentTest.cpp: status-outlives-agent (weak_ptr UAF safety) + concurrent submitTransferRequests. Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent e190e96 commit a8cff5d

9 files changed

Lines changed: 556 additions & 75 deletions

File tree

cpp/tests/unit_tests/executor/transferAgentTest.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <gmock/gmock.h>
2222
#include <gtest/gtest.h>
2323

24+
#include <atomic>
2425
#include <filesystem>
26+
#include <thread>
2527
#include <vector>
2628

2729
namespace fs = std::filesystem;
@@ -376,6 +378,107 @@ TEST_P(TransferAgentTest, SyncMessage)
376378
xferAgent1->invalidateRemoteAgent(agent0);
377379
}
378380

381+
// Status must survive destruction of its owning agent (#14137 UAF-safety): the
382+
// status holds a weak_ptr<nixlAgent>; once the agent is reset the weak_ptr expires
383+
// and orphaned queries must report failure rather than dereference a dangling agent.
384+
TEST_P(TransferAgentTest, StatusOutlivesAgent)
385+
{
386+
std::string const agent0{"agent0"}, agent1{"agent1"};
387+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
388+
auto xferAgent0 = makeTransferAgent(config0);
389+
auto xferAgent1 = makeTransferAgent(config1);
390+
TLLM_CHECK(xferAgent0);
391+
TLLM_CHECK(xferAgent1);
392+
393+
std::vector<char> memory0(100, 10);
394+
std::vector<char> memory1(100, 1);
395+
396+
// RegisteredHostMemory holds a raw agent pointer and deregisters in its
397+
// dtor, so it must NOT outlive its agent. Scope it (and the transfer) so it
398+
// deregisters while both agents are alive; only `status` is kept past here.
399+
std::unique_ptr<TransferStatus> status;
400+
{
401+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
402+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
403+
404+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
405+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
406+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
407+
{
408+
}
409+
410+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
411+
status = xferAgent0->submitTransferRequests(writeReq);
412+
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
413+
}
414+
415+
// Destroy the owning agent BEFORE the status. shutdown() resets the
416+
// shared_ptr<nixlAgent>, expiring the status's weak_ptr.
417+
xferAgent0.reset();
418+
419+
// Orphaned queries are safe and report failure (no use-after-free):
420+
// wait()/isCompleted() see mWeakAgent.lock() == nullptr and return
421+
// kFAILURE/false instead of dereferencing the freed agent.
422+
EXPECT_FALSE(status->isCompleted());
423+
EXPECT_EQ(status->wait(0), TransferState::kFAILURE);
424+
// `status` destructor runs at scope exit: weak_ptr.lock() == nullptr ->
425+
// early return (no releaseXferReq on a dangling agent).
426+
}
427+
428+
// Concurrent submitTransferRequests (#14137): submit holds a std::shared_lock and
429+
// copies reqParams per-request, so many threads can submit at once without racing
430+
// a shared mExtraParams. All concurrently-submitted transfers must still succeed.
431+
TEST_P(TransferAgentTest, ConcurrentSubmit)
432+
{
433+
std::string const agent0{"agent0"}, agent1{"agent1"};
434+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
435+
auto xferAgent0 = makeTransferAgent(config0);
436+
auto xferAgent1 = makeTransferAgent(config1);
437+
TLLM_CHECK(xferAgent0);
438+
TLLM_CHECK(xferAgent1);
439+
440+
std::vector<char> memory0(100, 10);
441+
std::vector<char> memory1(100, 1);
442+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
443+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
444+
445+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
446+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
447+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
448+
{
449+
}
450+
451+
constexpr int kNumThreads = 8;
452+
std::vector<std::thread> threads;
453+
std::vector<std::unique_ptr<TransferStatus>> statuses(kNumThreads);
454+
std::atomic<int> ready{0};
455+
for (int i = 0; i < kNumThreads; ++i)
456+
{
457+
threads.emplace_back(
458+
[&, i]()
459+
{
460+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
461+
ready.fetch_add(1);
462+
while (ready.load() < kNumThreads)
463+
{
464+
// Align thread starts to maximize submit contention.
465+
}
466+
statuses[i] = xferAgent0->submitTransferRequests(writeReq);
467+
});
468+
}
469+
for (auto& t : threads)
470+
{
471+
t.join();
472+
}
473+
for (auto& status : statuses)
474+
{
475+
TLLM_CHECK(status);
476+
EXPECT_EQ(status->wait(), TransferState::kSUCCESS);
477+
}
478+
TLLM_CHECK(memory0 == memory1);
479+
xferAgent0->invalidateRemoteAgent(agent1);
480+
}
481+
379482
INSTANTIATE_TEST_SUITE_P(AvailableBackends, TransferAgentTest, ::testing::ValuesIn(getAvailableBackends()),
380483
[](::testing::TestParamInfo<TransferAgentTest::ParamType> const& info) { return info.param; });
381484

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,12 @@ def request_and_receive_async(self, req: LlmRequest):
517517
def check_context_transfer_status(
518518
self, at_least_request_num: Optional[int], mark_complete: bool = False
519519
):
520-
# Skip the tp_allgather in _ctx_consensus when this transceiver never sends (pure GEN role).
521-
if not self._ever_had_send_session:
520+
# Skip the consensus collectives when this transceiver never sends (pure GEN role).
521+
# Guarded with pp_size==1 (not _ctx_need_pp_sync): under pipeline parallelism the
522+
# per-rank send marker flips asymmetrically across PP stages, so short-circuiting here
523+
# would let some ranks skip the pp_allgather barrier while peers enter it -> deadlock
524+
# (e.g. ADP+PP tp4_pp2_dp_both). With PP=1 there is no cross-stage consensus barrier.
525+
if not self._ever_had_send_session and not self._ctx_need_pp_sync:
522526
return [], []
523527
block_all = at_least_request_num is None
524528
wait_num = at_least_request_num if not block_all else 0
@@ -573,8 +577,11 @@ def check_context_transfer_status(
573577
return completed, failed
574578

575579
def check_gen_transfer_status(self, at_least_request_num: Optional[int]):
576-
# Skip the allgather in _gen_consensus when this transceiver never receives (pure CTX role).
577-
if not self._ever_had_recv_session:
580+
# Skip the consensus collectives when this transceiver never receives (pure CTX role).
581+
# Guarded with pp_size==1 (not _ctx_need_pp_sync): see check_context_transfer_status --
582+
# under PP the per-rank recv marker flips asymmetrically across stages, so an early
583+
# return would desync the consensus barrier; only short-circuit when PP is absent.
584+
if not self._ever_had_recv_session and not self._ctx_need_pp_sync:
578585
return [], [], []
579586
block_all = at_least_request_num is None
580587
wait_num = at_least_request_num if not block_all else 0

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,12 @@ def trim_to_history(self, req: LlmRequest, history_length: int) -> bool:
10191019
try:
10201020
return kv_cache.resize(target_capacity, history_length=history_length)
10211021
except Exception as e:
1022+
# Best-effort SWA trim: resize() can raise more than ValueError
1023+
# under v2 KV-cache + uneven-PP disagg (e.g. internal state
1024+
# assertions). A failed trim MUST degrade gracefully (return
1025+
# False) -- letting the exception propagate aborts KV-block
1026+
# release, leaking storage slots and killing the run. Do not
1027+
# narrow this except.
10221028
logger.warning(
10231029
f"trim_to_history failed for req {req.py_request_id} "
10241030
f"(capacity={kv_cache.capacity}, target_history={history_length}): {e}"

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 48 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,6 @@ def __init__(
505505
self.num_scheduled_requests: int = 0
506506
self.benchmark_req_queues_size = int(
507507
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
508-
self.benchmark_fill_stall_timeout_s = float(
509-
os.environ.get("TLLM_BENCHMARK_FILL_STALL_TIMEOUT_S", 60.0))
510508

511509
# list of requests in each PP micro batch
512510
self.num_micro_batches = max(self.dist.pp_size,
@@ -576,6 +574,14 @@ def __init__(
576574
def on_detected():
577575
logger.error(
578576
f"Hang detected on rank {self.global_rank} in PyExecutor.")
577+
# Surface a concrete error to local waiters (e.g.
578+
# _await_single_response) the same way _event_loop_wrapper does,
579+
# without calling _handle_errors here: _handle_errors triggers
580+
# tp_gather/allgather collectives, which are unsafe to run from
581+
# the hang-detector thread while the worker thread is hung.
582+
if self._event_loop_error is None:
583+
self._event_loop_error = RuntimeError(
584+
f"Hang detected on rank {self.global_rank} in PyExecutor.")
579585
self.shutdown_event.set()
580586
self.is_shutdown = True
581587

@@ -2673,67 +2679,30 @@ def _prepare_and_schedule_batch(self):
26732679
# scheduler could not allocate KV for any of them, the benchmark
26742680
# will hang forever because in-progress generation requests won't
26752681
# release their KV cache.
2676-
#
2677-
# Only watch during the fill phase: once fill completes the count
2678-
# stays at its target value through the entire decode, which would
2679-
# otherwise look like a stall. With ADP, requests are sharded
2680-
# across TP ranks so the comparison must use the global count
2681-
# (allgather) against the global target.
2682-
if (self.is_benchmark_disagg and self._benchmark_fill_phase_active
2683-
and not self.is_warmup):
2684-
# NOTE: keep the gate condition free of any per-rank state
2685-
# (e.g. `fitting_disagg_gen_init_requests`). The
2686-
# `tp_allgather` below is a collective and every ADP rank
2687-
# must participate together; otherwise ranks desync and a
2688-
# later allgather mixes payload shapes (list[int] from
2689-
# gather_all_rank_states vs int from the gate's
2690-
# _is_benchmark_disagg_fill_complete), producing TypeErrors
2691-
# like "argument after * must be an iterable, not int" or
2692-
# "unsupported operand type(s) for +: 'int' and 'list'".
2693-
# The per-rank "still has fitting requests" hint is folded
2694-
# into the same allgather so we can suppress the stall
2695-
# check globally when any rank is still making progress.
2696-
local_ready_gen = sum(
2697-
1 for req in self.active_requests if req.state in (
2698-
LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE,
2699-
LlmRequestState.GENERATION_IN_PROGRESS,
2700-
))
2701-
local_has_fitting = 1 if fitting_disagg_gen_init_requests else 0
2702-
if self.enable_attention_dp:
2703-
responses = self.dist.tp_allgather(
2704-
[local_ready_gen, local_has_fitting])
2705-
total_ready_gen = sum(r[0] for r in responses)
2706-
any_rank_has_fitting = any(r[1] for r in responses)
2707-
else:
2708-
total_ready_gen = local_ready_gen
2709-
any_rank_has_fitting = bool(local_has_fitting)
2710-
2711-
if not any_rank_has_fitting:
2712-
now = time.time()
2713-
last_count = getattr(self, "_bench_disagg_last_gen_count",
2714-
None)
2715-
last_change_time = getattr(
2716-
self, "_bench_disagg_last_gen_count_time", None)
2717-
if (last_count != total_ready_gen
2718-
or last_change_time is None):
2719-
self._bench_disagg_last_gen_count = total_ready_gen
2720-
self._bench_disagg_last_gen_count_time = now
2721-
elif (now - last_change_time
2722-
> self.benchmark_fill_stall_timeout_s
2723-
and total_ready_gen < self.benchmark_req_queues_size):
2724-
error_msg = (
2725-
f"Benchmark gen request count stalled at "
2726-
f"{total_ready_gen} "
2727-
f"for {now - last_change_time:.0f}s "
2728-
f"(target {self.benchmark_req_queues_size}, "
2729-
f"fetched={self.num_fetch_requests}). "
2730-
f"Likely causes: KV transfer stuck, KV cache pool "
2731-
f"too small, or transceiver deadlock. Aborting all "
2732-
f"active requests.")
2733-
logger.error(error_msg)
2734-
self._handle_errors(error_msg,
2735-
requests=self.active_requests)
2736-
return None, None
2682+
if (self.benchmark_req_queues_size > 0 and not self.is_warmup
2683+
and not fitting_disagg_gen_init_requests):
2684+
stuck_init_requests = [
2685+
req for req in self.active_requests
2686+
if req.is_disagg_generation_init_state
2687+
]
2688+
# Only fail once all benchmark requests have been fetched
2689+
# so that _handle_errors covers every request and every
2690+
# client receives an error response.
2691+
if (stuck_init_requests and self.num_fetch_requests
2692+
>= self.benchmark_req_queues_size):
2693+
error_msg = (
2694+
f"Insufficient KV cache for gen-only benchmark mode: "
2695+
f"{len(stuck_init_requests)} request(s) are waiting for "
2696+
f"KV cache allocation but the scheduler could not fit "
2697+
f"any of them. Increase free_gpu_memory_fraction or "
2698+
f"reduce TLLM_BENCHMARK_REQ_QUEUES_SIZE (currently "
2699+
f"{self.benchmark_req_queues_size}).")
2700+
logger.error(error_msg)
2701+
# Fail all active and waiting requests so every
2702+
# client receives an error instead of hanging.
2703+
self._handle_errors(error_msg,
2704+
requests=self.active_requests)
2705+
return None, None
27372706

27382707
self.num_scheduled_requests = scheduled_batch.batch_size
27392708
logger.debug(
@@ -4233,14 +4202,21 @@ def _should_skip_dummy_for_benchmark_disagg(
42334202
def _update_adp_dummy_role(self, candidates: List[LlmRequest]) -> None:
42344203
if not self.enable_attention_dp or self.kv_cache_transceiver is None:
42354204
return
4205+
has_ctx = False
4206+
has_gen = False
42364207
for req in candidates:
42374208
rt = getattr(req, "llm_request_type", None)
42384209
if rt == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY:
4239-
self._adp_dummy_is_gen = False
4240-
return
4241-
if rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4242-
self._adp_dummy_is_gen = True
4243-
return
4210+
has_ctx = True
4211+
elif rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4212+
has_gen = True
4213+
# Prefer the CTX role when both types are present this iteration: a CTX
4214+
# dummy is padded to max_num_tokens so idle ranks keep MoE all-to-all
4215+
# token counts comparable with ranks doing real context work.
4216+
if has_ctx:
4217+
self._adp_dummy_is_gen = False
4218+
elif has_gen:
4219+
self._adp_dummy_is_gen = True
42444220

42454221
@nvtx_range("_pad_attention_dp_dummy_request")
42464222
def _pad_attention_dp_dummy_request(self):
@@ -5231,12 +5207,14 @@ def _handle_responses(self, emit_first_iter: bool = True):
52315207
bool(timed_out_requests)))
52325208
if any_timed_out:
52335209
self._handle_errors(error_msg="Request timed out (KV transfer)",
5234-
requests=timed_out_requests)
5210+
requests=timed_out_requests,
5211+
charge_budget=False)
52355212
else:
52365213
for req in timed_out_requests:
52375214
self._handle_errors(
52385215
error_msg=f"Request {req.py_request_id} timed out",
5239-
requests=[req])
5216+
requests=[req],
5217+
charge_budget=False)
52405218
return requests_to_terminate + requests_finished_by_transfer
52415219

52425220
def _await_any_response(self,

tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
No GPU required.
1919
"""
2020

21+
from types import SimpleNamespace
2122
from unittest.mock import Mock, patch
2223

2324
import pytest
2425

26+
from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2
2527
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
2628
from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy
2729

@@ -2310,3 +2312,49 @@ def track_resize(req, n):
23102312
out = sched.schedule_request([req], set())
23112313
assert ids(out.context_requests) == []
23122314
assert resize_calls == [] # SKIP path: no commit to KV cache
2315+
2316+
2317+
# ---------------------------------------------------------------------------
2318+
# KVCacheManagerV2.trim_to_history (#14258 disagg-gen SWA history pre-declaration):
2319+
# called unbound on a fake self to cover all 5 branches without the heavy __init__.
2320+
# ---------------------------------------------------------------------------
2321+
class TestTrimToHistory:
2322+
@staticmethod
2323+
def _call(kv_cache, history_length, req_id=1):
2324+
kv_cache_map = {} if kv_cache is None else {req_id: kv_cache}
2325+
fake = SimpleNamespace(kv_cache_map=kv_cache_map)
2326+
req = SimpleNamespace(py_request_id=req_id)
2327+
return KVCacheManagerV2.trim_to_history(fake, req, history_length)
2328+
2329+
def test_missing_cache_is_noop_true(self):
2330+
assert self._call(None, 50) is True
2331+
2332+
def test_inactive_cache_is_noop_true(self):
2333+
kv = Mock(is_active=False)
2334+
assert self._call(kv, 50) is True
2335+
kv.resize.assert_not_called()
2336+
2337+
def test_history_not_increasing_is_noop_true(self):
2338+
kv = Mock(is_active=True, history_length=50, capacity=100)
2339+
assert self._call(kv, 50) is True # 50 <= current 50
2340+
kv.resize.assert_not_called()
2341+
2342+
def test_resize_success_clamps_capacity_and_returns_true(self):
2343+
kv = Mock(is_active=True, history_length=10, capacity=8)
2344+
kv.resize.return_value = True
2345+
assert self._call(kv, 64) is True
2346+
# target_capacity = max(capacity=8, history=64) = 64
2347+
kv.resize.assert_called_once_with(64, history_length=64)
2348+
2349+
def test_resize_rejection_returns_false(self):
2350+
kv = Mock(is_active=True, history_length=10, capacity=100)
2351+
kv.resize.return_value = False
2352+
assert self._call(kv, 64) is False
2353+
kv.resize.assert_called_once_with(100, history_length=64)
2354+
2355+
def test_resize_exception_degrades_to_false(self):
2356+
# Broad except: a non-ValueError (e.g. internal state assert) must
2357+
# degrade to False rather than propagate (do-not-narrow contract).
2358+
kv = Mock(is_active=True, history_length=10, capacity=100)
2359+
kv.resize.side_effect = RuntimeError("internal state assert")
2360+
assert self._call(kv, 64) is False

0 commit comments

Comments
 (0)