Skip to content

Commit 450122e

Browse files
authored
[None][fix] Fix Mamba cache correctness under MTP + CUDA-graph padding (#13151)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 734a146 commit 450122e

9 files changed

Lines changed: 320 additions & 127 deletions

File tree

cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
#include "tensorrt_llm/runtime/cudaStream.h"
2121
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
2222

23-
#include <unordered_set>
24-
2523
using namespace tensorrt_llm::runtime;
2624

2725
namespace tensorrt_llm::batch_manager::rnn_state_manager
@@ -258,40 +256,16 @@ std::vector<RnnStateManager::SizeType32> RnnStateManager::getStateIndices(
258256
std::vector<RequestIdType> const& requestIds, std::vector<bool> const& isPadding)
259257
{
260258
TLLM_CHECK_WITH_INFO(requestIds.size() == isPadding.size(), "requestIds and isPadding must have the same size");
261-
262-
std::unordered_set<SizeType32> availableSlots;
263-
availableSlots.reserve(mMaxNumSequences);
264-
for (SizeType32 i = 0; i < mMaxNumSequences; ++i)
265-
{
266-
availableSlots.insert(i);
267-
}
268-
269-
for (size_t i = 0; i < requestIds.size(); ++i)
270-
{
271-
if (!isPadding[i])
272-
{
273-
availableSlots.erase(getCacheIndex(requestIds[i]));
274-
}
275-
}
276-
259+
// Every id (real or CUDA-graph padding sentinel) has a permanent slot
260+
// allocated by allocateCacheBlocks; padding entries all share their
261+
// sentinel's slot, so they never alias a live request and never
262+
// consume free-pool slots.
277263
std::vector<SizeType32> result;
278264
result.reserve(requestIds.size());
279-
auto availableIt = availableSlots.begin();
280-
281-
for (size_t i = 0; i < requestIds.size(); ++i)
265+
for (auto const& rid : requestIds)
282266
{
283-
if (isPadding[i])
284-
{
285-
TLLM_CHECK_WITH_INFO(availableIt != availableSlots.end(), "Run out of available slots for padding");
286-
result.push_back(*availableIt);
287-
++availableIt;
288-
}
289-
else
290-
{
291-
result.push_back(getCacheIndex(requestIds[i]));
292-
}
267+
result.push_back(getCacheIndex(rid));
293268
}
294-
295269
return result;
296270
}
297271

cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ void mamba2_mtp_ssm_cache_update(th::Tensor ssm, th::Tensor x, th::Tensor dt, th
5555
int const head_dim = ssm.size(2);
5656
int const ssm_dim = ssm.size(3);
5757

58-
TORCH_CHECK(intermediate_states.dim() == 5 && intermediate_states.size(0) == ssm.size(0)
58+
// ssm.size(0) is the Mamba cache capacity — independent of the
59+
// current batch (may include parked requests or reserved dummy
60+
// slots). intermediate_states is per-step scratch indexed by
61+
// intermediate_states_indices in [0, bs), so it only needs to
62+
// fit the forward batch.
63+
TORCH_CHECK(intermediate_states.dim() == 5 && intermediate_states.size(0) >= bs
5964
&& intermediate_states.size(1) == cache_steps && intermediate_states.size(2) == nheads
6065
&& intermediate_states.size(3) == head_dim && intermediate_states.size(4) == ssm_dim,
6166
"intermediate_states shape check failed");

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,12 @@ def _forward_with_kv_cache(self, csi: CachedSequenceInterface):
10251025
kv_cache_manager = csi.kv_cache_manager
10261026
if num_extend > 0 and isinstance(kv_cache_manager, MambaHybridCacheManager):
10271027
if kv_cache_manager.is_speculative():
1028+
state_indices = csi.get_arg("slot_idx", truncate=True)
10281029
_ctx = SimpleNamespace(num_seqs=num_sequences, num_contexts=num_prefill)
10291030
kv_cache_manager.update_mamba_states(
10301031
attn_metadata=_ctx,
10311032
num_accepted_tokens=new_tokens_lens,
1033+
state_indices=state_indices,
10321034
)
10331035

10341036
# compute the cache and position offset based on the number of new tokens compared to the

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,6 @@ def _generate_dummy_request(
305305
dummy_request = kv_cache_manager.add_dummy_requests([request_id], **request_kwargs)[0]
306306
dummy_request.is_cuda_graph_dummy = True
307307

308-
# generate a dummy scheduled requests object
309-
dummy_scheduled_requests = ScheduledRequests()
310-
dummy_scheduled_requests.generation_requests.append(dummy_request)
311-
312-
# if it's a hybrid kv-cache manager, we need to manually call prepare_resources again (not done
313-
# in add_dummy_requests)
314-
if is_hybrid_cache:
315-
kv_cache_manager.prepare_resources(dummy_scheduled_requests)
316-
317308
# add to spec resource manager
318309
if spec_res_mgr:
319310
spec_res_mgr.add_dummy_requests([request_id])

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
from tensorrt_llm._torch.pyexecutor.resource_manager import (
3333
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
3434
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
35-
from tensorrt_llm._utils import (nvtx_range, prefer_pinned,
36-
torch_dtype_to_binding)
35+
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
3736
from tensorrt_llm.bindings.internal.batch_manager import (
3837
KvCacheConnectorManager, LinearAttentionMetadata, LinearCacheType)
3938
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
@@ -191,12 +190,10 @@ def free_resources(self, request: LlmRequest):
191190
self.mamba_impl.free_cache_block(request.py_request_id)
192191

193192
def add_dummy_requests(self, request_ids: List[int], **kwargs):
194-
# For CUDA graph dummy requests, the blocks will be allocated
195-
# when get_state_indices is called.
196-
from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID
197-
request_ids = [
198-
rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID
199-
]
193+
# Allocate a permanent slot for every id, including CUDA-graph
194+
# padding sentinels (matches PythonMambaCacheManager). Padding
195+
# entries in get_state_indices then resolve via mCacheIndex to
196+
# the sentinel's reserved slot and never alias a live request.
200197
if request_ids:
201198
self.mamba_impl.allocate_cache_blocks(request_ids)
202199

@@ -375,12 +372,6 @@ def __init__(
375372
# mamba cache index, maps request_id -> state indices
376373
self.mamba_cache_index: Dict[int, int] = {}
377374

378-
# mamba cache state indices
379-
self.state_indices: torch.Tensor = torch.arange(max_batch_size,
380-
device=device,
381-
dtype=torch.int32)
382-
# save mamba state indices for requests
383-
self.state_indices_list: List[int] = []
384375
# save intermediate state indices for requests
385376
self.intermediate_state_indices = torch.arange(max_batch_size,
386377
dtype=torch.int32,
@@ -399,23 +390,13 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
399390

400391
@torch.inference_mode()
401392
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
402-
self.state_indices_list.clear()
403393
for r in request_ids:
404-
# cache hit
405394
if r in self.mamba_cache_index:
406-
self.state_indices_list.append(self.mamba_cache_index[r])
407-
# cache miss
408-
else:
409-
if len(self.mamba_cache_free_blocks) == 0:
410-
raise RuntimeError("run out of mamba cache blocks")
411-
block = self.mamba_cache_free_blocks.pop()
412-
self.mamba_cache_index[r] = block
413-
self.state_indices_list.append(block)
414-
self.state_indices[:len(self.state_indices_list)].copy_(
415-
torch.tensor(self.state_indices_list,
416-
dtype=torch.int32,
417-
pin_memory=prefer_pinned()),
418-
non_blocking=True)
395+
continue
396+
if len(self.mamba_cache_free_blocks) == 0:
397+
raise RuntimeError("run out of mamba cache blocks")
398+
block = self.mamba_cache_free_blocks.pop()
399+
self.mamba_cache_index[r] = block
419400

420401
def prepare_resources(self, scheduled_batch: ScheduledRequests):
421402
context_ids = [
@@ -428,10 +409,16 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
428409
self._prepare_mamba_cache_blocks(request_ids)
429410

430411
def add_dummy_requests(self, request_ids: List[int], **kwargs):
431-
from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID
432-
request_ids = [
433-
rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID
434-
]
412+
# Allocate a permanent slot for every dummy request ID, including
413+
# the CUDA-graph padding sentinel. Padding entries in a batch all
414+
# reference the same dummy request ID, so they share one slot via
415+
# mamba_cache_index lookup in get_state_indices. This mirrors how
416+
# MTP's per-draft-len padding dummies already behave (they use
417+
# CUDA_GRAPH_DUMMY_REQUEST_ID - draft_len, which was never
418+
# filtered here) and keeps padding writes off every live
419+
# request's slot, even under the overlap scheduler where a prior
420+
# batch's completed requests linger in mamba_cache_index until
421+
# _process_previous_batch runs.
435422
if request_ids:
436423
for r in request_ids:
437424
if r not in self.mamba_cache_index:
@@ -448,29 +435,10 @@ def free_resources(self, request: LlmRequest):
448435

449436
def get_state_indices(self, request_ids: List[int],
450437
is_padding: List[bool]) -> List[int]:
451-
assert len(request_ids) == len(is_padding), (
452-
"request_ids and is_padding must have the same size")
453-
454-
used_slots = {
455-
self.mamba_cache_index[req_id]
456-
for req_id, pad in zip(request_ids, is_padding) if not pad
457-
}
458-
available_slots = iter(
459-
sorted(set(range(self.state_indices.numel())) - used_slots))
460-
461-
def slot_for(req_id: int, pad: bool):
462-
if pad:
463-
try:
464-
return next(available_slots)
465-
except StopIteration:
466-
raise RuntimeError(
467-
"Run out of available slots for padding") from None
468-
return self.mamba_cache_index[req_id]
469-
470-
result = [
471-
slot_for(rid, pad) for rid, pad in zip(request_ids, is_padding)
472-
]
473-
return result
438+
# Padding entries reuse the slot pre-allocated by their dummy
439+
# request in add_dummy_requests; see that method for the
440+
# overlap-scheduler rationale.
441+
return [self.mamba_cache_index[rid] for rid in request_ids]
474442

475443
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
476444
layer_offset = self.mamba_layer_offsets[layer_idx]
@@ -509,9 +477,6 @@ def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
509477

510478
def shutdown(self):
511479
"""Release tensor memory."""
512-
# Clear state indices
513-
self.state_indices = torch.tensor([])
514-
515480
# Clear mamba cache states
516481
if isinstance(self.mamba_cache, self.SpeculativeState):
517482
self.mamba_cache = self.SpeculativeState(
@@ -530,14 +495,14 @@ def shutdown(self):
530495

531496
@torch.compile(options={"max-autotune": True})
532497
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
533-
num_accepted_tokens: torch.Tensor):
498+
num_accepted_tokens: torch.Tensor,
499+
state_indices: torch.Tensor):
534500
batch_size = attn_metadata.num_seqs
535501
num_contexts = attn_metadata.num_contexts
536502
num_gens = batch_size - num_contexts
537503
num_accepted_draft_tokens = num_accepted_tokens[
538504
num_contexts:num_contexts + num_gens] - 1
539-
state_indices_d = self.state_indices[num_contexts:num_contexts +
540-
num_gens]
505+
state_indices_d = state_indices[num_contexts:num_contexts + num_gens]
541506

542507
conv_states = self.mamba_cache.conv
543508
ssm_states = self.mamba_cache.temporal
@@ -684,9 +649,18 @@ def shutdown(self):
684649
self._impl.shutdown()
685650

686651
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
687-
num_accepted_tokens: torch.Tensor):
652+
num_accepted_tokens: torch.Tensor,
653+
state_indices: torch.Tensor):
654+
# Non-speculative configs don't allocate intermediate state; the
655+
# promotion is a clean no-op.
656+
if not self._impl.is_speculative():
657+
return
658+
# Belt-and-suspenders: C++ is non-speculative today so this is
659+
# unreachable. Fires if C++ ever grows speculative support
660+
# without also implementing the scatter there.
688661
assert not self._use_cpp, "update_mamba_states is not supported in CppMambaCacheManager"
689-
self._impl.update_mamba_states(attn_metadata, num_accepted_tokens)
662+
self._impl.update_mamba_states(attn_metadata, num_accepted_tokens,
663+
state_indices)
690664

691665

692666
class MixedMambaHybridCacheManager(KVCacheManager, MambaCacheManager):
@@ -733,7 +707,13 @@ def __init__(
733707
# mamba hybrid cache requires block reuse to be disabled in KV cache config
734708
assert not kv_cache_config.enable_block_reuse, "mamba hybrid cache requires block reuse to be disabled in KV cache config"
735709

736-
# initialize mamba cache manager
710+
# Reserve one Mamba slot per possible CUDA-graph padding dummy
711+
# (one per runtime_draft_len in 0..max_draft_len) so a full
712+
# max_batch_size of real requests still leaves room for padding.
713+
max_draft_len = (spec_config.max_draft_len
714+
if spec_config is not None else 0)
715+
pool_size = max_batch_size + max_draft_len + 1
716+
737717
MambaCacheManager.__init__(
738718
self,
739719
mamba_d_state,
@@ -742,7 +722,7 @@ def __init__(
742722
mamba_n_groups,
743723
mamba_head_dim,
744724
mamba_num_layers,
745-
max_batch_size,
725+
pool_size,
746726
max_batch_size,
747727
mapping,
748728
mamba_cache_dtype,
@@ -796,11 +776,6 @@ def update_resources(self,
796776
KVCacheManager.update_resources(self, scheduled_batch, attn_metadata,
797777
kv_cache_dtype_byte_size)
798778

799-
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
800-
num_accepted_tokens: torch.Tensor):
801-
MambaCacheManager.update_mamba_states(self, attn_metadata,
802-
num_accepted_tokens)
803-
804779

805780
def calc_context_stop_positions(prompt_len: int,
806781
tokens_per_block: int,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,16 @@ def drafting_loop_wrapper(model):
904904
py_executor.kv_cache_transceiver.shutdown()
905905
finally:
906906
kv_cache_creator.teardown_managers(resources)
907+
908+
# Release Phase-1 CUDA graph pools before final KV allocation to avoid overshoot.
909+
for eng in [model_engine, draft_model_engine]:
910+
if eng is None:
911+
continue
912+
if eng.attn_metadata is not None:
913+
if llm_args.cuda_graph_config is not None:
914+
eng._release_cuda_graphs()
915+
eng.attn_metadata = None
916+
907917
del py_executor # free before constructing new
908918
gc.collect()
909919

@@ -918,13 +928,6 @@ def drafting_loop_wrapper(model):
918928
max_seq_len = kv_cache_creator._max_seq_len
919929
update_sampler_max_seq_len(max_seq_len, sampler)
920930

921-
for eng in [model_engine, draft_model_engine]:
922-
if eng is None:
923-
continue
924-
if eng.attn_metadata is not None:
925-
if llm_args.cuda_graph_config is not None:
926-
eng._release_cuda_graphs()
927-
eng.attn_metadata = None
928931
with allocation_scope(ExecutorMemoryType.EXTRA_RESOURCES):
929932

930933
# run gc.collect() to free memory of the previous py_executor, avoid cudaFree overlap with cuda graph capture

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,8 @@ def forward(
11691169
if num_gens > 0 and self._is_mamba_hybrid_cache:
11701170
attn_metadata.kv_cache_manager.update_mamba_states(
11711171
attn_metadata=attn_metadata,
1172-
num_accepted_tokens=num_accepted_tokens)
1172+
num_accepted_tokens=num_accepted_tokens,
1173+
state_indices=attn_metadata.mamba_metadata.state_indices)
11731174

11741175
# Save the old attn_metadata and spec_metadata
11751176
self._prepare_attn_metadata_for_spec_dec(attn_metadata)

0 commit comments

Comments
 (0)