Skip to content

Commit 86bb5ce

Browse files
committed
[None][fix] Post-rebase fixes + coverage tests for DeepSeek-V4 disaggregation
Squash of the post-cherry-pick work layered on top of the 8 DeepSeek-V4 disaggregation cherry-picks. Fixes: - ADP disagg error path: restore per-request hang signal (_event_loop_error), scan all candidates + prefer CTX role for mixed-batch dummy padding, and keep charge_budget=False on KV-transfer timeouts so they don't exhaust the global error budget and shut down the executor. - transceiver: only short-circuit the tp_allgather skip when pp_size==1 (_ctx_need_pp_sync) -- the PP>1 path asymmetrically flips send/recv markers across pipeline stages and deadlocks the _ctx_consensus pp_allgather. - py_executor: restore main's immediate benchmark fail-fast guard. - resource_manager: do NOT narrow trim_to_history's except (resize() can raise non-ValueError under v2 SWA + uneven-PP; narrowing leaked KV blocks). Tests (added to existing files): - test_py_executor.py: disagg cache-error sync + ADP no-op paths. - test_kv_cache_v2_scheduler.py: trim_to_history. - test_cache_reuse_adapter.py: trim-to-prompt-history + transceiver ctx mgr. - test_router.py: finish_request explicit-session forwarding. - test_agent.py: BindingsNixlTransferStatus + shutdown idempotency (#14137). - transferAgentTest.cpp: status-outlives-agent (weak_ptr UAF safety) + concurrent submitTransferRequests. Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent 326f280 commit 86bb5ce

9 files changed

Lines changed: 556 additions & 75 deletions

File tree

cpp/tests/unit_tests/executor/transferAgentTest.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <gmock/gmock.h>
2222
#include <gtest/gtest.h>
2323

24+
#include <atomic>
2425
#include <filesystem>
26+
#include <thread>
2527
#include <vector>
2628

2729
namespace fs = std::filesystem;
@@ -376,6 +378,107 @@ TEST_P(TransferAgentTest, SyncMessage)
376378
xferAgent1->invalidateRemoteAgent(agent0);
377379
}
378380

