@@ -180,7 +180,9 @@ def __init__(
180180 spec_config .tokens_per_gen_step -
181181 1 ) if spec_config is not None else 0
182182 # Saved before zeroing for draft models; used by update_spec_dec_param.
183- self ._spec_dec_max_total_draft_tokens = spec_config .max_total_draft_tokens if spec_config is not None else 0
183+ self ._spec_dec_max_total_draft_tokens = (
184+ spec_config .max_total_draft_tokens
185+ if spec_config is not None else 0 )
184186
185187 preserve_wrapped_eagle3_widths = (spec_config is not None
186188 and is_draft_model
@@ -334,11 +336,14 @@ def __init__(
334336 self .llm_args .attn_backend ,
335337 sparse_attn_config = self .sparse_attention_config )
336338
339+ self .get_runtime_tokens_per_gen_step = spec_config .get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len : 1
340+
337341 if self .is_spec_decode :
338342 self .spec_metadata = None
339343 update_spec_config_from_model_config (self .spec_config ,
340344 self .model .config )
341- max_num_draft_tokens = self .original_max_total_draft_tokens * self .batch_size
345+ max_num_draft_tokens = (self .original_max_total_draft_tokens *
346+ self .batch_size )
342347 self .draft_tokens_cuda = torch .empty ((max_num_draft_tokens , ),
343348 dtype = torch .int ,
344349 device = 'cuda' )
@@ -869,7 +874,7 @@ def _get_graphs_to_capture(
869874 return graphs
870875
871876 # Case 3: Target model (two-model) or one-model without dynamic draft
872- draft_lengths = [self .max_total_draft_tokens ]
877+ draft_lengths = [self .max_draft_len ]
873878 should_capture_no_spec = (
874879 self .max_total_draft_tokens > 0
875880 and not self .spec_config .spec_dec_mode .use_one_engine ()
@@ -1194,12 +1199,15 @@ def _create_cuda_graph_warmup_request(
11941199
11951200 result = ScheduledRequests ()
11961201 num_extra_decoding_steps = self ._get_num_extra_decoding_steps ()
1202+ runtime_tokens_per_gen_step = self .get_runtime_tokens_per_gen_step (
1203+ draft_len )
1204+ runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
11971205
11981206 # Add (batch_size - 1) dummy requests with seq_len=1.
11991207 requests = kv_cache_manager .add_dummy_requests (
12001208 list (range (batch_size - 1 )),
12011209 is_gen = True ,
1202- max_num_draft_tokens = draft_len ,
1210+ max_num_draft_tokens = runtime_draft_token_buffer_width ,
12031211 use_mrope = self .use_mrope ,
12041212 max_beam_width = self .max_beam_width ,
12051213 num_extra_decoding_steps = num_extra_decoding_steps ,
@@ -1216,26 +1224,29 @@ def _create_cuda_graph_warmup_request(
12161224 available_tokens = kv_cache_manager .get_num_available_tokens (
12171225 token_num_upper_bound = max_seq_len ,
12181226 batch_size = batch_size ,
1219- max_num_draft_tokens = draft_len )
1227+ max_num_draft_tokens = runtime_draft_token_buffer_width )
12201228
12211229 # Also consider draft KV cache capacity when it exists
12221230 if draft_kv_cache_manager is not None :
12231231 draft_available_tokens = draft_kv_cache_manager .get_num_available_tokens (
12241232 batch_size = batch_size ,
12251233 token_num_upper_bound = max_seq_len ,
1226- max_num_draft_tokens = draft_len )
1234+ max_num_draft_tokens = runtime_draft_token_buffer_width )
12271235 available_tokens = min (available_tokens , draft_available_tokens )
12281236
12291237 token_num = max (
12301238 1 ,
12311239 min (
1232- available_tokens , max_seq_len - 1 -
1233- get_num_extra_kv_tokens (self .spec_config ) - draft_len ))
1240+ available_tokens ,
1241+ max_seq_len - 1 - get_num_extra_kv_tokens (self .spec_config ) -
1242+ runtime_draft_token_buffer_width ))
12341243 model_config = self .model .model_config .pretrained_config
12351244 max_position_embeddings = getattr (model_config ,
12361245 'max_position_embeddings' , None )
12371246 if max_position_embeddings is not None :
1238- token_num = min (token_num , max_position_embeddings - draft_len )
1247+ token_num = min (
1248+ token_num ,
1249+ max_position_embeddings - runtime_draft_token_buffer_width )
12391250
12401251 assert token_num > num_extra_decoding_steps , (
12411252 "Cannot fuse drafting loop. Not enough KV cache space for all draft tokens."
@@ -1246,7 +1257,7 @@ def _create_cuda_graph_warmup_request(
12461257 request_ids = [batch_size - 1 ],
12471258 token_nums = [token_num ],
12481259 is_gen = True ,
1249- max_num_draft_tokens = draft_len ,
1260+ max_num_draft_tokens = runtime_draft_token_buffer_width ,
12501261 use_mrope = self .use_mrope ,
12511262 max_beam_width = self .max_beam_width ,
12521263 num_extra_decoding_steps = num_extra_decoding_steps ,
@@ -1968,8 +1979,10 @@ def _update_target_input_tensors(
19681979 non_blocking = True )
19691980
19701981 # Prepare draft tokens
1982+ num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1
19711983 self .draft_tokens_cuda [:previous_batch_draft_tokens ].copy_ (
1972- next_draft_tokens_device [previous_slots , :].flatten (),
1984+ next_draft_tokens_device [
1985+ previous_slots , :num_draft_tokens_per_extend_request ].flatten (),
19731986 non_blocking = True )
19741987
19751988 # Compute kv_len_offsets and update offset tensors
@@ -2005,8 +2018,10 @@ def _apply_incremental_update_target(
20052018 # Pre-compute constants
20062019 extend_requests = scheduled_requests .generation_requests
20072020 num_extend_requests = len (extend_requests )
2008- num_tokens_per_extend_request = self .runtime_draft_len + 1
20092021 spec_config = self .spec_config
2022+ num_tokens_per_extend_request = self .get_runtime_tokens_per_gen_step (
2023+ self .runtime_draft_len )
2024+ runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1
20102025
20112026 prompt_lengths = torch .empty (num_extend_requests ,
20122027 dtype = torch .int ,
@@ -2068,7 +2083,8 @@ def _apply_incremental_update_target(
20682083 prompt_lengths = prompt_lengths .tolist ()
20692084 num_cached_tokens_per_seq = num_cached_tokens_per_seq .tolist ()
20702085
2071- previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self .runtime_draft_len
2086+ previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy *
2087+ runtime_draft_token_buffer_width )
20722088
20732089 self ._update_target_input_tensors (
20742090 num_accepted_tokens_device = num_accepted_tokens_device ,
@@ -2347,6 +2363,9 @@ def _prepare_tp_inputs(
23472363 # will contain previous batch indices of generation requests
23482364 previous_batch_indices = []
23492365 previous_pos_indices = []
2366+ runtime_tokens_per_gen_step = self .get_runtime_tokens_per_gen_step (
2367+ self .runtime_draft_len )
2368+ runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
23502369 for request in extend_requests :
23512370 request_ids .append (request .py_request_id )
23522371 request_accepted_path [
@@ -2405,16 +2424,16 @@ def _prepare_tp_inputs(
24052424 previous_batch_idx = request .py_batch_idx
24062425 request .py_batch_idx = request .py_seq_slot
24072426
2408- sequence_lengths .append (1 + self . runtime_draft_len )
2427+ sequence_lengths .append (runtime_tokens_per_gen_step )
24092428 num_accepted_draft_tokens .append (
24102429 request .py_num_accepted_draft_tokens )
24112430 past_seen_token_num = request .max_beam_num_tokens - 1
24122431
2413- draft_lens .append (self . runtime_draft_len )
2432+ draft_lens .append (runtime_draft_token_buffer_width )
24142433 gather_ids .extend (
24152434 list (
24162435 range (len (position_ids ),
2417- len (position_ids ) + 1 + self . runtime_draft_len )))
2436+ len (position_ids ) + runtime_tokens_per_gen_step )))
24182437 # For the target model + tree decoding
24192438 if not self .is_draft_model and not spec_config .is_linear_tree :
24202439 assert spec_tree_manager is not None
@@ -2427,19 +2446,19 @@ def _prepare_tp_inputs(
24272446 position_ids .extend (
24282447 list (
24292448 range (
2430- past_seen_token_num , past_seen_token_num + 1 +
2431- self . runtime_draft_len )))
2449+ past_seen_token_num , past_seen_token_num +
2450+ runtime_tokens_per_gen_step )))
24322451 # previous tensor
24332452 previous_batch_indices .append (previous_batch_idx )
24342453 previous_pos_indices .extend ([previous_batch_idx ] *
2435- ( 1 + self . runtime_draft_len ) )
2454+ runtime_tokens_per_gen_step )
24362455
24372456 num_cached_tokens_per_seq .append (past_seen_token_num +
2438- self . runtime_draft_len + 1 )
2457+ runtime_tokens_per_gen_step )
24392458 request .cached_tokens = num_cached_tokens_per_seq [- 1 ]
24402459 if self .enable_spec_decode and spec_config .spec_dec_mode .extend_ctx (
24412460 self .attn_backend ) and spec_config .is_linear_tree :
2442- prompt_lengths .append (1 + self . runtime_draft_len )
2461+ prompt_lengths .append (runtime_tokens_per_gen_step )
24432462 else :
24442463 prompt_lengths .append (request .py_prompt_len )
24452464
@@ -2740,30 +2759,36 @@ def previous_seq_slots_device():
27402759 # Initialize these two values to zeros
27412760 self .previous_pos_id_offsets_cuda *= 0
27422761 self .previous_kv_lens_offsets_cuda *= 0
2762+ runtime_tokens_per_gen_step = self .get_runtime_tokens_per_gen_step (
2763+ self .runtime_draft_len )
2764+ runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
27432765
27442766 if previous_batch_len > 0 :
27452767 previous_slots = previous_seq_slots_device ()
27462768 # previous input ids
2747- previous_batch_tokens = previous_batch_len * (
2748- 1 + self . runtime_draft_len )
2769+ previous_batch_tokens = ( previous_batch_len *
2770+ runtime_tokens_per_gen_step )
27492771 new_tokens = new_tokens_device .transpose (
27502772 0 ,
2751- 1 )[previous_slots , :( 1 + self . runtime_draft_len ) ].flatten ()
2773+ 1 )[previous_slots , :runtime_tokens_per_gen_step ].flatten ()
27522774 self .input_ids_cuda [num_tokens :num_tokens +
27532775 previous_batch_tokens ].copy_ (
27542776 new_tokens , non_blocking = True )
27552777
27562778 # previous draft tokens
2757- previous_batch_draft_tokens = previous_batch_len * self .runtime_draft_len
2758- if self .runtime_draft_len > 0 :
2759- self .draft_tokens_cuda [num_draft_tokens :num_draft_tokens +
2760- previous_batch_draft_tokens ].copy_ (
2761- next_draft_tokens_device [
2762- previous_slots , :self .
2763- runtime_draft_len ].flatten (),
2764- non_blocking = True )
2779+ previous_batch_draft_tokens = (previous_batch_len *
2780+ runtime_draft_token_buffer_width )
2781+ if runtime_draft_token_buffer_width > 0 :
2782+ self .draft_tokens_cuda [
2783+ num_draft_tokens :num_draft_tokens +
2784+ previous_batch_draft_tokens ].copy_ (
2785+ next_draft_tokens_device [
2786+ previous_slots , :
2787+ runtime_draft_token_buffer_width ].flatten (),
2788+ non_blocking = True )
27652789 # prepare data for the preprocess inputs
2766- kv_len_offsets_device = new_tokens_lens_device - self .runtime_draft_len - 1
2790+ kv_len_offsets_device = (new_tokens_lens_device -
2791+ runtime_tokens_per_gen_step )
27672792 previous_pos_indices_host = torch .tensor (
27682793 previous_pos_indices ,
27692794 dtype = torch .int ,
@@ -2789,8 +2814,8 @@ def previous_seq_slots_device():
27892814 extend_dummy_requests )
27902815 self .previous_pos_id_offsets_cuda [
27912816 (num_extend_reqeust_wo_dummy - previous_batch_len ) *
2792- ( 1 + self . runtime_draft_len ) :num_extend_reqeust_wo_dummy *
2793- ( 1 + self . runtime_draft_len ) ].copy_ (
2817+ runtime_tokens_per_gen_step :num_extend_reqeust_wo_dummy *
2818+ runtime_tokens_per_gen_step ].copy_ (
27942819 new_tokens_lens_device [self .previous_pos_indices_cuda [
27952820 0 :previous_batch_tokens ]],
27962821 non_blocking = True )
@@ -3626,6 +3651,8 @@ def forward(self,
36263651 # Propagate runtime_draft_len (already set on self by py_executor)
36273652 # to spec_metadata so downstream code (eagle3, interface, trtllm) can read it.
36283653 spec_metadata .runtime_draft_len = self .runtime_draft_len
3654+ spec_metadata .runtime_tokens_per_gen_step = (
3655+ self .get_runtime_tokens_per_gen_step (self .runtime_draft_len ))
36293656
36303657 attn_metadata .update_spec_dec_param (
36313658 batch_size = scheduled_requests .batch_size ,
0 commit comments