@@ -182,7 +182,9 @@ def __init__(
182182 spec_config .tokens_per_gen_step -
183183 1 ) if spec_config is not None else 0
184184 # Saved before zeroing for draft models; used by update_spec_dec_param.
185- self ._spec_dec_max_total_draft_tokens = spec_config .max_total_draft_tokens if spec_config is not None else 0
185+ self ._spec_dec_max_total_draft_tokens = (
186+ spec_config .max_total_draft_tokens
187+ if spec_config is not None else 0 )
186188
187189 preserve_wrapped_eagle3_widths = (spec_config is not None
188190 and is_draft_model
@@ -341,11 +343,14 @@ def __init__(
341343 self .llm_args .attn_backend ,
342344 sparse_attn_config = self .sparse_attention_config )
343345
346+ self .get_runtime_tokens_per_gen_step = spec_config .get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len : 1
347+
344348 if self .is_spec_decode :
345349 self .spec_metadata = None
346350 update_spec_config_from_model_config (self .spec_config ,
347351 self .model .config )
348- max_num_draft_tokens = self .original_max_total_draft_tokens * self .batch_size
352+ max_num_draft_tokens = (self .original_max_total_draft_tokens *
353+ self .batch_size )
349354 self .draft_tokens_cuda = torch .empty ((max_num_draft_tokens , ),
350355 dtype = torch .int ,
351356 device = 'cuda' )
@@ -876,12 +881,31 @@ def _get_graphs_to_capture(
876881 ):
877882 graphs = [(graph_bs , draft_len ) for graph_bs , draft_len in
878883 self ._dynamic_draft_len_mapping .items ()]
884+ # Workaround for dynamic draft length:
885+ # capture the maximum speculative graph shape up front. Dynamic draft length
886+ # breaks the previous assumption that attention workspace demand can be safely
887+ # ordered by batch size alone; a later graph shape may require a larger shared
888+ # graph workspace, and resizing that workspace can change its data_ptr and
889+ # invalidate pointers captured by earlier graphs, causing illegal memory access
890+ # on replay.
891+ #
892+ # This adds the overhead of one extra captured graph, and that graph is not
893+ # expected to be used by the normal schedule-driven dynamic draft-length path.
894+ #
895+ # Follow-up first-principles fix:
896+ # query or precompute the exact attention workspace requirement for all
897+ # reachable graph shapes, pre-size the shared graph workspace once without
898+ # capturing an extra graph, and avoid resizing it in graph mode afterward.
899+ max_spec_graph = (max (cuda_graph_batch_sizes ),
900+ self .original_max_draft_len )
901+ if max_spec_graph not in graphs :
902+ graphs .append (max_spec_graph )
879903 logger .info (f"Dynamic draft length enabled for one-model path. "
880904 f"Capturing { len (graphs )} graphs: { graphs } " )
881905 return graphs
882906
883907 # Case 3: Target model (two-model) or one-model without dynamic draft
884- draft_lengths = [self .max_total_draft_tokens ]
908+ draft_lengths = [self .max_draft_len ]
885909 should_capture_no_spec = (
886910 self .max_total_draft_tokens > 0
887911 and not self .spec_config .spec_dec_mode .use_one_engine ()
@@ -1206,12 +1230,15 @@ def _create_cuda_graph_warmup_request(
12061230
12071231 result = ScheduledRequests ()
12081232 num_extra_decoding_steps = self ._get_num_extra_decoding_steps ()
1233+ runtime_tokens_per_gen_step = self .get_runtime_tokens_per_gen_step (
1234+ draft_len )
1235+ runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
12091236
12101237 # Add (batch_size - 1) dummy requests with seq_len=1.
12111238 requests = kv_cache_manager .add_dummy_requests (
12121239 list (range (batch_size - 1 )),
12131240 is_gen = True ,
1214- max_num_draft_tokens = draft_len ,
1241+ max_num_draft_tokens = runtime_draft_token_buffer_width ,
12151242 use_mrope = self .use_mrope ,
12161243 max_beam_width = self .max_beam_width ,
12171244 num_extra_decoding_steps = num_extra_decoding_steps ,
@@ -1228,26 +1255,29 @@ def _create_cuda_graph_warmup_request(
12281255 available_tokens = kv_cache_manager .get_num_available_tokens (
12291256 token_num_upper_bound = max_seq_len ,
12301257 batch_size = batch_size ,
1231- max_num_draft_tokens = draft_len )
1258+ max_num_draft_tokens = runtime_draft_token_buffer_width )
12321259
12331260 # Also consider draft KV cache capacity when it exists
12341261 if draft_kv_cache_manager is not None :
12351262 draft_available_tokens = draft_kv_cache_manager .get_num_available_tokens (
12361263 batch_size = batch_size ,
12371264 token_num_upper_bound = max_seq_len ,
1238- max_num_draft_tokens = draft_len )
1265+ max_num_draft_tokens = runtime_draft_token_buffer_width )
12391266 available_tokens = min (available_tokens , draft_available_tokens )
12401267
12411268 token_num = max (
12421269 1 ,
12431270 min (
1244- available_tokens , max_seq_len - 1 -
1245- get_num_extra_kv_tokens (self .spec_config ) - draft_len ))
1271+ available_tokens ,
1272+ max_seq_len - 1 - get_num_extra_kv_tokens (self .spec_config ) -
1273+ runtime_draft_token_buffer_width ))
12461274 model_config = self .model .model_config .pretrained_config
12471275 max_position_embeddings = getattr (model_config ,
12481276 'max_position_embeddings' , None )
12491277 if max_position_embeddings is not None :
1250- token_num = min (token_num , max_position_embeddings - draft_len )
1278+ token_num = min (
1279+ token_num ,
1280+ max_position_embeddings - runtime_draft_token_buffer_width )
12511281
12521282 assert token_num > num_extra_decoding_steps , (
12531283 "Cannot fuse drafting loop. Not enough KV cache space for all draft tokens."
@@ -1258,7 +1288,7 @@ def _create_cuda_graph_warmup_request(
12581288 request_ids = [batch_size - 1 ],
12591289 token_nums = [token_num ],
12601290 is_gen = True ,
1261- max_num_draft_tokens = draft_len ,
1291+ max_num_draft_tokens = runtime_draft_token_buffer_width ,
12621292 use_mrope = self .use_mrope ,
12631293 max_beam_width = self .max_beam_width ,
12641294 num_extra_decoding_steps = num_extra_decoding_steps ,
@@ -1985,8 +2015,10 @@ def _update_target_input_tensors(
19852015 non_blocking = True )
19862016
19872017 # Prepare draft tokens
2018+ num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1
19882019 self .draft_tokens_cuda [:previous_batch_draft_tokens ].copy_ (
1989- next_draft_tokens_device [previous_slots , :].flatten (),
2020+ next_draft_tokens_device [
2021+ previous_slots , :num_draft_tokens_per_extend_request ].flatten (),
19902022 non_blocking = True )
19912023
19922024 # Compute kv_len_offsets and update offset tensors
@@ -2022,8 +2054,10 @@ def _apply_incremental_update_target(
20222054 # Pre-compute constants
20232055 extend_requests = scheduled_requests .generation_requests
20242056 num_extend_requests = len (extend_requests )
2025- num_tokens_per_extend_request = self .runtime_draft_len + 1
20262057 spec_config = self .spec_config
2058+ num_tokens_per_extend_request = self .get_runtime_tokens_per_gen_step (
2059+ self .runtime_draft_len )
2060+ runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1
20272061
20282062 prompt_lengths = torch .empty (num_extend_requests ,
20292063 dtype = torch .int ,
@@ -2085,7 +2119,8 @@ def _apply_incremental_update_target(
20852119 prompt_lengths = prompt_lengths .tolist ()
20862120 num_cached_tokens_per_seq = num_cached_tokens_per_seq .tolist ()
20872121
2088- previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self .runtime_draft_len
2122+ previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy *
2123+ runtime_draft_token_buffer_width )
20892124
20902125 self ._update_target_input_tensors (
20912126 num_accepted_tokens_device = num_accepted_tokens_device ,
@@ -2368,6 +2403,9 @@ def _prepare_tp_inputs(
23682403 # will contain previous batch indices of generation requests
23692404 previous_batch_indices = []
23702405 previous_pos_indices = []
2406+ runtime_tokens_per_gen_step = self .get_runtime_tokens_per_gen_step (
2407+ self .runtime_draft_len )
2408+ runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
23712409 for request in extend_requests :
23722410 request_ids .append (request .py_request_id )
23732411 request_accepted_path [
@@ -2426,16 +2464,16 @@ def _prepare_tp_inputs(
24262464 previous_batch_idx = request .py_batch_idx
24272465 request .py_batch_idx = request .py_seq_slot
24282466
2429- sequence_lengths .append (1 + self . runtime_draft_len )
2467+ sequence_lengths .append (runtime_tokens_per_gen_step )
24302468 num_accepted_draft_tokens .append (
24312469 request .py_num_accepted_draft_tokens )
24322470 past_seen_token_num = request .max_beam_num_tokens - 1
24332471
2434- draft_lens .append (self . runtime_draft_len )
2472+ draft_lens .append (runtime_draft_token_buffer_width )
24352473 gather_ids .extend (
24362474 list (
24372475 range (len (position_ids ),
2438- len (position_ids ) + 1 + self . runtime_draft_len )))
2476+ len (position_ids ) + runtime_tokens_per_gen_step )))
24392477 # For the target model + tree decoding
24402478 if not self .is_draft_model and not spec_config .is_linear_tree :
24412479 assert spec_tree_manager is not None
@@ -2448,19 +2486,19 @@ def _prepare_tp_inputs(
24482486 position_ids .extend (
24492487 list (
24502488 range (
2451- past_seen_token_num , past_seen_token_num + 1 +
2452- self . runtime_draft_len )))
2489+ past_seen_token_num , past_seen_token_num +
2490+ runtime_tokens_per_gen_step )))
24532491 # previous tensor
24542492 previous_batch_indices .append (previous_batch_idx )
24552493 previous_pos_indices .extend ([previous_batch_idx ] *
2456- ( 1 + self . runtime_draft_len ) )
2494+ runtime_tokens_per_gen_step )
24572495
24582496 num_cached_tokens_per_seq .append (past_seen_token_num +
2459- self . runtime_draft_len + 1 )
2497+ runtime_tokens_per_gen_step )
24602498 request .cached_tokens = num_cached_tokens_per_seq [- 1 ]
24612499 if self .enable_spec_decode and spec_config .spec_dec_mode .extend_ctx (
24622500 self .attn_backend ) and spec_config .is_linear_tree :
2463- prompt_lengths .append (1 + self . runtime_draft_len )
2501+ prompt_lengths .append (runtime_tokens_per_gen_step )
24642502 else :
24652503 prompt_lengths .append (request .py_prompt_len )
24662504
@@ -2765,30 +2803,36 @@ def previous_seq_slots_device():
27652803 # Initialize these two values to zeros
27662804 self .previous_pos_id_offsets_cuda *= 0
27672805 self .previous_kv_lens_offsets_cuda *= 0
2806+ runtime_tokens_per_gen_step = self .get_runtime_tokens_per_gen_step (
2807+ self .runtime_draft_len )
2808+ runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
27682809
27692810 if previous_batch_len > 0 :
27702811 previous_slots = previous_seq_slots_device ()
27712812 # previous input ids
2772- previous_batch_tokens = previous_batch_len * (
2773- 1 + self . runtime_draft_len )
2813+ previous_batch_tokens = ( previous_batch_len *
2814+ runtime_tokens_per_gen_step )
27742815 new_tokens = new_tokens_device .transpose (
27752816 0 ,
2776- 1 )[previous_slots , :( 1 + self . runtime_draft_len ) ].flatten ()
2817+ 1 )[previous_slots , :runtime_tokens_per_gen_step ].flatten ()
27772818 self .input_ids_cuda [num_tokens :num_tokens +
27782819 previous_batch_tokens ].copy_ (
27792820 new_tokens , non_blocking = True )
27802821
27812822 # previous draft tokens
2782- previous_batch_draft_tokens = previous_batch_len * self .runtime_draft_len
2783- if self .runtime_draft_len > 0 :
2784- self .draft_tokens_cuda [num_draft_tokens :num_draft_tokens +
2785- previous_batch_draft_tokens ].copy_ (
2786- next_draft_tokens_device [
2787- previous_slots , :self .
2788- runtime_draft_len ].flatten (),
2789- non_blocking = True )
2823+ previous_batch_draft_tokens = (previous_batch_len *
2824+ runtime_draft_token_buffer_width )
2825+ if runtime_draft_token_buffer_width > 0 :
2826+ self .draft_tokens_cuda [
2827+ num_draft_tokens :num_draft_tokens +
2828+ previous_batch_draft_tokens ].copy_ (
2829+ next_draft_tokens_device [
2830+ previous_slots , :
2831+ runtime_draft_token_buffer_width ].flatten (),
2832+ non_blocking = True )
27902833 # prepare data for the preprocess inputs
2791- kv_len_offsets_device = new_tokens_lens_device - self .runtime_draft_len - 1
2834+ kv_len_offsets_device = (new_tokens_lens_device -
2835+ runtime_tokens_per_gen_step )
27922836 previous_pos_indices_host = torch .tensor (
27932837 previous_pos_indices ,
27942838 dtype = torch .int ,
@@ -2814,8 +2858,8 @@ def previous_seq_slots_device():
28142858 extend_dummy_requests )
28152859 self .previous_pos_id_offsets_cuda [
28162860 (num_extend_reqeust_wo_dummy - previous_batch_len ) *
2817- ( 1 + self . runtime_draft_len ) :num_extend_reqeust_wo_dummy *
2818- ( 1 + self . runtime_draft_len ) ].copy_ (
2861+ runtime_tokens_per_gen_step :num_extend_reqeust_wo_dummy *
2862+ runtime_tokens_per_gen_step ].copy_ (
28192863 new_tokens_lens_device [self .previous_pos_indices_cuda [
28202864 0 :previous_batch_tokens ]],
28212865 non_blocking = True )
@@ -3679,6 +3723,8 @@ def forward(self,
36793723 # Propagate runtime_draft_len (already set on self by py_executor)
36803724 # to spec_metadata so downstream code (eagle3, interface, trtllm) can read it.
36813725 spec_metadata .runtime_draft_len = self .runtime_draft_len
3726+ spec_metadata .runtime_tokens_per_gen_step = (
3727+ self .get_runtime_tokens_per_gen_step (self .runtime_draft_len ))
36823728
36833729 # PARD has 2K tokens per gen request, not K+1. Pass 2K-1
36843730 # so generation_lengths = 2K and the XQA kernel computes
0 commit comments