3232from tensorrt_llm ._torch .pyexecutor .resource_manager import (
3333 BaseResourceManager , CacheTypeCpp , DataType , KVCacheManager , get_pp_layers )
3434from tensorrt_llm ._torch .pyexecutor .scheduler import ScheduledRequests
35- from tensorrt_llm ._utils import (nvtx_range , prefer_pinned ,
36- torch_dtype_to_binding )
35+ from tensorrt_llm ._utils import nvtx_range , torch_dtype_to_binding
3736from tensorrt_llm .bindings .internal .batch_manager import (
3837 KvCacheConnectorManager , LinearAttentionMetadata , LinearCacheType )
3938from tensorrt_llm .llmapi .llm_args import KvCacheConfig
@@ -191,12 +190,10 @@ def free_resources(self, request: LlmRequest):
191190 self .mamba_impl .free_cache_block (request .py_request_id )
192191
193192 def add_dummy_requests (self , request_ids : List [int ], ** kwargs ):
194- # For CUDA graph dummy requests, the blocks will be allocated
195- # when get_state_indices is called.
196- from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID
197- request_ids = [
198- rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID
199- ]
193+ # Allocate a permanent slot for every id, including CUDA-graph
194+ # padding sentinels (matches PythonMambaCacheManager). Padding
195+ # entries in get_state_indices then resolve via mCacheIndex to
196+ # the sentinel's reserved slot and never alias a live request.
200197 if request_ids :
201198 self .mamba_impl .allocate_cache_blocks (request_ids )
202199
@@ -375,12 +372,6 @@ def __init__(
375372 # mamba cache index, maps request_id -> state indices
376373 self .mamba_cache_index : Dict [int , int ] = {}
377374
378- # mamba cache state indices
379- self .state_indices : torch .Tensor = torch .arange (max_batch_size ,
380- device = device ,
381- dtype = torch .int32 )
382- # save mamba state indices for requests
383- self .state_indices_list : List [int ] = []
384375 # save intermediate state indices for requests
385376 self .intermediate_state_indices = torch .arange (max_batch_size ,
386377 dtype = torch .int32 ,
@@ -399,23 +390,13 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
399390
400391 @torch .inference_mode ()
401392 def _prepare_mamba_cache_blocks (self , request_ids : List [int ]):
402- self .state_indices_list .clear ()
403393 for r in request_ids :
404- # cache hit
405394 if r in self .mamba_cache_index :
406- self .state_indices_list .append (self .mamba_cache_index [r ])
407- # cache miss
408- else :
409- if len (self .mamba_cache_free_blocks ) == 0 :
410- raise RuntimeError ("run out of mamba cache blocks" )
411- block = self .mamba_cache_free_blocks .pop ()
412- self .mamba_cache_index [r ] = block
413- self .state_indices_list .append (block )
414- self .state_indices [:len (self .state_indices_list )].copy_ (
415- torch .tensor (self .state_indices_list ,
416- dtype = torch .int32 ,
417- pin_memory = prefer_pinned ()),
418- non_blocking = True )
395+ continue
396+ if len (self .mamba_cache_free_blocks ) == 0 :
397+ raise RuntimeError ("run out of mamba cache blocks" )
398+ block = self .mamba_cache_free_blocks .pop ()
399+ self .mamba_cache_index [r ] = block
419400
420401 def prepare_resources (self , scheduled_batch : ScheduledRequests ):
421402 context_ids = [
@@ -428,10 +409,16 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
428409 self ._prepare_mamba_cache_blocks (request_ids )
429410
430411 def add_dummy_requests (self , request_ids : List [int ], ** kwargs ):
431- from .cuda_graph_runner import CUDA_GRAPH_DUMMY_REQUEST_ID
432- request_ids = [
433- rid for rid in request_ids if rid != CUDA_GRAPH_DUMMY_REQUEST_ID
434- ]
412+ # Allocate a permanent slot for every dummy request ID, including
413+ # the CUDA-graph padding sentinel. Padding entries in a batch all
414+ # reference the same dummy request ID, so they share one slot via
415+ # mamba_cache_index lookup in get_state_indices. This mirrors how
416+ # MTP's per-draft-len padding dummies already behave (they use
417+ # CUDA_GRAPH_DUMMY_REQUEST_ID - draft_len, which was never
418+ # filtered here) and keeps padding writes off every live
419+ # request's slot, even under the overlap scheduler where a prior
420+ # batch's completed requests linger in mamba_cache_index until
421+ # _process_previous_batch runs.
435422 if request_ids :
436423 for r in request_ids :
437424 if r not in self .mamba_cache_index :
@@ -448,29 +435,10 @@ def free_resources(self, request: LlmRequest):
448435
449436 def get_state_indices (self , request_ids : List [int ],
450437 is_padding : List [bool ]) -> List [int ]:
451- assert len (request_ids ) == len (is_padding ), (
452- "request_ids and is_padding must have the same size" )
453-
454- used_slots = {
455- self .mamba_cache_index [req_id ]
456- for req_id , pad in zip (request_ids , is_padding ) if not pad
457- }
458- available_slots = iter (
459- sorted (set (range (self .state_indices .numel ())) - used_slots ))
460-
461- def slot_for (req_id : int , pad : bool ):
462- if pad :
463- try :
464- return next (available_slots )
465- except StopIteration :
466- raise RuntimeError (
467- "Run out of available slots for padding" ) from None
468- return self .mamba_cache_index [req_id ]
469-
470- result = [
471- slot_for (rid , pad ) for rid , pad in zip (request_ids , is_padding )
472- ]
473- return result
438+ # Padding entries reuse the slot pre-allocated by their dummy
439+ # request in add_dummy_requests; see that method for the
440+ # overlap-scheduler rationale.
441+ return [self .mamba_cache_index [rid ] for rid in request_ids ]
474442
475443 def get_conv_states (self , layer_idx : int ) -> torch .Tensor :
476444 layer_offset = self .mamba_layer_offsets [layer_idx ]
@@ -509,9 +477,6 @@ def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
509477
510478 def shutdown (self ):
511479 """Release tensor memory."""
512- # Clear state indices
513- self .state_indices = torch .tensor ([])
514-
515480 # Clear mamba cache states
516481 if isinstance (self .mamba_cache , self .SpeculativeState ):
517482 self .mamba_cache = self .SpeculativeState (
@@ -530,14 +495,14 @@ def shutdown(self):
530495
531496 @torch .compile (options = {"max-autotune" : True })
532497 def update_mamba_states (self , attn_metadata : "AttentionMetadata" ,
533- num_accepted_tokens : torch .Tensor ):
498+ num_accepted_tokens : torch .Tensor ,
499+ state_indices : torch .Tensor ):
534500 batch_size = attn_metadata .num_seqs
535501 num_contexts = attn_metadata .num_contexts
536502 num_gens = batch_size - num_contexts
537503 num_accepted_draft_tokens = num_accepted_tokens [
538504 num_contexts :num_contexts + num_gens ] - 1
539- state_indices_d = self .state_indices [num_contexts :num_contexts +
540- num_gens ]
505+ state_indices_d = state_indices [num_contexts :num_contexts + num_gens ]
541506
542507 conv_states = self .mamba_cache .conv
543508 ssm_states = self .mamba_cache .temporal
@@ -684,9 +649,18 @@ def shutdown(self):
684649 self ._impl .shutdown ()
685650
686651 def update_mamba_states (self , attn_metadata : "AttentionMetadata" ,
687- num_accepted_tokens : torch .Tensor ):
652+ num_accepted_tokens : torch .Tensor ,
653+ state_indices : torch .Tensor ):
654+ # Non-speculative configs don't allocate intermediate state; the
655+ # promotion is a clean no-op.
656+ if not self ._impl .is_speculative ():
657+ return
658+ # Belt-and-suspenders: C++ is non-speculative today so this is
659+ # unreachable. Fires if C++ ever grows speculative support
660+ # without also implementing the scatter there.
688661 assert not self ._use_cpp , "update_mamba_states is not supported in CppMambaCacheManager"
689- self ._impl .update_mamba_states (attn_metadata , num_accepted_tokens )
662+ self ._impl .update_mamba_states (attn_metadata , num_accepted_tokens ,
663+ state_indices )
690664
691665
692666class MixedMambaHybridCacheManager (KVCacheManager , MambaCacheManager ):
@@ -733,7 +707,13 @@ def __init__(
733707 # mamba hybrid cache requires block reuse to be disabled in KV cache config
734708 assert not kv_cache_config .enable_block_reuse , "mamba hybrid cache requires block reuse to be disabled in KV cache config"
735709
736- # initialize mamba cache manager
710+ # Reserve one Mamba slot per possible CUDA-graph padding dummy
711+ # (one per runtime_draft_len in 0..max_draft_len) so a full
712+ # max_batch_size of real requests still leaves room for padding.
713+ max_draft_len = (spec_config .max_draft_len
714+ if spec_config is not None else 0 )
715+ pool_size = max_batch_size + max_draft_len + 1
716+
737717 MambaCacheManager .__init__ (
738718 self ,
739719 mamba_d_state ,
@@ -742,7 +722,7 @@ def __init__(
742722 mamba_n_groups ,
743723 mamba_head_dim ,
744724 mamba_num_layers ,
745- max_batch_size ,
725+ pool_size ,
746726 max_batch_size ,
747727 mapping ,
748728 mamba_cache_dtype ,
@@ -796,11 +776,6 @@ def update_resources(self,
796776 KVCacheManager .update_resources (self , scheduled_batch , attn_metadata ,
797777 kv_cache_dtype_byte_size )
798778
799- def update_mamba_states (self , attn_metadata : "AttentionMetadata" ,
800- num_accepted_tokens : torch .Tensor ):
801- MambaCacheManager .update_mamba_states (self , attn_metadata ,
802- num_accepted_tokens )
803-
804779
805780def calc_context_stop_positions (prompt_len : int ,
806781 tokens_per_block : int ,
0 commit comments