@@ -498,7 +498,6 @@ def __init__(
498498 sparse_attention_config = self .sparse_attention_config )
499499
500500 if self .is_spec_decode :
501- self .spec_metadata = None
502501 update_spec_config_from_model_config (self .spec_config ,
503502 self .model .config )
504503 max_num_draft_tokens = self .max_draft_loop_tokens * self .batch_size
@@ -552,6 +551,7 @@ def __init__(
552551 # the model engine.
553552 self .attn_metadata = None
554553 self .encoder_attn_metadata = None
554+ self .spec_metadata = None
555555 self .iter_states = {}
556556 self ._cuda_graph_mem_pool = self ._torch_compile_backend ._graph_pool_handle if self ._torch_compile_enabled else None
557557
@@ -1350,33 +1350,70 @@ def _capture_generation_cuda_graphs(self,
13501350 else :
13511351 max_seq_len_list = [effective_max_seq_len ]
13521352
1353- for bs , draft_len in graphs_to_capture :
1354- if bs > self .batch_size :
1355- continue
1356-
1357- for max_seq_len in max_seq_len_list :
1358- warmup_request = self ._create_cuda_graph_warmup_request (
1359- resource_manager , bs , draft_len , max_seq_len )
1360- with self ._release_batch_context (warmup_request ,
1361- resource_manager ) as batch :
1362- if batch is None :
1363- # No KV cache space, cannot continue capturing graphs
1353+ def _run_capture_pass (force_non_greedy : bool , label : str ) -> None :
1354+ spec_metadata = self .spec_metadata
1355+ if force_non_greedy and spec_metadata is not None :
1356+ spec_metadata ._force_non_greedy_for_capture = True
1357+ # maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample
1358+ # to build the graph cache key BEFORE populate runs inside
1359+ # _prepare_inputs. Pre-flip it here so the very first capture
1360+ # in this pass uses the non-greedy key; populate's override
1361+ # below will keep it False on every subsequent iteration.
1362+ spec_metadata .is_all_greedy_sample = False
1363+ try :
1364+ for bs , draft_len in graphs_to_capture :
1365+ if bs > self .batch_size :
13641366 continue
1365- logger .info (
1366- f"Run generation-only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } , max_seq_len={ max_seq_len } "
1367- )
1368- self .enable_spec_decode = draft_len > 0 or self .is_draft_model or (
1369- self .spec_config is not None
1370- and self .spec_config .spec_dec_mode .use_one_engine ())
1371- self ._update_draft_inference_state_for_warmup (
1372- batch , draft_len > 0 , resource_manager )
1373- self .runtime_draft_len = draft_len
1374- self .forward (batch ,
1375- new_tensors_device = None ,
1376- resource_manager = resource_manager )
1377- torch .cuda .synchronize ()
1367+
1368+ for max_seq_len in max_seq_len_list :
1369+ warmup_request = self ._create_cuda_graph_warmup_request (
1370+ resource_manager , bs , draft_len , max_seq_len )
1371+ with self ._release_batch_context (
1372+ warmup_request , resource_manager ) as batch :
1373+ if batch is None :
1374+ # No KV cache space, cannot continue capturing graphs
1375+ continue
1376+ logger .info (
1377+ f"Run generation-only CUDA graph warmup ({ label } ) "
1378+ f"for batch size={ bs } , draft_len={ draft_len } , "
1379+ f"max_seq_len={ max_seq_len } " )
1380+ self .enable_spec_decode = draft_len > 0 or self .is_draft_model or (
1381+ self .spec_config is not None and
1382+ self .spec_config .spec_dec_mode .use_one_engine ())
1383+ self ._update_draft_inference_state_for_warmup (
1384+ batch , draft_len > 0 , resource_manager )
1385+ self .runtime_draft_len = draft_len
1386+ self .forward (batch ,
1387+ new_tensors_device = None ,
1388+ resource_manager = resource_manager )
1389+ torch .cuda .synchronize ()
1390+ finally :
1391+ if force_non_greedy and spec_metadata is not None :
1392+ spec_metadata ._force_non_greedy_for_capture = False
1393+
1394+ # Pass 1: greedy fast-path (dummy requests carry no sampling params,
1395+ # so is_all_greedy_sample is naturally True).
1396+ _run_capture_pass (force_non_greedy = False , label = "greedy" )
1397+ # Pass 2: advanced sampling variant. Required because on-the-fly capture
1398+ # is disabled outside warmup, so any inference batch that contains a
1399+ # non-greedy request would otherwise fall back to eager. Only meaningful
1400+ # for one-engine spec dec (where is_all_greedy_sample participates in
1401+ # the graph key); other paths default to True and would never key into
1402+ # this variant.
1403+ needs_non_greedy_capture = (
1404+ self .spec_config is not None
1405+ and self .spec_config .spec_dec_mode .use_one_engine ())
1406+ if needs_non_greedy_capture :
1407+ _run_capture_pass (force_non_greedy = True , label = "advanced sampling" )
13781408 # Set the value back to the original value after cuda graph warmups are complete
13791409 self .enable_spec_decode = self .is_spec_decode
1410+ # The advanced-sampling capture pass above leaves is_all_greedy_sample
1411+ # set to False on spec_metadata. Reset it to the default so the first
1412+ # real iteration's graph-key selection is not seeded with this
1413+ # capture-only value. (update_is_all_greedy_sample refreshes it every
1414+ # iteration; this is a defensive guard.)
1415+ if self .spec_metadata is not None :
1416+ self .spec_metadata .is_all_greedy_sample = True
13801417
13811418 def _capture_piecewise_cuda_graphs (self , resource_manager : ResourceManager ):
13821419 """Captures piecewise CUDA graphs for context/prefill steps via torch.compile."""
@@ -4887,6 +4924,17 @@ def forward(self,
48874924 self .runtime_draft_len ) as padded_requests :
48884925 self ._pad_batch_seed_mrope_delta_cache (padded_requests )
48894926
4927+ # Refresh is_all_greedy_sample for the *current* batch BEFORE the
4928+ # CUDA graph key is built below. The key includes this flag to pick
4929+ # the argmax vs advanced-sampling graph variant; populate (inside
4930+ # _prepare_inputs) runs later and fills the matching GPU buffers.
4931+ # Without this pre-scan the key would use the previous iteration's
4932+ # stale value and could replay the advanced graph against
4933+ # unpopulated (greedy) buffers, hanging the run (e.g. MTP nextn>=2).
4934+ if spec_metadata is not None :
4935+ spec_metadata .update_is_all_greedy_sample (
4936+ padded_requests .all_requests ())
4937+
48904938 maybe_attn_metadata , maybe_spec_metadata , key = self .cuda_graph_runner .maybe_get_cuda_graph (
48914939 padded_requests ,
48924940 enable_spec_decode = self .enable_spec_decode ,
0 commit comments