381+
// Status must survive destruction of its owning agent (#14137 UAF-safety): the
382+
// status holds a weak_ptr<nixlAgent>; once the agent is reset the weak_ptr expires
383+
// and orphaned queries must report failure rather than dereference a dangling agent.
384+
TEST_P(TransferAgentTest, StatusOutlivesAgent)
385+
{
386+
std::string const agent0{"agent0"}, agent1{"agent1"};
387+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
388+
auto xferAgent0 = makeTransferAgent(config0);
389+
auto xferAgent1 = makeTransferAgent(config1);
390+
TLLM_CHECK(xferAgent0);
391+
TLLM_CHECK(xferAgent1);
392+
393+
std::vector<char> memory0(100, 10);
394+
std::vector<char> memory1(100, 1);
395+
396+
// RegisteredHostMemory holds a raw agent pointer and deregisters in its
397+
// dtor, so it must NOT outlive its agent. Scope it (and the transfer) so it
398+
// deregisters while both agents are alive; only `status` is kept past here.
399+
std::unique_ptr<TransferStatus> status;
400+
{
401+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
402+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
403+
404+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
405+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
406+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
407+
{
408+
}
409+
410+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
411+
status = xferAgent0->submitTransferRequests(writeReq);
412+
TLLM_CHECK(status->wait() == TransferState::kSUCCESS);
413+
}
414+
415+
// Destroy the owning agent BEFORE the status. shutdown() resets the
416+
// shared_ptr<nixlAgent>, expiring the status's weak_ptr.
417+
xferAgent0.reset();
418+
419+
// Orphaned queries are safe and report failure (no use-after-free):
420+
// wait()/isCompleted() see mWeakAgent.lock() == nullptr and return
421+
// kFAILURE/false instead of dereferencing the freed agent.
422+
EXPECT_FALSE(status->isCompleted());
423+
EXPECT_EQ(status->wait(0), TransferState::kFAILURE);
424+
// `status` destructor runs at scope exit: weak_ptr.lock() == nullptr ->
425+
// early return (no releaseXferReq on a dangling agent).
426+
}
427+
428+
// Concurrent submitTransferRequests (#14137): submit holds a std::shared_lock and
429+
// copies reqParams per-request, so many threads can submit at once without racing
430+
// a shared mExtraParams. All concurrently-submitted transfers must still succeed.
431+
TEST_P(TransferAgentTest, ConcurrentSubmit)
432+
{
433+
std::string const agent0{"agent0"}, agent1{"agent1"};
434+
BaseAgentConfig config0{agent0, true, false, true}, config1{agent1, true, false, true};
435+
auto xferAgent0 = makeTransferAgent(config0);
436+
auto xferAgent1 = makeTransferAgent(config1);
437+
TLLM_CHECK(xferAgent0);
438+
TLLM_CHECK(xferAgent1);
439+
440+
std::vector<char> memory0(100, 10);
441+
std::vector<char> memory1(100, 1);
442+
RegisteredHostMemory regMem0(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory0}}}, xferAgent0.get());
443+
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, xferAgent1.get());
444+
445+
auto connectionInfo = xferAgent1->getLocalConnectionInfo();
446+
xferAgent0->loadRemoteAgent(agent1, connectionInfo);
447+
while (!xferAgent0->checkRemoteDescs(agent1, regMem1.getDescs()))
448+
{
449+
}
450+
451+
constexpr int kNumThreads = 8;
452+
std::vector<std::thread> threads;
453+
std::vector<std::unique_ptr<TransferStatus>> statuses(kNumThreads);
454+
std::atomic<int> ready{0};
455+
for (int i = 0; i < kNumThreads; ++i)
456+
{
457+
threads.emplace_back(
458+
[&, i]()
459+
{
460+
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem1.getDescs(), agent1};
461+
ready.fetch_add(1);
462+
while (ready.load() < kNumThreads)
463+
{
464+
// Align thread starts to maximize submit contention.
465+
}
466+
statuses[i] = xferAgent0->submitTransferRequests(writeReq);
467+
});
468+
}
469+
for (auto& t : threads)
470+
{
471+
t.join();
472+
}
473+
for (auto& status : statuses)
474+
{
475+
TLLM_CHECK(status);
476+
EXPECT_EQ(status->wait(), TransferState::kSUCCESS);
477+
}
478+
TLLM_CHECK(memory0 == memory1);
479+
xferAgent0->invalidateRemoteAgent(agent1);
480+
}
481+
379482
INSTANTIATE_TEST_SUITE_P(AvailableBackends, TransferAgentTest, ::testing::ValuesIn(getAvailableBackends()),
380483
[](::testing::TestParamInfo<TransferAgentTest::ParamType> const& info) { return info.param; });
381484

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,12 @@ def request_and_receive_async(self, req: LlmRequest):
517517
def check_context_transfer_status(
518518
self, at_least_request_num: Optional[int], mark_complete: bool = False
519519
):
520-
# Skip the tp_allgather in _ctx_consensus when this transceiver never sends (pure GEN role).
521-
if not self._ever_had_send_session:
520+
# Skip the consensus collectives when this transceiver never sends (pure GEN role).
521+
# Guarded with pp_size==1 (not _ctx_need_pp_sync): under pipeline parallelism the
522+
# per-rank send marker flips asymmetrically across PP stages, so short-circuiting here
523+
# would let some ranks skip the pp_allgather barrier while peers enter it -> deadlock
524+
# (e.g. ADP+PP tp4_pp2_dp_both). With PP=1 there is no cross-stage consensus barrier.
525+
if not self._ever_had_send_session and not self._ctx_need_pp_sync:
522526
return [], []
523527
block_all = at_least_request_num is None
524528
wait_num = at_least_request_num if not block_all else 0
@@ -573,8 +577,11 @@ def check_context_transfer_status(
573577
return completed, failed
574578

575579
def check_gen_transfer_status(self, at_least_request_num: Optional[int]):
576-
# Skip the allgather in _gen_consensus when this transceiver never receives (pure CTX role).
577-
if not self._ever_had_recv_session:
580+
# Skip the consensus collectives when this transceiver never receives (pure CTX role).
581+
# Guarded with pp_size==1 (not _ctx_need_pp_sync): see check_context_transfer_status --
582+
# under PP the per-rank recv marker flips asymmetrically across stages, so an early
583+
# return would desync the consensus barrier; only short-circuit when PP is absent.
584+
if not self._ever_had_recv_session and not self._ctx_need_pp_sync:
578585
return [], [], []
579586
block_all = at_least_request_num is None
580587
wait_num = at_least_request_num if not block_all else 0

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,12 @@ def trim_to_history(self, req: LlmRequest, history_length: int) -> bool:
10201020
try:
10211021
return kv_cache.resize(target_capacity, history_length=history_length)
10221022
except Exception as e:
1023+
# Best-effort SWA trim: resize() can raise more than ValueError
1024+
# under v2 KV-cache + uneven-PP disagg (e.g. internal state
1025+
# assertions). A failed trim MUST degrade gracefully (return
1026+
# False) -- letting the exception propagate aborts KV-block
1027+
# release, leaking storage slots and killing the run. Do not
1028+
# narrow this except.
10231029
logger.warning(
10241030
f"trim_to_history failed for req {req.py_request_id} "
10251031
f"(capacity={kv_cache.capacity}, target_history={history_length}): {e}"

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 48 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,6 @@ def __init__(
515515
self.num_scheduled_requests: int = 0
516516
self.benchmark_req_queues_size = int(
517517
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
518-
self.benchmark_fill_stall_timeout_s = float(
519-
os.environ.get("TLLM_BENCHMARK_FILL_STALL_TIMEOUT_S", 60.0))
520518

521519
# list of requests in each PP micro batch
522520
self.num_micro_batches = max(self.dist.pp_size,
@@ -586,6 +584,14 @@ def __init__(
586584
def on_detected():
587585
logger.error(
588586
f"Hang detected on rank {self.global_rank} in PyExecutor.")
587+
# Surface a concrete error to local waiters (e.g.
588+
# _await_single_response) the same way _event_loop_wrapper does,
589+
# without calling _handle_errors here: _handle_errors triggers
590+
# tp_gather/allgather collectives, which are unsafe to run from
591+
# the hang-detector thread while the worker thread is hung.
592+
if self._event_loop_error is None:
593+
self._event_loop_error = RuntimeError(
594+
f"Hang detected on rank {self.global_rank} in PyExecutor.")
589595
self.shutdown_event.set()
590596
self.is_shutdown = True
591597

@@ -2738,67 +2744,30 @@ def _prepare_and_schedule_batch(self):
27382744
# scheduler could not allocate KV for any of them, the benchmark
27392745
# will hang forever because in-progress generation requests won't
27402746
# release their KV cache.
2741-
#
2742-
# Only watch during the fill phase: once fill completes the count
2743-
# stays at its target value through the entire decode, which would
2744-
# otherwise look like a stall. With ADP, requests are sharded
2745-
# across TP ranks so the comparison must use the global count
2746-
# (allgather) against the global target.
2747-
if (self.is_benchmark_disagg and self._benchmark_fill_phase_active
2748-
and not self.is_warmup):
2749-
# NOTE: keep the gate condition free of any per-rank state
2750-
# (e.g. `fitting_disagg_gen_init_requests`). The
2751-
# `tp_allgather` below is a collective and every ADP rank
2752-
# must participate together; otherwise ranks desync and a
2753-
# later allgather mixes payload shapes (list[int] from
2754-
# gather_all_rank_states vs int from the gate's
2755-
# _is_benchmark_disagg_fill_complete), producing TypeErrors
2756-
# like "argument after * must be an iterable, not int" or
2757-
# "unsupported operand type(s) for +: 'int' and 'list'".
2758-
# The per-rank "still has fitting requests" hint is folded
2759-
# into the same allgather so we can suppress the stall
2760-
# check globally when any rank is still making progress.
2761-
local_ready_gen = sum(
2762-
1 for req in self.active_requests if req.state in (
2763-
LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE,
2764-
LlmRequestState.GENERATION_IN_PROGRESS,
2765-
))
2766-
local_has_fitting = 1 if fitting_disagg_gen_init_requests else 0
2767-
if self.enable_attention_dp:
2768-
responses = self.dist.tp_allgather(
2769-
[local_ready_gen, local_has_fitting])
2770-
total_ready_gen = sum(r[0] for r in responses)
2771-
any_rank_has_fitting = any(r[1] for r in responses)
2772-
else:
2773-
total_ready_gen = local_ready_gen
2774-
any_rank_has_fitting = bool(local_has_fitting)
2775-
2776-
if not any_rank_has_fitting:
2777-
now = time.time()
2778-
last_count = getattr(self, "_bench_disagg_last_gen_count",
2779-
None)
2780-
last_change_time = getattr(
2781-
self, "_bench_disagg_last_gen_count_time", None)
2782-
if (last_count != total_ready_gen
2783-
or last_change_time is None):
2784-
self._bench_disagg_last_gen_count = total_ready_gen
2785-
self._bench_disagg_last_gen_count_time = now
2786-
elif (now - last_change_time
2787-
> self.benchmark_fill_stall_timeout_s
2788-
and total_ready_gen < self.benchmark_req_queues_size):
2789-
error_msg = (
2790-
f"Benchmark gen request count stalled at "
2791-
f"{total_ready_gen} "
2792-
f"for {now - last_change_time:.0f}s "
2793-
f"(target {self.benchmark_req_queues_size}, "
2794-
f"fetched={self.num_fetch_requests}). "
2795-
f"Likely causes: KV transfer stuck, KV cache pool "
2796-
f"too small, or transceiver deadlock. Aborting all "
2797-
f"active requests.")
2798-
logger.error(error_msg)
2799-
self._handle_errors(error_msg,
2800-
requests=self.active_requests)
2801-
return None, None
2747+
if (self.benchmark_req_queues_size > 0 and not self.is_warmup
2748+
and not fitting_disagg_gen_init_requests):
2749+
stuck_init_requests = [
2750+
req for req in self.active_requests
2751+
if req.is_disagg_generation_init_state
2752+
]
2753+
# Only fail once all benchmark requests have been fetched
2754+
# so that _handle_errors covers every request and every
2755+
# client receives an error response.
2756+
if (stuck_init_requests and self.num_fetch_requests
2757+
>= self.benchmark_req_queues_size):
2758+
error_msg = (
2759+
f"Insufficient KV cache for gen-only benchmark mode: "
2760+
f"{len(stuck_init_requests)} request(s) are waiting for "
2761+
f"KV cache allocation but the scheduler could not fit "
2762+
f"any of them. Increase free_gpu_memory_fraction or "
2763+
f"reduce TLLM_BENCHMARK_REQ_QUEUES_SIZE (currently "
2764+
f"{self.benchmark_req_queues_size}).")
2765+
logger.error(error_msg)
2766+
# Fail all active and waiting requests so every
2767+
# client receives an error instead of hanging.
2768+
self._handle_errors(error_msg,
2769+
requests=self.active_requests)
2770+
return None, None
28022771

28032772
self.num_scheduled_requests = scheduled_batch.batch_size
28042773
logger.debug(
@@ -4450,14 +4419,21 @@ def _should_skip_dummy_for_benchmark_disagg(
44504419
def _update_adp_dummy_role(self, candidates: List[LlmRequest]) -> None:
44514420
if not self.enable_attention_dp or self.kv_cache_transceiver is None:
44524421
return
4422+
has_ctx = False
4423+
has_gen = False
44534424
for req in candidates:
44544425
rt = getattr(req, "llm_request_type", None)
44554426
if rt == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY:
4456-
self._adp_dummy_is_gen = False
4457-
return
4458-
if rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4459-
self._adp_dummy_is_gen = True
4460-
return
4427+
has_ctx = True
4428+
elif rt == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
4429+
has_gen = True
4430+
# Prefer the CTX role when both types are present this iteration: a CTX
4431+
# dummy is padded to max_num_tokens so idle ranks keep MoE all-to-all
4432+
# token counts comparable with ranks doing real context work.
4433+
if has_ctx:
4434+
self._adp_dummy_is_gen = False
4435+
elif has_gen:
4436+
self._adp_dummy_is_gen = True
44614437

44624438
@nvtx_range("_pad_attention_dp_dummy_request")
44634439
def _pad_attention_dp_dummy_request(self):
@@ -5462,12 +5438,14 @@ def _handle_responses(self, emit_first_iter: bool = True):
54625438
bool(timed_out_requests)))
54635439
if any_timed_out:
54645440
self._handle_errors(error_msg="Request timed out (KV transfer)",
5465-
requests=timed_out_requests)
5441+
requests=timed_out_requests,
5442+
charge_budget=False)
54665443
else:
54675444
for req in timed_out_requests:
54685445
self._handle_errors(
54695446
error_msg=f"Request {req.py_request_id} timed out",
5470-
requests=[req])
5447+
requests=[req],
5448+
charge_budget=False)
54715449
return requests_to_terminate + requests_finished_by_transfer
54725450

54735451
def _await_any_response(self,

tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
No GPU required.
1919
"""
2020

21+
from types import SimpleNamespace
2122
from unittest.mock import Mock, patch
2223

2324
import pytest
2425

26+
from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2
2527
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
2628
from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy
2729

@@ -2419,3 +2421,49 @@ def track_resize(req, n):
24192421
out = sched.schedule_request([req], set())
24202422
assert ids(out.context_requests) == []
24212423
assert resize_calls == [] # SKIP path: no commit to KV cache
2424+
2425+
2426+
# ---------------------------------------------------------------------------
2427+
# KVCacheManagerV2.trim_to_history (#14258 disagg-gen SWA history pre-declaration):
2428+
# called unbound on a fake self to cover all 5 branches without the heavy __init__.
2429+
# ---------------------------------------------------------------------------
2430+
class TestTrimToHistory:
2431+
@staticmethod
2432+
def _call(kv_cache, history_length, req_id=1):
2433+
kv_cache_map = {} if kv_cache is None else {req_id: kv_cache}
2434+
fake = SimpleNamespace(kv_cache_map=kv_cache_map)
2435+
req = SimpleNamespace(py_request_id=req_id)
2436+
return KVCacheManagerV2.trim_to_history(fake, req, history_length)
2437+
2438+
def test_missing_cache_is_noop_true(self):
2439+
assert self._call(None, 50) is True
2440+
2441+
def test_inactive_cache_is_noop_true(self):
2442+
kv = Mock(is_active=False)
2443+
assert self._call(kv, 50) is True
2444+
kv.resize.assert_not_called()
2445+
2446+
def test_history_not_increasing_is_noop_true(self):
2447+
kv = Mock(is_active=True, history_length=50, capacity=100)
2448+
assert self._call(kv, 50) is True # 50 <= current 50
2449+
kv.resize.assert_not_called()
2450+
2451+
def test_resize_success_clamps_capacity_and_returns_true(self):
2452+
kv = Mock(is_active=True, history_length=10, capacity=8)
2453+
kv.resize.return_value = True
2454+
assert self._call(kv, 64) is True
2455+
# target_capacity = max(capacity=8, history=64) = 64
2456+
kv.resize.assert_called_once_with(64, history_length=64)
2457+
2458+
def test_resize_rejection_returns_false(self):
2459+
kv = Mock(is_active=True, history_length=10, capacity=100)
2460+
kv.resize.return_value = False
2461+
assert self._call(kv, 64) is False
2462+
kv.resize.assert_called_once_with(100, history_length=64)
2463+
2464+
def test_resize_exception_degrades_to_false(self):
2465+
# Broad except: a non-ValueError (e.g. internal state assert) must
2466+
# degrade to False rather than propagate (do-not-narrow contract).
2467+
kv = Mock(is_active=True, history_length=10, capacity=100)
2468+
kv.resize.side_effect = RuntimeError("internal state assert")
2469+
assert self._call(kv, 64) is False

0 commit comments

Comments
 (0)