From 7353cbc81ed213935c72de63910c42c04c315235 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Mon, 16 Mar 2026 23:36:16 +0000 Subject: [PATCH 01/17] xSupport MTP, MTP-Eagle, PARD. Signed-off-by: Zheyu Fu Clear naming. Signed-off-by: Zheyu Fu Fix CI Signed-off-by: Zheyu Fu Add draft_target support Signed-off-by: Zheyu Fu Add SA and SA hybrid support Signed-off-by: Zheyu Fu --- .../_torch/attention_backend/trtllm.py | 11 +- .../_torch/pyexecutor/cuda_graph_runner.py | 13 +- .../_torch/pyexecutor/model_engine.py | 116 +++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 40 ++- .../_torch/speculative/draft_target.py | 37 +- tensorrt_llm/_torch/speculative/eagle3.py | 4 +- tensorrt_llm/_torch/speculative/interface.py | 12 +- tensorrt_llm/_torch/speculative/mtp.py | 127 ++++--- tensorrt_llm/_torch/speculative/pard.py | 44 +-- tensorrt_llm/_torch/speculative/sa_worker.py | 33 +- tensorrt_llm/llmapi/llm_args.py | 8 + .../defs/accuracy/references/gsm8k.yaml | 2 + .../defs/accuracy/test_llm_api_pytorch.py | 321 +++++++++++++++++- .../test_lists/qa/llm_function_core.txt | 16 + tests/unittest/_torch/speculative/test_mtp.py | 3 + 15 files changed, 636 insertions(+), 151 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index ac8bf3c8444d..847101b74f8d 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1657,15 +1657,14 @@ def update_spec_dec_param( # Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length. # So we create cache for position offsets and packed mask for each draft length to avoid reallocation. assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree" - runtime_draft_len = (spec_metadata.runtime_draft_len - if spec_metadata is not None else - max_draft_len) + runtime_draft_token_buffer_width = ( + spec_metadata.runtime_tokens_per_gen_step - 1) self.generate_spec_decoding_generation_length( - runtime_draft_len=runtime_draft_len) + runtime_draft_len=runtime_draft_token_buffer_width) self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets( - self.max_num_requests, runtime_draft_len) + self.max_num_requests, runtime_draft_token_buffer_width) self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask( - self.max_num_requests, runtime_draft_len) + self.max_num_requests, runtime_draft_token_buffer_width) def generate_spec_decoding_generation_length(self, runtime_draft_len): self.spec_decoding_generation_lengths[:self.max_num_requests].fill_( diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 69d7a9af59a2..88eb81397bb4 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -118,8 +118,10 @@ def __init__(self, config: CUDAGraphRunnerConfig): def _create_shared_static_tensors(self): """Allocates static tensors sized for the largest possible batch.""" - max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0 - token_per_request = max_draft_len + 1 + runtime_draft_token_buffer_width = ( + self.config.original_max_total_draft_tokens + if self.config.spec_config is not None else 0) + token_per_request = runtime_draft_token_buffer_width + 1 max_total_tokens = (self.max_supported_batch_size * self.max_beam_width * token_per_request) max_total_tokens = min(max_total_tokens, self.config.max_num_tokens) @@ -443,6 +445,11 @@ def _get_padded_batch(self, batch: ScheduledRequests, if padding_size + batch.batch_size > self.config.batch_size: return 0 + runtime_tokens_per_gen_step = ( + self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len) + if self.spec_config is not None else 1 + runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 + # No padding if it would create too many concurrent requests. # This is not strictly required, but we should probably # respect the requirement just in case that changes in the future. @@ -460,7 +467,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, dummy_request = kv_cache_manager.add_dummy_requests( [dummy_request_id], is_gen=True, - max_num_draft_tokens=runtime_draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.config.use_mrope, max_beam_width=self.config.max_beam_width, draft_kv_cache_manager=draft_kv_cache_manager) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a32734bd599e..e5dc6b3f9b77 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -182,7 +182,9 @@ def __init__( spec_config.tokens_per_gen_step - 1) if spec_config is not None else 0 # Saved before zeroing for draft models; used by update_spec_dec_param. - self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 + self._spec_dec_max_total_draft_tokens = ( + spec_config.max_total_draft_tokens + if spec_config is not None else 0) preserve_wrapped_eagle3_widths = (spec_config is not None and is_draft_model @@ -341,11 +343,14 @@ def __init__( self.llm_args.attn_backend, sparse_attn_config=self.sparse_attention_config) + self.get_runtime_tokens_per_gen_step = spec_config.get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len: 1 + if self.is_spec_decode: self.spec_metadata = None update_spec_config_from_model_config(self.spec_config, self.model.config) - max_num_draft_tokens = self.original_max_total_draft_tokens * self.batch_size + max_num_draft_tokens = (self.original_max_total_draft_tokens * + self.batch_size) self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ), dtype=torch.int, device='cuda') @@ -915,12 +920,31 @@ def _get_graphs_to_capture( ): graphs = [(graph_bs, draft_len) for graph_bs, draft_len in self._dynamic_draft_len_mapping.items()] + # Workaround for dynamic draft length: + # capture the maximum speculative graph shape up front. Dynamic draft length + # breaks the previous assumption that attention workspace demand can be safely + # ordered by batch size alone; a later graph shape may require a larger shared + # graph workspace, and resizing that workspace can change its data_ptr and + # invalidate pointers captured by earlier graphs, causing illegal memory access + # on replay. + # + # This adds the overhead of one extra captured graph, and that graph is not + # expected to be used by the normal schedule-driven dynamic draft-length path. + # + # Follow-up first-principles fix: + # query or precompute the exact attention workspace requirement for all + # reachable graph shapes, pre-size the shared graph workspace once without + # capturing an extra graph, and avoid resizing it in graph mode afterward. + max_spec_graph = (max(cuda_graph_batch_sizes), + self.original_max_draft_len) + if max_spec_graph not in graphs: + graphs.append(max_spec_graph) logger.info(f"Dynamic draft length enabled for one-model path. " f"Capturing {len(graphs)} graphs: {graphs}") return graphs # Case 3: Target model (two-model) or one-model without dynamic draft - draft_lengths = [self.max_total_draft_tokens] + draft_lengths = [self.max_draft_len] should_capture_no_spec = ( self.max_total_draft_tokens > 0 and not self.spec_config.spec_dec_mode.use_one_engine() @@ -1245,12 +1269,15 @@ def _create_cuda_graph_warmup_request( result = ScheduledRequests() num_extra_decoding_steps = self._get_num_extra_decoding_steps() + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 # Add (batch_size - 1) dummy requests with seq_len=1. requests = kv_cache_manager.add_dummy_requests( list(range(batch_size - 1)), is_gen=True, - max_num_draft_tokens=draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps, @@ -1267,26 +1294,29 @@ def _create_cuda_graph_warmup_request( available_tokens = kv_cache_manager.get_num_available_tokens( token_num_upper_bound=max_seq_len, batch_size=batch_size, - max_num_draft_tokens=draft_len) + max_num_draft_tokens=runtime_draft_token_buffer_width) # Also consider draft KV cache capacity when it exists if draft_kv_cache_manager is not None: draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens( batch_size=batch_size, token_num_upper_bound=max_seq_len, - max_num_draft_tokens=draft_len) + max_num_draft_tokens=runtime_draft_token_buffer_width) available_tokens = min(available_tokens, draft_available_tokens) token_num = max( 1, min( - available_tokens, max_seq_len - 1 - - get_num_extra_kv_tokens(self.spec_config) - draft_len)) + available_tokens, + max_seq_len - 1 - get_num_extra_kv_tokens(self.spec_config) - + runtime_draft_token_buffer_width)) model_config = self.model.model_config.pretrained_config max_position_embeddings = getattr(model_config, 'max_position_embeddings', None) if max_position_embeddings is not None: - token_num = min(token_num, max_position_embeddings - draft_len) + token_num = min( + token_num, + max_position_embeddings - runtime_draft_token_buffer_width) assert token_num > num_extra_decoding_steps, ( "Cannot fuse drafting loop. Not enough KV cache space for all draft tokens." @@ -1297,7 +1327,7 @@ def _create_cuda_graph_warmup_request( request_ids=[batch_size - 1], token_nums=[token_num], is_gen=True, - max_num_draft_tokens=draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps, @@ -2024,8 +2054,10 @@ def _update_target_input_tensors( non_blocking=True) # Prepare draft tokens + num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1 self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_( - next_draft_tokens_device[previous_slots, :].flatten(), + next_draft_tokens_device[ + previous_slots, :num_draft_tokens_per_extend_request].flatten(), non_blocking=True) # Compute kv_len_offsets and update offset tensors @@ -2061,8 +2093,10 @@ def _apply_incremental_update_target( # Pre-compute constants extend_requests = scheduled_requests.generation_requests num_extend_requests = len(extend_requests) - num_tokens_per_extend_request = self.runtime_draft_len + 1 spec_config = self.spec_config + num_tokens_per_extend_request = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1 prompt_lengths = torch.empty(num_extend_requests, dtype=torch.int, @@ -2124,7 +2158,8 @@ def _apply_incremental_update_target( prompt_lengths = prompt_lengths.tolist() num_cached_tokens_per_seq = num_cached_tokens_per_seq.tolist() - previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self.runtime_draft_len + previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy * + runtime_draft_token_buffer_width) self._update_target_input_tensors( num_accepted_tokens_device=num_accepted_tokens_device, @@ -2407,6 +2442,9 @@ def _prepare_tp_inputs( # will contain previous batch indices of generation requests previous_batch_indices = [] previous_pos_indices = [] + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 for request in extend_requests: request_ids.append(request.py_request_id) request_accepted_path[ @@ -2465,16 +2503,16 @@ def _prepare_tp_inputs( previous_batch_idx = request.py_batch_idx request.py_batch_idx = request.py_seq_slot - sequence_lengths.append(1 + self.runtime_draft_len) + sequence_lengths.append(runtime_tokens_per_gen_step) num_accepted_draft_tokens.append( request.py_num_accepted_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 - draft_lens.append(self.runtime_draft_len) + draft_lens.append(runtime_draft_token_buffer_width) gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.runtime_draft_len))) + len(position_ids) + runtime_tokens_per_gen_step))) # For the target model + tree decoding if not self.is_draft_model and not spec_config.is_linear_tree: assert spec_tree_manager is not None @@ -2487,19 +2525,19 @@ def _prepare_tp_inputs( position_ids.extend( list( range( - past_seen_token_num, past_seen_token_num + 1 + - self.runtime_draft_len))) + past_seen_token_num, past_seen_token_num + + runtime_tokens_per_gen_step))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * - (1 + self.runtime_draft_len)) + runtime_tokens_per_gen_step) num_cached_tokens_per_seq.append(past_seen_token_num + - self.runtime_draft_len + 1) + runtime_tokens_per_gen_step) request.cached_tokens = num_cached_tokens_per_seq[-1] if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend) and spec_config.is_linear_tree: - prompt_lengths.append(1 + self.runtime_draft_len) + prompt_lengths.append(runtime_tokens_per_gen_step) else: prompt_lengths.append(request.py_prompt_len) @@ -2804,30 +2842,36 @@ def previous_seq_slots_device(): # Initialize these two values to zeros self.previous_pos_id_offsets_cuda *= 0 self.previous_kv_lens_offsets_cuda *= 0 + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 if previous_batch_len > 0: previous_slots = previous_seq_slots_device() # previous input ids - previous_batch_tokens = previous_batch_len * ( - 1 + self.runtime_draft_len) + previous_batch_tokens = (previous_batch_len * + runtime_tokens_per_gen_step) new_tokens = new_tokens_device.transpose( 0, - 1)[previous_slots, :(1 + self.runtime_draft_len)].flatten() + 1)[previous_slots, :runtime_tokens_per_gen_step].flatten() self.input_ids_cuda[num_tokens:num_tokens + previous_batch_tokens].copy_( new_tokens, non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len - if self.runtime_draft_len > 0: - self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + - previous_batch_draft_tokens].copy_( - next_draft_tokens_device[ - previous_slots, :self. - runtime_draft_len].flatten(), - non_blocking=True) + previous_batch_draft_tokens = (previous_batch_len * + runtime_draft_token_buffer_width) + if runtime_draft_token_buffer_width > 0: + self.draft_tokens_cuda[ + num_draft_tokens:num_draft_tokens + + previous_batch_draft_tokens].copy_( + next_draft_tokens_device[ + previous_slots, : + runtime_draft_token_buffer_width].flatten(), + non_blocking=True) # prepare data for the preprocess inputs - kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1 + kv_len_offsets_device = (new_tokens_lens_device - + runtime_tokens_per_gen_step) previous_pos_indices_host = torch.tensor( previous_pos_indices, dtype=torch.int, @@ -2853,8 +2897,8 @@ def previous_seq_slots_device(): extend_dummy_requests) self.previous_pos_id_offsets_cuda[ (num_extend_reqeust_wo_dummy - previous_batch_len) * - (1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy * - (1 + self.runtime_draft_len)].copy_( + runtime_tokens_per_gen_step:num_extend_reqeust_wo_dummy * + runtime_tokens_per_gen_step].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) @@ -3749,6 +3793,8 @@ def forward(self, # Propagate runtime_draft_len (already set on self by py_executor) # to spec_metadata so downstream code (eagle3, interface, trtllm) can read it. spec_metadata.runtime_draft_len = self.runtime_draft_len + spec_metadata.runtime_tokens_per_gen_step = ( + self.get_runtime_tokens_per_gen_step(self.runtime_draft_len)) # PARD has 2K tokens per gen request, not K+1. Pass 2K-1 # so generation_lengths = 2K and the XQA kernel computes diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5a99bd20189e..060eafd1e326 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1752,22 +1752,42 @@ def _handle_dynamic_draft_len(self, from tensorrt_llm._torch.speculative.utils import \ get_draft_len_for_batch_size + spec_dec_mode = self.model_engine.spec_config.spec_dec_mode + # 1. Resolve runtime draft length from schedule runtime_draft_len = get_draft_len_for_batch_size( self.model_engine.spec_config.draft_len_schedule, scheduled_batch.batch_size, self.model_engine.max_draft_len) - # 2. Pad or truncate draft tokens to the resolved length - PADDING_TOKEN = 0 + DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. for request in scheduled_batch.generation_requests: - current_draft_len = len(request.py_draft_tokens) - if current_draft_len < runtime_draft_len: - padding_needed = runtime_draft_len - current_draft_len - request.py_draft_tokens.extend([PADDING_TOKEN] * - padding_needed) - elif current_draft_len > runtime_draft_len: - request.py_draft_tokens = request.py_draft_tokens[: - runtime_draft_len] + current_num_draft_tokens = len(request.py_draft_tokens) + if spec_dec_mode.is_pard(): + # special case as PARD carries 2K-1 draft tokens per request + runtime_draft_token_buffer_width = ( + self.model_engine.spec_config. + get_runtime_tokens_per_gen_step(runtime_draft_len) - 1) + current_runtime_draft_len = ( + current_num_draft_tokens + + 1) // 2 if current_num_draft_tokens > 0 else 0 + real_draft_tokens = request.py_draft_tokens[:min( + current_runtime_draft_len, runtime_draft_len)] + real_draft_tokens.extend( + [DRAFT_BUFFER_PAD] * + (runtime_draft_len - len(real_draft_tokens))) + request.py_draft_tokens = real_draft_tokens + [ + DRAFT_BUFFER_PAD + ] * (runtime_draft_token_buffer_width - + len(real_draft_tokens)) + else: + if current_num_draft_tokens < runtime_draft_len: + padding_needed = (runtime_draft_len - + current_num_draft_tokens) + request.py_draft_tokens.extend([DRAFT_BUFFER_PAD] * + padding_needed) + elif current_num_draft_tokens > runtime_draft_len: + request.py_draft_tokens = request.py_draft_tokens[: + runtime_draft_len] self.model_engine.runtime_draft_len = runtime_draft_len else: diff --git a/tensorrt_llm/_torch/speculative/draft_target.py b/tensorrt_llm/_torch/speculative/draft_target.py index c026b6d5b290..fe6dc7d1e4c7 100644 --- a/tensorrt_llm/_torch/speculative/draft_target.py +++ b/tensorrt_llm/_torch/speculative/draft_target.py @@ -70,7 +70,7 @@ def prepare(self): num_seqs, dtype=torch.int, device="cpu", pin_memory=prefer_pinned() ) self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True) - self.num_tokens -= self.num_generations * self.max_draft_len + self.num_tokens -= self.num_generations * self.runtime_draft_len self.is_spec_dec_tree = False self.is_spec_dec_dynamic_tree = False @@ -131,10 +131,11 @@ def _update_kv_after_first_draft_step( num_accepted_tokens: torch.Tensor, num_contexts: int, batch_size: int, + runtime_draft_len: int, ): if hasattr(attn_metadata, "kv_lens_cuda"): attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - self.max_draft_len - num_accepted_tokens[num_contexts:batch_size] + runtime_draft_len - num_accepted_tokens[num_contexts:batch_size] ) attn_metadata.kv_lens_cuda[:num_contexts] += 1 @@ -175,6 +176,18 @@ def forward( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len + + if runtime_draft_len == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) raw_logits = logits @@ -204,10 +217,10 @@ def forward( draft_kv_cache_manager = self.get_draft_kv_cache_manager(resource_manager) with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i in range(self.max_draft_len): + for i in range(runtime_draft_len): if i == 0: start_ids_gen = ( - spec_metadata.batch_indices_cuda[:num_gens] * (self.max_draft_len + 1) + spec_metadata.batch_indices_cuda[:num_gens] * (runtime_draft_len + 1) ).long() gather_ids_gen = ( start_ids_gen @@ -260,7 +273,11 @@ def forward( attn_metadata.host_request_types[: attn_metadata.num_contexts].fill_(1) attn_metadata.num_contexts = 0 self._update_kv_after_first_draft_step( - attn_metadata, num_accepted_tokens, num_contexts, batch_size + attn_metadata, + num_accepted_tokens, + num_contexts, + batch_size, + runtime_draft_len, ) else: self._update_kv_for_chained_draft_step(attn_metadata, batch_size) @@ -306,13 +323,14 @@ def sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len if spec_metadata.draft_tokens is None: draft_tokens = torch.zeros( - (num_gens, self.max_draft_len), dtype=torch.int, device=logits.device + (num_gens, runtime_draft_len), dtype=torch.int, device=logits.device ) else: - draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, self.max_draft_len) + draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, runtime_draft_len) return self._sample_and_accept_draft_tokens_base( logits, draft_tokens, num_contexts, batch_size, spec_metadata @@ -337,6 +355,7 @@ def prepare_1st_drafter_inputs( num_contexts = attn_metadata.num_contexts batch_size = attn_metadata.num_seqs num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len if num_contexts > 0: input_ids_ctx = self._prepare_context_input_ids( @@ -350,7 +369,9 @@ def prepare_1st_drafter_inputs( input_ids_ctx = torch.empty(0, dtype=torch.int32, device="cuda") if num_gens > 0: - input_ids_gen = accepted_tokens[num_contexts:, :].flatten().to(torch.int32) + input_ids_gen = ( + accepted_tokens[num_contexts:, : runtime_draft_len + 1].flatten().to(torch.int32) + ) else: input_ids_gen = torch.empty(0, dtype=torch.int32, device="cuda") diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 0211f3e7c989..2fe02dca7eb2 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -364,7 +364,7 @@ def prepare(self): if sa_manager is not None: gen_request_ids = self.request_ids[num_seqs - self.num_generations:] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) def maybe_capture_hidden_states( self, @@ -467,7 +467,7 @@ def forward(self, num_accepted_tokens=num_accepted_tokens, num_gens=num_gens, num_contexts=num_contexts, - max_draft_len=self.max_draft_len, + max_draft_len=runtime_draft_len, ) # Save the old attn_metadata and spec_metadata diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 081030f482cb..959dd9ec6711 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -225,8 +225,8 @@ def support_capturable_guided_decoder(self): ) or self.is_external_drafter() or self.is_sa() def support_dynamic_draft_len(self): - # TODO: expand to all one-model algorithms - return self.is_eagle3_one_model() + return self.is_mtp_one_model() or self.is_eagle3_one_model( + ) or self.is_pard() or self.is_draft_target_one_model() or self.is_sa() def has_draft_model(self): return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle() @@ -365,6 +365,9 @@ class SpecMetadata: # draft_len_schedule. Otherwise it equals max_draft_len (the static max). # Always set by model_engine.forward() before any downstream code reads it. runtime_draft_len: int = 0 + # Total runtime tokens per generation request for the current iteration, + # Normally, it equals 1 + runtime_draft_len. But for PARD, it equals 2 * runtime_draft_len. + runtime_tokens_per_gen_step: int = 1 # For non-greedy sampling on 1-model. allow_advanced_sampling: bool = False @@ -658,9 +661,8 @@ def _sample_and_accept_draft_tokens_base( num_accepted_tokens: [batch_size] - Number of accepted tokens per request """ # Derive draft length from the actual draft_tokens shape rather than - # spec_metadata.runtime_draft_len, because they can differ: PARD sets - # runtime_draft_len = 2K-1 for input sizing but only passes K draft - # tokens for acceptance; + # spec_metadata.runtime_draft_len, because callers may slice a wider + # runtime token layout down to the K draft tokens used for acceptance. runtime_draft_len = draft_tokens.shape[-1] num_gens = batch_size - num_contexts diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 8baf2da76615..36348ad3ebf7 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -227,7 +227,7 @@ def prepare(self): num_contexts = num_seqs - self.num_generations gen_request_ids = self.request_ids[num_contexts:] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) class MTPSampler(SpecSamplerBase): @@ -341,7 +341,7 @@ def forward( - hidden states: H_E, H_F, H_G, H_H (H_H is invalid) Draft model: MTP1: - # For generation request, `mtp_num_modules` of tokens will be used as input. + # For generation request, `runtime_draft_len` tokens are used as input. - input tokens: FGX - input hidden states: H_E, H_F, H_G - KV cache: (BCDE) + FGX @@ -392,6 +392,13 @@ def forward( - new generated draft tokens: UVQ ''' + runtime_draft_len = spec_metadata.runtime_draft_len + # skip the draft forward if the runtime draft length is 0 + if runtime_draft_len == 0: + return self.skip_drafting(input_ids, position_ids, hidden_states, + logits, attn_metadata, spec_metadata, + draft_model) + batch_size = attn_metadata.num_seqs raw_logits = logits @@ -422,7 +429,8 @@ def forward( # update attn metadata if attn_metadata is not None: - self.change_attn_metadata(num_accepted_tokens, attn_metadata) + self.change_attn_metadata(num_accepted_tokens, attn_metadata, + spec_metadata) # Run MTP layers to predict draft tokens next_draft_tokens = [] @@ -433,7 +441,8 @@ def forward( resource_manager) with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i, mtp_layer in enumerate(draft_model.mtp_layers): + for i, mtp_layer in enumerate( + draft_model.mtp_layers[:runtime_draft_len]): if self.guided_decoder is not None: new_tokens = draft_inputs['input_ids'][last_tokens_idx] self.guided_decoder.add_draft_batch(new_tokens, @@ -506,17 +515,17 @@ def skip_forward( resource_manager=None, ): batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.num_nextn_predict_layers - accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + runtime_draft_len = spec_metadata.runtime_draft_len + accepted_tokens = torch.empty((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) - next_draft_tokens = torch.empty((batch_size, mtp_num_modules), + next_draft_tokens = torch.empty((batch_size, runtime_draft_len), dtype=torch.int, device=logits.device) - next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + next_new_tokens = torch.empty((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) return { @@ -589,14 +598,15 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): seq_lens = attn_metadata.seq_lens_cuda seq_lens_cpu = attn_metadata.seq_lens hidden_size = hidden_states.shape[-1] - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len + max_draft_len = self.spec_config.num_nextn_predict_layers if self.is_thop: _, _ = torch.ops.trtllm.mtp_update_hidden_states_op( input_ids, seq_lens, hidden_states, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, - mtp_num_modules, batch_size, num_contexts, hidden_size) + runtime_draft_len, batch_size, num_contexts, hidden_size) else: assert len(spec_metadata.request_ids) == batch_size mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool @@ -625,7 +635,7 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): dim=1) ctx_batch_idx = spec_metadata.batch_indices_cuda[:num_contexts] row_indices_ctx = ctx_batch_idx.unsqueeze(1).expand( - -1, mtp_num_modules) + -1, max_draft_len) col_indices_ctx = (seq_lens_ctx.unsqueeze(1) + spec_metadata.draft_token_indices_cuda) new_mtp_past_tokens.append(cat_tokens_ctx[row_indices_ctx, @@ -636,10 +646,10 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): # generation if num_gens > 0: unpacked_input_ids_gen = input_ids[num_ctx_tokens:].reshape( - num_gens, mtp_num_modules + 1).int() + num_gens, runtime_draft_len + 1).int() hidden_states_gen = hidden_states[num_ctx_tokens:, :] unpacked_hidden_states_gen = hidden_states_gen.reshape( - num_gens, mtp_num_modules + 1, hidden_size) + num_gens, runtime_draft_len + 1, hidden_size) cat_tokens_gen = torch.cat( (mtp_tokens[num_contexts:], unpacked_input_ids_gen), dim=1) cat_hidden_states_gen = torch.cat( @@ -648,10 +658,10 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): dim=1) gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] row_indices_gen = gen_batch_idx.unsqueeze(1).expand( - -1, mtp_num_modules) + -1, max_draft_len) col_indices_gen = ( num_accepted_tokens[num_contexts:].unsqueeze(1) + - spec_metadata.draft_token_indices_cuda) + spec_metadata.draft_token_indices_cuda[:max_draft_len]) new_mtp_past_tokens.append(cat_tokens_gen[row_indices_gen, col_indices_gen]) new_mtp_past_hidden_states.append( @@ -666,17 +676,17 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): new_mtp_past_hidden_states) @torch.compile(options={"max-autotune": True}) - def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules, + def topk_kernel(self, gen_logprobs, num_gens, runtime_draft_len, spec_metadata): topk_value, topk_indices = torch.topk(gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1) - topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1, + topk_indices = topk_indices.reshape(num_gens, runtime_draft_len + 1, self.spec_config.relaxed_topk) - topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1, + topk_value = topk_value.reshape(num_gens, runtime_draft_len + 1, self.spec_config.relaxed_topk) draft_tokens = spec_metadata.draft_tokens.reshape( - num_gens, mtp_num_modules) + num_gens, runtime_draft_len) return topk_value, topk_indices, draft_tokens @torch.compile(options={"max-autotune": True}) @@ -761,14 +771,14 @@ def sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len if logits.dim() == 1: logits = logits.unsqueeze(0) # The return buffer if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop: - accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)), + accepted_tokens = torch.ones((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, @@ -804,41 +814,40 @@ def sample_and_accept_draft_tokens( # generation gen_logprobs = self.process_generation_logits(logits, num_contexts) topk_value, topk_indices, draft_tokens = self.topk_kernel( - gen_logprobs, num_gens, mtp_num_modules, spec_metadata) + gen_logprobs, num_gens, runtime_draft_len, spec_metadata) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op( spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens, mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens, - mtp_num_modules, batch_size, num_contexts, + runtime_draft_len, batch_size, num_contexts, self.spec_config.relaxed_topk, self.spec_config.relaxed_delta, self.spec_config.begin_thinking_phase_token, self.spec_config.end_thinking_phase_token) # Apply force override for relaxed acceptance path num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, - spec_metadata.runtime_draft_len) + num_accepted_tokens, num_contexts, runtime_draft_len) # Strict acceptance else: if self.is_thop: # Temporary buffer target_tokens_cache = torch.zeros(batch_size * - (mtp_num_modules + 1), + (runtime_draft_len + 1), dtype=torch.int, device=logits.device) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( logits, spec_metadata.draft_tokens, target_tokens_cache, - mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) + runtime_draft_len, batch_size, num_contexts, + logits.shape[-1]) # Apply force override for THOP path num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, - spec_metadata.runtime_draft_len) + num_accepted_tokens, num_contexts, runtime_draft_len) else: # Reshape draft tokens for base implementation draft_tokens = spec_metadata.draft_tokens.reshape( - num_gens, mtp_num_modules) + num_gens, runtime_draft_len) # Use base implementation for strict acceptance accepted_tokens, num_accepted_tokens = self._sample_and_accept_draft_tokens_base( @@ -855,16 +864,17 @@ def sample_and_accept_draft_tokens( num_accepted_tokens=num_accepted_tokens, num_gens=num_gens, num_contexts=num_contexts, - max_draft_len=mtp_num_modules, + max_draft_len=runtime_draft_len, ) return accepted_tokens, num_accepted_tokens def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, - attn_metadata: AttentionMetadata): + attn_metadata: AttentionMetadata, + spec_metadata: MTPSpecMetadata): self._prepare_attn_metadata_for_spec_dec(attn_metadata) batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len num_contexts = attn_metadata.num_contexts attn_metadata._seq_lens[num_contexts:batch_size] -= 1 @@ -876,7 +886,7 @@ def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, # buffer once the graph has been captured also - this will invalidate # the graph and force an expensive recapture. attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - mtp_num_modules + 1 - + runtime_draft_len + 1 - num_accepted_tokens[num_contexts:batch_size]) attn_metadata.on_update_kv_lens() @@ -884,7 +894,7 @@ def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, for i in range(num_contexts, batch_size): # used for vanilla MLA, list on cpu attn_metadata.kv_cache_params.num_cached_tokens_per_seq[ - i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item() + i] -= runtime_draft_len + 1 - num_accepted_tokens[i].item() def prepare_drafter_inputs( self, @@ -902,8 +912,8 @@ def prepare_drafter_inputs( Args: input_ids: torch.IntTensor [num_tokens] - The input ids of all requests. Flattened. - num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) + The input ids of all requests. Flatten. + num_tokens = sum(all prompts) + num_generation * (runtime_draft_len + 1) position_ids: torch.IntTensor [1][num_tokens] @@ -930,8 +940,8 @@ def prepare_drafter_inputs( Returns: draft_inputs input_ids: torch.Tensor [num_tokens] - The new input ids of all requests. Flattened. - num_tokens = sum(all prompts) + num_generation * (mtp_num_modules) + The new input ids of all requests. Flatten. + num_tokens = sum(all prompts) + num_generation * (runtime_draft_len) position_ids: torch.Tensor [1, num_tokens] @@ -955,7 +965,7 @@ def prepare_drafter_inputs( num_gens = batch_size - num_contexts mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len if self.is_thop: # Temporary buffer @@ -977,7 +987,7 @@ def prepare_drafter_inputs( spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, hidden_states, accepted_tokens, num_accepted_tokens, return_input_ids, - return_hidden_states, mtp_num_modules, batch_size, + return_hidden_states, runtime_draft_len, batch_size, num_contexts, hidden_size) else: @@ -1002,11 +1012,18 @@ def prepare_drafter_inputs( accepted_tokens_gen = accepted_tokens[num_contexts:, :] input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx].unsqueeze(1) - input_ids_gen = torch.concat( - [mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen], - dim=1) + + if runtime_draft_len > 1: + history_tokens = mtp_past_tokens_pool[slot_ids][:, -( + runtime_draft_len - 1):] + else: + history_tokens = torch.empty((num_gens, 0), + dtype=torch.int, + device=input_ids.device) + input_ids_gen = torch.concat([history_tokens, input_ids_gen], + dim=1) hidden_states_gen = mtp_past_hidden_states_pool[ - slot_ids].flatten(0, 1) + slot_ids][:, -runtime_draft_len:, :].flatten(0, 1) return_input_ids_list.append(input_ids_gen.flatten(0, 1)) return_hidden_states_list.append(hidden_states_gen) # Concatenate into continuous buffers @@ -1020,9 +1037,9 @@ def prepare_drafter_inputs( position_ids_list.append(position_ids[:num_ctx_tokens]) if num_gens > 0: position_ids_gen = position_ids[num_ctx_tokens:].reshape( - num_gens, mtp_num_modules + 1)[:, -mtp_num_modules:] + num_gens, runtime_draft_len + 1)[:, -runtime_draft_len:] position_ids_gen = position_ids_gen - ( - 1 + mtp_num_modules - + 1 + runtime_draft_len - num_accepted_tokens[num_contexts:].unsqueeze(1)) position_ids_list.append(position_ids_gen.flatten()) return_position_ids = torch.concat(position_ids_list, dim=-1) @@ -1150,6 +1167,12 @@ def forward( draft_model, resource_manager=None, ): + runtime_draft_len = spec_metadata.runtime_draft_len + # skip the draft forward if the runtime draft length is 0 + if runtime_draft_len == 0: + return self.skip_drafting(input_ids, position_ids, hidden_states, + logits, attn_metadata, spec_metadata, + draft_model) batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts @@ -1191,7 +1214,7 @@ def forward( # Predict draft tokens next_draft_tokens = [] with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i in range(self.mtp_num_modules): + for i in range(runtime_draft_len): if i == 0: hidden_states = draft_model.mtp_layers[0]( embed_tokens=draft_model.embed_tokens, @@ -1200,7 +1223,7 @@ def forward( start_ids_gen = ( spec_metadata.batch_indices_cuda[:num_gens] * - (self.mtp_num_modules + 1)).long() + (runtime_draft_len + 1)).long() gather_ids_gen = (start_ids_gen + num_accepted_tokens[num_contexts:] - 1 + attn_metadata.num_ctx_tokens) @@ -1281,7 +1304,7 @@ def forward( # update kv_lens_cuda if hasattr(attn_metadata, 'kv_lens_cuda'): attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - self.mtp_num_modules - + runtime_draft_len - num_accepted_tokens[num_contexts:]) attn_metadata.kv_lens_cuda[:num_contexts] += 1 # update metadata for flash mla @@ -1364,6 +1387,7 @@ def prepare_drafter_inputs( spec_metadata: MTPSpecMetadata, ): num_contexts = attn_metadata.num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len # context input_ids_ctx = self._prepare_context_input_ids( @@ -1371,7 +1395,8 @@ def prepare_drafter_inputs( accepted_tokens, num_contexts) # generation - input_ids_gen = accepted_tokens[num_contexts:, :].flatten() + input_ids_gen = accepted_tokens[num_contexts:, :runtime_draft_len + + 1].flatten() # get draft inputs input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0) diff --git a/tensorrt_llm/_torch/speculative/pard.py b/tensorrt_llm/_torch/speculative/pard.py index 628de48022d7..34a4509314cd 100644 --- a/tensorrt_llm/_torch/speculative/pard.py +++ b/tensorrt_llm/_torch/speculative/pard.py @@ -48,7 +48,7 @@ def prepare(self): if sa_manager is not None: gen_request_ids = self.request_ids[num_seqs - self.num_generations :] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) def _get_sa_manager(self): """Get SA manager from spec_resource_manager. @@ -99,15 +99,6 @@ def __init__( def max_draft_len(self) -> int: return self.spec_config.max_draft_len - @property - def _draft_tokens_per_req(self) -> int: - """Total tokens per gen request in the draft forward. - - Uses 2K to fit all accepted tokens (up to K+1) plus K-1 mask tokens, - ensuring K unique predictions regardless of how many tokens were accepted. - """ - return 2 * self.max_draft_len - def _prepare_attn_metadata_for_pard(self, attn_metadata, spec_metadata): """ Save attn_metadata fields that PARD modifies during forward. @@ -190,13 +181,25 @@ def forward( num_gens = batch_size - num_contexts raw_logits = logits - K = self.max_draft_len + K = spec_metadata.runtime_draft_len + + if K == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) self._execute_guided_decoder_if_present(logits) # draft_tokens buffer has (2K-1) entries per gen request; extract the K real drafts if num_gens > 0: - draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, 2 * K - 1)[:, :K] + draft_tokens = spec_metadata.draft_tokens[: num_gens * (2 * K - 1)] + draft_tokens = draft_tokens.reshape(num_gens, 2 * K - 1)[:, :K] else: draft_tokens = spec_metadata.draft_tokens.reshape(0, K) @@ -262,14 +265,13 @@ def forward( gen_start_idx = attn_metadata.num_ctx_tokens request_bases = ( - torch.arange(num_gens, dtype=torch.long, device="cuda") - * self._draft_tokens_per_req + torch.arange(num_gens, dtype=torch.long, device="cuda") * (2 * K) + gen_start_idx ) gen_num_accepted = num_accepted_tokens[num_contexts:batch_size].long() base_offsets = gen_num_accepted - 1 # M = bonus position - offsets = torch.arange(self.max_draft_len, dtype=torch.long, device="cuda") + offsets = torch.arange(K, dtype=torch.long, device="cuda") gen_gather_ids = ( request_bases.unsqueeze(1) + base_offsets.unsqueeze(1) + offsets.unsqueeze(0) @@ -281,7 +283,7 @@ def forward( ) vocab_size = gen_logits.shape[-1] - gen_logits = gen_logits.reshape(num_gens, self.max_draft_len, vocab_size) + gen_logits = gen_logits.reshape(num_gens, K, vocab_size) # Use torch.argmax directly to avoid cute_argmax stride issues d2t = getattr(draft_model.model, "d2t", None) @@ -384,6 +386,8 @@ def prepare_1st_drafter_inputs( num_contexts = attn_metadata.num_contexts batch_size = attn_metadata.num_seqs num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len + total_tokens_per_req = 2 * runtime_draft_len if ( hasattr(self.spec_config, "mask_token_id") @@ -412,8 +416,6 @@ def prepare_1st_drafter_inputs( gen_num_accepted = num_accepted_tokens[num_contexts : num_contexts + num_gens] gen_accepted_tokens = accepted_tokens[num_contexts : num_contexts + num_gens, :] - total_tokens_per_req = self._draft_tokens_per_req # 2K - # Start with all mask tokens request_ids_2d = torch.full( (num_gens, total_tokens_per_req), @@ -452,9 +454,9 @@ def prepare_1st_drafter_inputs( - total_tokens_per_req ) else: - gen_pos_starts = position_ids[ - attn_metadata.num_ctx_tokens :: self._draft_tokens_per_req - ][:num_gens] + gen_pos_starts = position_ids[attn_metadata.num_ctx_tokens :: total_tokens_per_req][ + :num_gens + ] offsets = torch.arange(total_tokens_per_req, dtype=torch.int32, device="cuda") position_ids_gen = (gen_pos_starts.unsqueeze(1) + offsets.unsqueeze(0)).flatten() diff --git a/tensorrt_llm/_torch/speculative/sa_worker.py b/tensorrt_llm/_torch/speculative/sa_worker.py index f014423686f5..90cd617cef2a 100644 --- a/tensorrt_llm/_torch/speculative/sa_worker.py +++ b/tensorrt_llm/_torch/speculative/sa_worker.py @@ -82,7 +82,7 @@ def prepare(self) -> None: self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True) if self.sa_manager is not None: - self.sa_manager.prepare(self.request_ids, self.max_draft_len) + self.sa_manager.prepare(self.request_ids, self.runtime_draft_len) else: raise ValueError("SA manager is not set") @@ -160,6 +160,18 @@ def forward( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts raw_logits = logits + runtime_draft_len = spec_metadata.runtime_draft_len + + if runtime_draft_len == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) self._execute_guided_decoder_if_present(logits) @@ -206,6 +218,7 @@ def _sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len # Get draft tokens from spec_metadata (set during prepare) draft_tokens = spec_metadata.draft_tokens @@ -219,15 +232,15 @@ def _sample_and_accept_draft_tokens( if draft_tokens is None or draft_tokens.numel() == 0: use_zeros = True elif draft_tokens.dim() == 1: - # 1D tensor - try to reshape to [num_gens, max_draft_len] - expected_size = num_gens * self.max_draft_len + # 1D tensor - try to reshape to [num_gens, runtime_draft_len] + expected_size = num_gens * runtime_draft_len if draft_tokens.numel() == expected_size and num_gens > 0: - draft_tokens = draft_tokens.reshape(num_gens, self.max_draft_len) + draft_tokens = draft_tokens.reshape(num_gens, runtime_draft_len) else: use_zeros = True elif draft_tokens.dim() == 2: # 2D tensor - check shape - if draft_tokens.shape[-1] != self.max_draft_len: + if draft_tokens.shape[-1] != runtime_draft_len: use_zeros = True else: # Slice to get only generation requests' draft tokens @@ -238,7 +251,9 @@ def _sample_and_accept_draft_tokens( if use_zeros: # No valid draft tokens - create zeros for generation requests draft_tokens = torch.zeros( - (num_gens, self.max_draft_len), dtype=torch.int32, device=logits.device + (num_gens, runtime_draft_len), + dtype=torch.int32, + device=logits.device, ) # Use base implementation for sampling and acceptance @@ -275,7 +290,7 @@ def _generate_draft_tokens( """ sa_manager = spec_metadata.sa_manager request_ids = spec_metadata.request_ids - max_draft_len = self._max_draft_len + runtime_draft_len = spec_metadata.runtime_draft_len if sa_manager is None or request_ids is None: # No SA manager available, throw error @@ -286,7 +301,7 @@ def _generate_draft_tokens( request_ids, accepted_tokens, num_accepted_tokens, - max_draft_len, + runtime_draft_len, max_ngram_size=self._max_matching_ngram_size, ) else: @@ -294,7 +309,7 @@ def _generate_draft_tokens( request_ids, accepted_tokens, num_accepted_tokens, - max_draft_len, + runtime_draft_len, max_ngram_size=self._max_matching_ngram_size, ) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index f6e7639d9644..4629dab1e441 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -949,6 +949,10 @@ def tokens_per_gen_step(self) -> int: """Total tokens per gen request in one spec dec iteration (including golden token).""" return 1 + self.max_total_draft_tokens + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """Total tokens per gen request for the current runtime draft length.""" + return 1 + runtime_draft_len + def num_capture_layers(self) -> int: return 0 @@ -1618,6 +1622,10 @@ def tokens_per_gen_step(self) -> int: """PARD needs 2K tokens per gen request: K+1 accepted + K-1 masks.""" return 2 * self.max_draft_len + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """PARD needs 2K runtime tokens per gen request for logical draft length K.""" + return 1 if runtime_draft_len == 0 else 2 * runtime_draft_len + def supports_backend(self, backend: str) -> bool: return backend == "pytorch" diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index cd66176e332f..5a91f2648515 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -60,6 +60,8 @@ meta-llama/Llama-4-Scout-17B-16E-Instruct: accuracy: 89.45 deepseek-ai/DeepSeek-V3-Lite: - accuracy: 64.74 + - spec_dec_algo: Draft_Target + accuracy: 64.026 - quant_algo: NVFP4 accuracy: 63.71 - quant_algo: NVFP4 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 7aa99300e0f7..6a9b91709674 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -59,7 +59,7 @@ def patched_start_mpi_pool(self): Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, PARDDecodingConfig, RocketSparseAttentionConfig, SADecodingConfig, SamplingParams, SchedulerConfig, - SkipSoftmaxAttentionConfig, SAEnhancerConfig, TorchCompileConfig) + SkipSoftmaxAttentionConfig, SAEnhancerConfig, TorchCompileConfig, DraftTargetDecodingConfig) # isort: on from tensorrt_llm.quantization import QuantAlgo @@ -396,6 +396,51 @@ def test_eagle3_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_hopper + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_eagle3_sa_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + max_batch_size=500, + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + + eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B" + target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + + spec_config = Eagle3DecodingConfig( + max_draft_len=max_draft_len, + speculative_model=eagle_model_dir, + eagle3_one_model=True, + use_sa_spec=True, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + + with LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_hopper @parametrize_with_ids("overlap_scheduler", [True, False]) def test_pard(self, overlap_scheduler): @@ -479,6 +524,83 @@ def test_pard_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_pard_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + pard_model_dir = f"{llm_models_root()}/PARD-Llama-3.2-1B" + pard_config = PARDDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=pard_model_dir, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=pard_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_pard_sa_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + pard_model_dir = f"{llm_models_root()}/PARD-Llama-3.2-1B" + pard_config = PARDDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=pard_model_dir, + use_sa_spec=True, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=pard_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_hopper def test_ngram(self): max_bs = 16 @@ -536,6 +658,46 @@ def test_suffix_automaton(self, enable_global_pool): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_suffix_automaton_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=cuda_graph_config, + ) + + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.8) + spec_config = SADecodingConfig( + max_draft_len=max_draft_len, + max_matching_ngram_size=-1, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + max_batch_size=500) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attn_backend", ["TRTLLM"]) @@ -1767,6 +1929,163 @@ def test_bfloat16_mtp_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_mtp_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_mtp_sa_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + """Accuracy test for MTP + SA with dynamic draft length or max_concurrency.""" + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + use_sa_spec=True, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_mtp_eagle_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + # Force MTP-Eagle one-model path. + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + use_mtp_vanilla=False, + mtp_eagle_one_model=True, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_draft_target_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if + draft_len_schedule or max_concurrency is not None + else CudaGraphConfig()) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + free_gpu_memory_fraction=0.6, + ) + + spec_config = DraftTargetDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=self.MODEL_PATH, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @parametrize_with_ids("mtp_nextn", [0, 2]) def test_bfloat16_4gpus_kv_cache_aware_routing(self, mtp_nextn): diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index e626279f8897..0d15e7c284d8 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -10,6 +10,8 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_a accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_global_pool +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] @@ -39,6 +41,12 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_sch accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_global_pool +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_suffix_automaton_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_suffix_automaton_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype @@ -101,6 +109,14 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_sched accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-enable_chunked_prefill=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_sa_global_pool accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_2_model_mtp +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_sa_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_sa_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_eagle_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_eagle_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_draft_target_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_draft_target_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] diff --git a/tests/unittest/_torch/speculative/test_mtp.py b/tests/unittest/_torch/speculative/test_mtp.py index d3965f477c00..b15a832595b0 100644 --- a/tests/unittest/_torch/speculative/test_mtp.py +++ b/tests/unittest/_torch/speculative/test_mtp.py @@ -312,6 +312,7 @@ def test_sample_and_accept_draft_tokens(self, test_case_name, max_total_draft_tokens=mtp_num_modules, mtp_num_modules=mtp_num_modules) spec_metadata.draft_tokens = draft_tokens + spec_metadata.runtime_draft_len = mtp_num_modules # mtp worker mtpworker = MTPWorker(spec_config) @@ -901,6 +902,7 @@ def test_mtp_update_mtp_hidden_states( spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.runtime_draft_len = num_nextn_predict_layers spec_metadata.prepare() mtpworker = MTPWorker(spec_config) @@ -1397,6 +1399,7 @@ def test_prepare_drafter_inputs( spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.runtime_draft_len = num_nextn_predict_layers spec_metadata.prepare() mtpworker = MTPWorker(spec_config) From d05dc840d775eeb9267977ae67ac5320091476aa Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Sun, 5 Apr 2026 11:57:52 +0000 Subject: [PATCH 02/17] Add SM constraint Signed-off-by: Zheyu Fu --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 6a9b91709674..bd1c705282ae 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -524,7 +524,7 @@ def test_pard_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") - @pytest.mark.skip_less_device_memory(60000) + @skip_pre_hopper @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ (False, True), (True, False), @@ -562,7 +562,7 @@ def test_pard_dynamic_draft_len(self, enable_max_concurrency, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_less_device_memory(60000) + @skip_pre_hopper @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ (False, True), (True, False), @@ -1929,6 +1929,7 @@ def test_bfloat16_mtp_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @pytest.mark.skip_less_device_memory(60000) @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ (False, True), (True, False), @@ -2044,7 +2045,6 @@ def test_mtp_eagle_dynamic_draft_len(self, enable_max_concurrency, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_less_device_memory(60000) @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ (False, True), (True, False), From 7b9cf719903678c927cc0d79c894728d9ba574d0 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Mon, 6 Apr 2026 17:53:02 +0000 Subject: [PATCH 03/17] Fix CI Signed-off-by: Zheyu Fu --- .../_torch/pyexecutor/model_engine.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e5dc6b3f9b77..7063b180f2ba 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -368,12 +368,10 @@ def __init__( self.without_logits = self.spec_config.spec_dec_mode.without_logits( ) or self.model_is_wrapped self.max_total_draft_tokens = spec_config.tokens_per_gen_step - 1 - # PARD uses 2K tokens per gen request (K accepted + K masks), so - # its per-request draft buffer width is 2K-1 = max_total_draft_tokens. - if spec_config.spec_dec_mode.is_pard(): - self.max_draft_len = self.max_total_draft_tokens - else: - self.max_draft_len = spec_config.max_draft_len + # Keep max_draft_len in logical K units for every spec mode. + # PARD's wider per-request storage (2K-1) lives in + # max_total_draft_tokens / original_max_total_draft_tokens. + self.max_draft_len = spec_config.max_draft_len # Mutable per-iteration draft length (updated each iteration when # dynamic draft length is enabled; otherwise stays fixed). self.runtime_draft_len = self.max_draft_len @@ -996,11 +994,14 @@ def _capture_generation_cuda_graphs(self, sparse_config = self.sparse_attention_config if sparse_config is not None and sparse_config.needs_separate_short_long_cuda_graphs( ): - # For short sequences, use the (seq_len_threshold - max_draft_len - 1) as the maximum sequence length - # to make sure all of the past and current input tokens are within the sequence length threshold. + # For short sequences, subtract the maximum runtime tokens consumed + # by a generation step so all current-step tokens stay within the + # sequence length threshold. PARD uses 2K tokens here, not K+1. + max_runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.max_draft_len) # For long sequences, use the default maximum sequence length. - max_seq_len = sparse_config.seq_len_threshold - ( - self.max_draft_len + 1) + max_seq_len = (sparse_config.seq_len_threshold - + max_runtime_tokens_per_gen_step) if max_seq_len < effective_max_seq_len: max_seq_len_list = [effective_max_seq_len, max_seq_len] else: From a1e51c064614d5eb82c98c31ee9674d95c14c6fd Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 10 Apr 2026 22:59:28 +0000 Subject: [PATCH 04/17] resolve comments and CI Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/speculative/mtp.py | 2 +- .../defs/accuracy/references/gsm8k.yaml | 4 +- .../defs/accuracy/test_llm_api_pytorch.py | 224 +++++------------- .../test_lists/qa/llm_function_core.txt | 24 +- 4 files changed, 75 insertions(+), 179 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 36348ad3ebf7..cb54381d2d24 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -606,7 +606,7 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): input_ids, seq_lens, hidden_states, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, - runtime_draft_len, batch_size, num_contexts, hidden_size) + max_draft_len, batch_size, num_contexts, hidden_size) else: assert len(spec_metadata.request_ids) == batch_size mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 5a91f2648515..026878c94470 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -16,6 +16,8 @@ meta-llama/Llama-3.1-8B-Instruct: - spec_dec_algo: PARD extra_acc_spec: use_sa_spec accuracy: 74.20 + - spec_dec_algo: Draft_Target + accuracy: 74.20 - quant_algo: FP8 accuracy: 74.30 - quant_algo: FP8 @@ -60,8 +62,6 @@ meta-llama/Llama-4-Scout-17B-16E-Instruct: accuracy: 89.45 deepseek-ai/DeepSeek-V3-Lite: - accuracy: 64.74 - - spec_dec_algo: Draft_Target - accuracy: 64.026 - quant_algo: NVFP4 accuracy: 63.71 - quant_algo: NVFP4 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index bd1c705282ae..84c4d9a59500 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -59,7 +59,8 @@ def patched_start_mpi_pool(self): Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, PARDDecodingConfig, RocketSparseAttentionConfig, SADecodingConfig, SamplingParams, SchedulerConfig, - SkipSoftmaxAttentionConfig, SAEnhancerConfig, TorchCompileConfig, DraftTargetDecodingConfig) + SkipSoftmaxAttentionConfig, SAEnhancerConfig, TorchCompileConfig, + DraftTargetDecodingConfig) # isort: on from tensorrt_llm.quantization import QuantAlgo @@ -397,26 +398,11 @@ def test_eagle3_sa_global_pool(self): task.evaluate(llm, extra_acc_spec="use_sa_spec") @skip_pre_hopper - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_eagle3_sa_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None - max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) + def test_eagle3_sa_dynamic_draft_len(self): pytorch_config = dict( max_batch_size=500, disable_overlap_scheduler=False, - cuda_graph_config=cuda_graph_config, + cuda_graph_config=(CudaGraphConfig(max_batch_size=500)), ) kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) @@ -424,12 +410,15 @@ def test_eagle3_sa_dynamic_draft_len(self, enable_max_concurrency, target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" spec_config = Eagle3DecodingConfig( - max_draft_len=max_draft_len, + max_draft_len=4, speculative_model=eagle_model_dir, eagle3_one_model=True, - use_sa_spec=True, - max_concurrency=max_concurrency, - draft_len_schedule=draft_len_schedule, + sa_config=SAEnhancerConfig(enable_global_pool=True), + draft_len_schedule={ + 50: 4, + 200: 3, + 350: 2 + }, ) with LLM(model=target_model_dir, @@ -525,22 +514,10 @@ def test_pard_sa_global_pool(self): task.evaluate(llm, extra_acc_spec="use_sa_spec") @skip_pre_hopper - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_pard_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None + def test_pard_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) + cuda_graph_config = CudaGraphConfig(max_batch_size=500) pytorch_config = dict( disable_overlap_scheduler=False, cuda_graph_config=cuda_graph_config, @@ -550,7 +527,6 @@ def test_pard_dynamic_draft_len(self, enable_max_concurrency, pard_config = PARDDecodingConfig( max_draft_len=max_draft_len, speculative_model=pard_model_dir, - max_concurrency=max_concurrency, draft_len_schedule=draft_len_schedule, ) with LLM(self.MODEL_PATH, @@ -563,23 +539,12 @@ def test_pard_dynamic_draft_len(self, enable_max_concurrency, task.evaluate(llm) @skip_pre_hopper - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_pard_sa_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None + def test_pard_sa_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) + cuda_graph_config = CudaGraphConfig(max_batch_size=500) pytorch_config = dict( + max_batch_size=500, disable_overlap_scheduler=False, cuda_graph_config=cuda_graph_config, ) @@ -588,8 +553,7 @@ def test_pard_sa_dynamic_draft_len(self, enable_max_concurrency, pard_config = PARDDecodingConfig( max_draft_len=max_draft_len, speculative_model=pard_model_dir, - use_sa_spec=True, - max_concurrency=max_concurrency, + sa_config=SAEnhancerConfig(enable_global_pool=True), draft_len_schedule=draft_len_schedule, ) with LLM(self.MODEL_PATH, @@ -659,22 +623,13 @@ def test_suffix_automaton(self, enable_global_pool): task.evaluate(llm) @skip_pre_hopper - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_suffix_automaton_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None + def test_suffix_automaton_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 cuda_graph_config = CudaGraphConfig(max_batch_size=500) pytorch_config = dict( + max_batch_size=500, disable_overlap_scheduler=True, cuda_graph_config=cuda_graph_config, ) @@ -684,7 +639,7 @@ def test_suffix_automaton_dynamic_draft_len(self, enable_max_concurrency, spec_config = SADecodingConfig( max_draft_len=max_draft_len, max_matching_ngram_size=-1, - max_concurrency=max_concurrency, + enable_global_pool=True, draft_len_schedule=draft_len_schedule, ) @@ -693,8 +648,36 @@ def test_suffix_automaton_dynamic_draft_len(self, enable_max_concurrency, kv_cache_config=kv_cache_config, speculative_config=spec_config, enable_chunked_prefill=False, + max_num_tokens=8192) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_hopper + def test_draft_target_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + free_gpu_memory_fraction=0.6, + ) + + spec_config = DraftTargetDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=self.MODEL_PATH, + draft_len_schedule=draft_len_schedule, + ) + + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, max_num_tokens=8192, - max_batch_size=500) as llm: + speculative_config=spec_config) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm) @@ -1930,22 +1913,10 @@ def test_bfloat16_mtp_sa_global_pool(self): task.evaluate(llm, extra_acc_spec="use_sa_spec") @pytest.mark.skip_less_device_memory(60000) - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_mtp_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None + def test_mtp_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) + cuda_graph_config = CudaGraphConfig(max_batch_size=500) pytorch_config = dict( disable_overlap_scheduler=False, cuda_graph_config=cuda_graph_config, @@ -1954,7 +1925,6 @@ def test_mtp_dynamic_draft_len(self, enable_max_concurrency, mtp_config = MTPDecodingConfig( num_nextn_predict_layers=max_draft_len, max_draft_len=max_draft_len, - max_concurrency=max_concurrency, draft_len_schedule=draft_len_schedule, ) with LLM(self.MODEL_PATH, @@ -1967,24 +1937,13 @@ def test_mtp_dynamic_draft_len(self, enable_max_concurrency, task.evaluate(llm) @pytest.mark.skip_less_device_memory(60000) - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_mtp_sa_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - """Accuracy test for MTP + SA with dynamic draft length or max_concurrency.""" - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None + def test_mtp_sa_dynamic_draft_len(self): + """Accuracy test for MTP + SA with dynamic draft length.""" + draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) + cuda_graph_config = CudaGraphConfig(max_batch_size=500) pytorch_config = dict( + max_batch_size=500, disable_overlap_scheduler=False, cuda_graph_config=cuda_graph_config, ) @@ -1992,8 +1951,7 @@ def test_mtp_sa_dynamic_draft_len(self, enable_max_concurrency, mtp_config = MTPDecodingConfig( num_nextn_predict_layers=max_draft_len, max_draft_len=max_draft_len, - use_sa_spec=True, - max_concurrency=max_concurrency, + sa_config=SAEnhancerConfig(enable_global_pool=True), draft_len_schedule=draft_len_schedule, ) with LLM(self.MODEL_PATH, @@ -2006,22 +1964,10 @@ def test_mtp_sa_dynamic_draft_len(self, enable_max_concurrency, task.evaluate(llm, extra_acc_spec="use_sa_spec") @pytest.mark.skip_less_device_memory(60000) - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_mtp_eagle_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None + def test_mtp_eagle_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) + cuda_graph_config = CudaGraphConfig(max_batch_size=500) pytorch_config = dict( disable_overlap_scheduler=False, cuda_graph_config=cuda_graph_config, @@ -2033,7 +1979,6 @@ def test_mtp_eagle_dynamic_draft_len(self, enable_max_concurrency, max_draft_len=max_draft_len, use_mtp_vanilla=False, mtp_eagle_one_model=True, - max_concurrency=max_concurrency, draft_len_schedule=draft_len_schedule, ) with LLM(self.MODEL_PATH, @@ -2045,47 +1990,6 @@ def test_mtp_eagle_dynamic_draft_len(self, enable_max_concurrency, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ - (False, True), - (True, False), - ]) - def test_draft_target_dynamic_draft_len(self, enable_max_concurrency, - enable_draft_len_schedule): - max_concurrency = 100 if enable_max_concurrency else None - draft_len_schedule = { - 50: 4, - 200: 3, - 350: 2 - } if enable_draft_len_schedule else None - max_draft_len = 4 - cuda_graph_config = (CudaGraphConfig(max_batch_size=500) if - draft_len_schedule or max_concurrency is not None - else CudaGraphConfig()) - pytorch_config = dict( - disable_overlap_scheduler=True, - cuda_graph_config=cuda_graph_config, - ) - kv_cache_config = KvCacheConfig( - enable_block_reuse=False, - free_gpu_memory_fraction=0.6, - ) - - spec_config = DraftTargetDecodingConfig( - max_draft_len=max_draft_len, - speculative_model=self.MODEL_PATH, - max_concurrency=max_concurrency, - draft_len_schedule=draft_len_schedule, - ) - - with LLM(model=self.MODEL_PATH, - **pytorch_config, - kv_cache_config=kv_cache_config, - enable_chunked_prefill=False, - max_num_tokens=8192, - speculative_config=spec_config) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - @pytest.mark.skip_less_device(4) @parametrize_with_ids("mtp_nextn", [0, 2]) def test_bfloat16_4gpus_kv_cache_aware_routing(self, mtp_nextn): diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 0d15e7c284d8..246a6efc98e1 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -10,8 +10,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_a accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_global_pool -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] @@ -41,12 +40,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_sch accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_global_pool -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_suffix_automaton_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_suffix_automaton_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_suffix_automaton_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_draft_target_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype @@ -109,14 +106,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_sched accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-enable_chunked_prefill=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_sa_global_pool accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_2_model_mtp -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_sa_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_sa_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_eagle_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_eagle_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_draft_target_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_draft_target_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_sa_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_eagle_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] From dfced744d41d1334f7e654be4790384049a755a1 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 17 Apr 2026 07:04:38 +0000 Subject: [PATCH 05/17] Fix CI Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/attention_backend/trtllm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 847101b74f8d..979933e3df6f 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1658,7 +1658,8 @@ def update_spec_dec_param( # So we create cache for position offsets and packed mask for each draft length to avoid reallocation. assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree" runtime_draft_token_buffer_width = ( - spec_metadata.runtime_tokens_per_gen_step - 1) + spec_metadata.runtime_tokens_per_gen_step - + 1 if spec_metadata is not None else max_draft_len) self.generate_spec_decoding_generation_length( runtime_draft_len=runtime_draft_token_buffer_width) self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets( From 3a3e369909e627357f4f76bccea9f215b1ab47f3 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 17 Apr 2026 07:19:18 +0000 Subject: [PATCH 06/17] Add comment on runtime_draft_token_buffer_width Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/attention_backend/trtllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 979933e3df6f..145807d6454b 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1657,6 +1657,7 @@ def update_spec_dec_param( # Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length. # So we create cache for position offsets and packed mask for each draft length to avoid reallocation. assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree" + # For algos other than PARD, this equals runtime_draft_len (K); for PARD it's 2K-1. runtime_draft_token_buffer_width = ( spec_metadata.runtime_tokens_per_gen_step - 1 if spec_metadata is not None else max_draft_len) From 3dab7bc6120566c66a7697ca2e72b936fd9accf9 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Thu, 23 Apr 2026 19:45:51 +0000 Subject: [PATCH 07/17] Fix pre-commit Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 5ba985c67b15..a10b48197b55 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2502,7 +2502,8 @@ def _prepare_tp_inputs( len(position_ids) + runtime_tokens_per_gen_step))) position_ids.extend( list( - range(past_seen_token_num, past_seen_token_num + runtime_tokens_per_gen_step))) + range(past_seen_token_num, past_seen_token_num + + runtime_tokens_per_gen_step))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * From a5f0454fa1769c7a64d50eaf498bc1cf9252f826 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 24 Apr 2026 22:16:37 +0000 Subject: [PATCH 08/17] [TRTLLM-11556][fix] Align runtime_draft_len selection across writers runtime_draft_len means logical K for linear-tree modes (incl. PARD) and total tree tokens for tree decoding. _prepare_tp_inputs enforces this selection, but py_executor._handle_dynamic_draft_length and _get_graphs_to_capture wrote the other value, causing: - PARD test_pard[True/False] warmup: runtime_draft_len=7 (buffer width 2K-1) instead of K=4. pard.py:202 sliced draft_tokens by 13 elements from a 7-wide buffer: * B200: shape '[1, 13]' is invalid for input of size 7 * H100: cudaErrorIllegalAddress (same OOB read) - EAGLE3 dynamic tree: CUDA graph captured with draft_len=6 (max_draft_len) while _prepare_tp_inputs overrides to 60 (max_total_draft_tokens) at forward time. The captured graph and runtime forward disagree on per-request layout width, producing a shape-inconsistent descriptor at o_proj linear and CUBLAS_STATUS_EXECUTION_FAILED during capture. Both writers now use the inline ternary already at _prepare_tp_inputs:2436-2437: max_draft_len if spec_config.is_linear_tree else max_total_draft_tokens Three write sites cross-reference the shared invariant to prevent future drift. Validated with test_pard[True]/[False] passing (60s/50s, 2 passed). EAGLE3 dynamic tree requires PR-branch-only C++ kernels; covered by CI. Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 7 ++++++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 9 ++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a10b48197b55..782e3093aa74 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -941,7 +941,12 @@ def _get_graphs_to_capture( return graphs # Case 3: Target model (two-model) or one-model without dynamic draft - draft_lengths = [self.max_draft_len] + # Match the runtime_draft_len semantics enforced in _prepare_tp_inputs: + # logical K for linear-tree modes, total tree tokens for tree decoding. + draft_lengths = [ + self.max_draft_len + if self.spec_config.is_linear_tree else self.max_total_draft_tokens + ] should_capture_no_spec = ( self.max_total_draft_tokens > 0 and not self.spec_config.spec_dec_mode.use_one_engine() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5929cb79fef7..0bd0661f8199 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1792,7 +1792,14 @@ def _handle_dynamic_draft_len(self, self.model_engine.runtime_draft_len = runtime_draft_len else: - self.model_engine.runtime_draft_len = self.model_engine.max_total_draft_tokens + # Linear-tree modes (incl. PARD) use logical K; tree decoding + # (e.g. EAGLE3 dynamic tree) uses total tree tokens. Same + # selection as _prepare_tp_inputs and _get_graphs_to_capture. + spec_config = self.model_engine.spec_config + self.model_engine.runtime_draft_len = ( + self.model_engine.max_draft_len + if spec_config is not None and spec_config.is_linear_tree else + self.model_engine.max_total_draft_tokens) def _can_queue(self, scheduled_batch): From 60804cf13c2c01d5b7a54d07370b7bff3a968d89 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Mon, 27 Apr 2026 02:12:53 +0000 Subject: [PATCH 09/17] [TRTLLM-11556][fix] Guard _get_graphs_to_capture against spec_config=None The previous fix introduced `self.spec_config.is_linear_tree` in _get_graphs_to_capture Case 3 without a None guard. Non-spec models also fall through to Case 3 with spec_config=None, so capture crashed with `AttributeError: 'NoneType' object has no attribute 'is_linear_tree'` and cascaded as `RuntimeError: Executor worker returned error` in subprocesses. CI build 35684 attributed ~1145 of 1235 fails (90%+) to this single regression across non-spec PyTorch tests on every GPU class. Fix: short-circuit to max_draft_len (= 0 for non-spec) when spec_config is None. Matches the original pre-regression behaviour of `[self.max_draft_len]` for non-spec; spec paths unchanged. Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 4731210c2dac..b27d8212d938 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -953,9 +953,11 @@ def _get_graphs_to_capture( # Case 3: Target model (two-model) or one-model without dynamic draft # Match the runtime_draft_len semantics enforced in _prepare_tp_inputs: # logical K for linear-tree modes, total tree tokens for tree decoding. + # spec_config is None for non-spec models — fall back to max_draft_len (= 0). draft_lengths = [ - self.max_draft_len - if self.spec_config.is_linear_tree else self.max_total_draft_tokens + self.max_draft_len if + (self.spec_config is None or self.spec_config.is_linear_tree) else + self.max_total_draft_tokens ] should_capture_no_spec = ( self.max_total_draft_tokens > 0 From 90c90526708d63bc5cc036b83bca88e7d8361714 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Mon, 27 Apr 2026 21:39:32 +0000 Subject: [PATCH 10/17] Run pre-commit Signed-off-by: Zheyu Fu --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2340072e933e..0ee6663ff128 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -564,6 +564,7 @@ def test_pard_sa_dynamic_draft_len(self): speculative_config=pard_config) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + def test_dflash(self): pytorch_config = dict( max_batch_size=8, From 22a4718a2c3ffa3a2da5f8930cc554c06b519089 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Mon, 18 May 2026 03:05:01 +0000 Subject: [PATCH 11/17] Add DFlash support Signed-off-by: Zheyu Fu --- .../_torch/pyexecutor/model_engine.py | 10 +++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 +-- tensorrt_llm/_torch/speculative/dflash.py | 19 +++++++++--- tensorrt_llm/_torch/speculative/interface.py | 3 +- tensorrt_llm/llmapi/llm_args.py | 4 +++ .../defs/accuracy/test_llm_api_pytorch.py | 29 +++++++++++++++++++ 6 files changed, 60 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a5a7efde2cdf..c611a79a01e0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -432,8 +432,14 @@ def __init__( else: self.max_draft_len = spec_config.max_draft_len # Mutable per-iteration draft length (updated each iteration when - # dynamic draft length is enabled; otherwise stays fixed). - self.runtime_draft_len = self.max_draft_len + # dynamic draft length is enabled; otherwise stays fixed). For + # parallel-draft modes (PARD/DFlash), self.max_draft_len is the + # 2K-1 buffer width; the worker forward expects logical K, so + # initialize from spec_config.max_draft_len here. + if spec_config.spec_dec_mode.is_parallel_draft(): + self.runtime_draft_len = spec_config.max_draft_len + else: + self.runtime_draft_len = self.max_draft_len else: self.without_logits = False diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index b735b415fbc2..7f34a94af081 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2107,8 +2107,8 @@ def _handle_dynamic_draft_len(self, DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. for request in scheduled_batch.generation_requests: current_num_draft_tokens = len(request.py_draft_tokens) - if spec_dec_mode.is_pard(): - # special case as PARD carries 2K-1 draft tokens per request + if spec_dec_mode.is_parallel_draft(): + # special case: PARD/DFlash carry 2K-1 draft tokens per request runtime_draft_token_buffer_width = ( self.model_engine.spec_config. get_runtime_tokens_per_gen_step(runtime_draft_len) - 1) diff --git a/tensorrt_llm/_torch/speculative/dflash.py b/tensorrt_llm/_torch/speculative/dflash.py index 510f88fb5dad..57761bd4833b 100644 --- a/tensorrt_llm/_torch/speculative/dflash.py +++ b/tensorrt_llm/_torch/speculative/dflash.py @@ -364,7 +364,18 @@ def forward( num_gens = batch_size - num_contexts raw_logits = logits - K = self.max_draft_len + K = spec_metadata.runtime_draft_len + + if K == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) # Lazy init buffers and attach worker reference for prepare() self._lazy_init_ctx_buffers(draft_model, spec_metadata) @@ -481,7 +492,7 @@ def forward( ) vocab_size = gen_logits.shape[-1] - gen_logits = gen_logits.reshape(num_gens, self.max_draft_len, vocab_size) + gen_logits = gen_logits.reshape(num_gens, K, vocab_size) d2t = getattr(draft_model.model, "d2t", None) gen_draft_tokens = torch.argmax(gen_logits, dim=-1, keepdim=False).long() @@ -595,8 +606,8 @@ def prepare_1st_drafter_inputs( gen_num_accepted = num_accepted_tokens[num_contexts : num_contexts + num_gens] gen_accepted_tokens = accepted_tokens[num_contexts : num_contexts + num_gens, :] - total_tokens_per_req = self._draft_tokens_per_req # 2K - K = self.max_draft_len + K = spec_metadata.runtime_draft_len + total_tokens_per_req = 2 * K # Get captured multi-layer hidden states from spec_metadata captured_hs = spec_metadata.get_hidden_states(total_target_tokens) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index ed8d22a2ff8d..342ad728935c 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -296,7 +296,8 @@ def support_capturable_guided_decoder(self): def support_dynamic_draft_len(self): return self.is_mtp_one_model() or self.is_eagle3_one_model( - ) or self.is_pard() or self.is_draft_target_one_model() or self.is_sa() + ) or self.is_pard() or self.is_dflash( + ) or self.is_draft_target_one_model() or self.is_sa() def has_draft_model(self): return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index f3b88f7807b9..1f3ba4bbe557 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1813,6 +1813,10 @@ def tokens_per_gen_step(self) -> int: """DFlash needs 2K tokens per gen request: K+1 accepted + K-1 masks.""" return 2 * self.max_draft_len + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """DFlash needs 2K runtime tokens per gen request for logical draft length K.""" + return 1 if runtime_draft_len == 0 else 2 * runtime_draft_len + def supports_backend(self, backend: str) -> bool: return backend == "pytorch" diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index f03384add05a..5dc20d38fb0d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -638,6 +638,35 @@ def test_dflash(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + def test_dflash_dynamic_draft_len(self): + # DFlash uses a Qwen3-style draft with q/k_norm and 8K-wide cross-attn + # context, so the per-layer rmsnorm row count scales as + # B * max_ctx * num_kv_heads. Very large batches (e.g. 500) push that + # past the flashinfer rmsnorm kernel's stable range; cap at 200. + draft_len_schedule = {50: 4, 100: 3, 150: 2} + max_draft_len = 4 + pytorch_config = dict( + max_batch_size=200, + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(max_batch_size=200, + enable_padding=True), + ) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.6) + dflash_model_dir = f"{llm_models_root()}/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + spec_config = DFlashDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=dflash_model_dir, + draft_len_schedule=draft_len_schedule, + ) + with LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_hopper def test_ngram(self): max_bs = 16 From 9e212f4c79f5508b3288516e4f3c2b9f2e009738 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Mon, 18 May 2026 03:16:22 +0000 Subject: [PATCH 12/17] minor changes/ Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index c611a79a01e0..a617e13d2a14 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -427,19 +427,13 @@ def __init__( self.max_total_draft_tokens = spec_config.tokens_per_gen_step - 1 # PARD/DFlash use 2K tokens per gen request (K accepted + K masks), so # their per-request draft buffer width is 2K-1 = max_total_draft_tokens. + # Runtime draft length is updated each iteration when dynamic draft length is enabled; otherwise stays fixed. if spec_config.spec_dec_mode.is_parallel_draft(): self.max_draft_len = self.max_total_draft_tokens + self.runtime_draft_len = self.max_total_draft_tokens else: self.max_draft_len = spec_config.max_draft_len - # Mutable per-iteration draft length (updated each iteration when - # dynamic draft length is enabled; otherwise stays fixed). For - # parallel-draft modes (PARD/DFlash), self.max_draft_len is the - # 2K-1 buffer width; the worker forward expects logical K, so - # initialize from spec_config.max_draft_len here. - if spec_config.spec_dec_mode.is_parallel_draft(): self.runtime_draft_len = spec_config.max_draft_len - else: - self.runtime_draft_len = self.max_draft_len else: self.without_logits = False From 183ced8876b14182873181dc2fa7b880161989d8 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Tue, 19 May 2026 04:37:45 +0000 Subject: [PATCH 13/17] [TRTLLM-11556][fix] Restore logical K for parallel-draft runtime_draft_len MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier "minor changes/" commit set self.max_draft_len and self.runtime_draft_len to max_total_draft_tokens (= 2K-1 for PARD) for parallel-draft modes, but every reader of these attributes — pard.py forward, py_executor dynamic-draft path, _get_graphs_to_capture, _create_cuda_graph_warmup_request, resource_manager — expects the logical K. The mismatch surfaces during warmup: pard.py:202: shape '[1, 13]' is invalid for input of size 7 (or as an async CUDA illegal memory access at modeling_speculative.py:1740 on H100). Reverting to spec_config.max_draft_len for both fields restores the contract; the per-request 2K-1 buffer width is already carried by max_total_draft_tokens and computed via get_runtime_tokens_per_gen_step where needed. Verified test_pard[True] and test_pard[False] pass locally (132s / 131s). Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 639bbf878c4d..ed17e6d527f0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -425,19 +425,8 @@ def __init__( self.without_logits = self.spec_config.spec_dec_mode.without_logits( ) or self.model_is_wrapped self.max_total_draft_tokens = spec_config.tokens_per_gen_step - 1 - # Parallel-draft modes (PARD, DFlash) size their per-request draft - # buffer by tokens_per_gen_step - 1 so the engine reserves exactly - # one slot per draft token the target will verify. PARD still uses - # 2K tokens per gen req (K drafts + K mask fillers); DFlash was - # reduced to K+1 (K drafts + 1 bonus) - the spec config's - # tokens_per_gen_step carries the per-algorithm width. - # Runtime draft length is updated each iteration when dynamic draft length is enabled; otherwise stays fixed. - if spec_config.spec_dec_mode.is_parallel_draft(): - self.max_draft_len = self.max_total_draft_tokens - self.runtime_draft_len = self.max_total_draft_tokens - else: - self.max_draft_len = spec_config.max_draft_len - self.runtime_draft_len = spec_config.max_draft_len + self.max_draft_len = spec_config.max_draft_len + self.runtime_draft_len = spec_config.max_draft_len else: self.without_logits = False From 44d025d28ad36c6641c2a10b00170cd6da64db6d Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Tue, 19 May 2026 04:48:20 +0000 Subject: [PATCH 14/17] [None][test] Add test_dflash_dynamic_draft_len to QA llm_function_core Cover the dynamic-draft-length DFlash path in the QA CI lane, alongside the existing PARD/EAGLE3/SA dynamic-draft-length entries. Signed-off-by: Zheyu Fu --- tests/integration/test_lists/qa/llm_function_core.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index bd9dc2723c56..bf44c859892e 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -588,6 +588,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=False-attn_backend=TRTLLM] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=True-attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=True-attn_backend=TRTLLM] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_dflash_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_draft_target_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_dummy_load_format accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] From f22d6fbfaabdbba5a1ae560f66db85eec4de7c4c Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Wed, 20 May 2026 19:42:28 +0000 Subject: [PATCH 15/17] fix precommit Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f8bcde43c7b5..2a345b64072f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -60,7 +60,7 @@ from .hang_detector import HangDetector from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, - LlmResponse) + LlmResponse, get_draft_token_length) from .mamba_cache_manager import MambaHybridCacheManager from .model_engine import ModelEngine from .perf_metrics_manager import PerfMetricsManager From 5a54cd5d71ed3b0ff126a5a95f0da0027e549bfe Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Wed, 20 May 2026 20:41:40 +0000 Subject: [PATCH 16/17] [None][chore] collapse 24 single-line docstrings in llm_args.py (D200) Signed-off-by: Zheyu Fu --- tensorrt_llm/llmapi/llm_args.py | 94 +++++++++------------------------ 1 file changed, 24 insertions(+), 70 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7db9a2aad395..3d7838464ef9 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -107,9 +107,7 @@ def Field(default: Any = ..., class CudaGraphConfig(StrictBaseModel): - """ - Configuration for CUDA graphs. - """ + """Configuration for CUDA graphs.""" # List of batch sizes to create CUDA graphs for. batch_sizes: Optional[List[int]] = Field( default=None, @@ -253,9 +251,7 @@ class GuidedDecodingBackend(Enum): class BaseSparseAttentionConfig(StrictBaseModel): - """ - Configuration for sparse attention. - """ + """Configuration for sparse attention.""" algorithm: str seq_len_threshold: Optional[int] = Field( @@ -285,9 +281,7 @@ def needs_separate_short_long_cuda_graphs(self) -> bool: class RocketSparseAttentionConfig(BaseSparseAttentionConfig): - """ - Configuration for RocketKV sparse attention. - """ + """Configuration for RocketKV sparse attention.""" algorithm: Literal["rocket"] = "rocket" window_size: Optional[int] = Field( default=32, description="The window size for RocketKV.") @@ -312,9 +306,7 @@ def get_indices_block_size(self) -> int: class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig): - """ - Configuration for DeepSeek Sparse Attention. - """ + """Configuration for DeepSeek Sparse Attention.""" algorithm: Literal["dsa"] = "dsa" index_n_heads: Optional[int] = Field( default=None, description="The number of heads for the indexer.") @@ -405,9 +397,7 @@ def needs_separate_short_long_cuda_graphs(self) -> bool: class SkipSoftmaxAttentionConfig(BaseSparseAttentionConfig): - """ - Configuration for skip softmax attention. - """ + """Configuration for skip softmax attention.""" algorithm: Literal["skip_softmax"] = "skip_softmax" threshold_scale_factor: Optional[Union[float, Dict[str, float]]] = Field( default=None, @@ -563,9 +553,7 @@ def slot_end(self) -> int: def get_layer_initial_global_assignments( self, layer_idx: int) -> Optional[List[int]]: - """ - Retrieves the initial global assignments for a specific layer. - """ + """Retrieves the initial global assignments for a specific layer.""" if self.initial_global_assignments is None: return None @@ -589,9 +577,7 @@ def get_layer_initial_global_assignments( class MoeConfig(StrictBaseModel): - """ - Configuration for MoE. - """ + """Configuration for MoE.""" backend: Literal[ "AUTO", "CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", "DENSEGEMM", "VANILLA", "TRITON"] = Field( @@ -634,9 +620,7 @@ class MoeConfig(StrictBaseModel): class Nvfp4GemmConfig(StrictBaseModel): - """ - Configuration for NVFP4 GEMM backend selection. - """ + """Configuration for NVFP4 GEMM backend selection.""" allowed_backends: List[Nvfp4Backend] = Field( default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core'], min_length=1, @@ -647,9 +631,7 @@ class Nvfp4GemmConfig(StrictBaseModel): class AttentionDpConfig(StrictBaseModel): - """ - Configuration for attention DP. - """ + """Configuration for attention DP.""" enable_balance: bool = Field(default=False, description="Whether to enable balance.") timeout_iters: int = Field( @@ -708,9 +690,7 @@ def validate_attention_dp_config(self) -> 'AttentionDpConfig': class CpConfig(StrictBaseModel): - """ - Configuration for context parallelism. - """ + """Configuration for context parallelism.""" # TODO: given that multiple fields here are only used with specific cp_types, consider # making this a Pydantic discriminated union. cp_type: CpType = Field(default=CpType.ULYSSES, @@ -816,9 +796,7 @@ def to_mapping(self) -> Mapping: class CalibConfig(StrictBaseModel): - """ - Calibration configuration. - """ + """Calibration configuration.""" device: Literal['cuda', 'cpu'] = Field(default='cuda', description="The device to run calibration.") @@ -1102,9 +1080,7 @@ def _resolve_preset(self) -> "KvCacheConnectorConfig": class LayerwiseBenchmarksConfig(StrictBaseModel): - """ - Configuration for layer-wise benchmarks calibration. - """ + """Configuration for layer-wise benchmarks calibration.""" calibration_mode: Literal["NONE", "MARK", "COLLECT"] = Field( default="NONE", description= @@ -1488,9 +1464,7 @@ def set_max_total_draft_tokens(self): class NGramDecodingConfig(DecodingBaseConfig): - """ - Configuration for NGram drafter speculative decoding. - """ + """Configuration for NGram drafter speculative decoding.""" decoding_type: Literal["NGram"] = "NGram" max_matching_ngram_size: PositiveInt = Field( default=2, @@ -2020,8 +1994,7 @@ class ExecutorMemoryType(StrEnum): @dataclass class _SleepConfigDefaultFactory: - """Picklable replacement for ``lambda: default_mode`` in SleepConfig's defaultdict. - """ + """Picklable replacement for ``lambda: default_mode`` in SleepConfig's defaultdict.""" default_mode: Any @@ -2030,8 +2003,7 @@ def __call__(self) -> Any: class SleepConfig(StrictBaseModel): - """Configuration for the LLM sleep/wakeup feature. - """ + """Configuration for the LLM sleep/wakeup feature.""" restore_modes: dict[ ExecutorMemoryType, Literal["NONE", "MEMSET", "CPU", "PINNED"] @@ -2275,9 +2247,7 @@ class PybindMirrorMeta(type(PybindMirror)): class PybindMirrorEnumMeta(EnumMeta, PybindMirrorMeta): - """ - Combined metaclass for Enum and PybindMirror. This is crucial. - """ + """Combined metaclass for Enum and PybindMirror. This is crucial.""" @PybindMirror.mirror_pybind_enum(_BatchingType) @@ -2381,9 +2351,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_PeftCacheConfig) class PeftCacheConfig(StrictBaseModel, PybindMirror): - """ - Configuration for the PEFT cache. - """ + """Configuration for the PEFT cache.""" num_host_module_layer: int = Field( default=0, description= @@ -2451,9 +2419,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_LookaheadDecodingConfig) class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror): - """ - Configuration for lookahead speculative decoding. - """ + """Configuration for lookahead speculative decoding.""" decoding_type: Literal["Lookahead"] = "Lookahead" max_window_size: PositiveInt = Field( @@ -2556,9 +2522,7 @@ class ReorderRequestPolicyConfig(StrictBaseModel): @PybindMirror.mirror_pybind_fields(_KvCacheConfig) class KvCacheConfig(StrictBaseModel, PybindMirror): - """ - Configuration for the KV cache. - """ + """Configuration for the KV cache.""" enable_block_reuse: bool = Field( default=True, description= @@ -2771,9 +2735,7 @@ def validate_max_util_for_resume(cls, v: float): @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror): - """ - Configuration for extended runtime performance knobs. - """ + """Configuration for extended runtime performance knobs.""" multi_block_mode: bool = Field( default=True, description="Whether to use multi-block mode.") @@ -2802,9 +2764,7 @@ def _to_pybind(self): @PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig) class CacheTransceiverConfig(StrictBaseModel, PybindMirror): - """ - Configuration for the cache transceiver. - """ + """Configuration for the cache transceiver.""" backend: Optional[Literal[ "DEFAULT", "UCX", "NIXL", "MOONCAKE", "MPI"]] = Field( @@ -2914,9 +2874,7 @@ class DwdpConfig(StrictBaseModel): class BaseLlmArgs(StrictBaseModel): - """ - Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both. - """ + """Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.""" model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") # Explicit arguments @@ -3431,9 +3389,7 @@ class TrtLlmArgs(BaseLlmArgs): @model_validator(mode="after") def init_build_config(self): - """ - Creating a default BuildConfig if none is provided - """ + """Creating a default BuildConfig if none is provided.""" build_config = getattr(self, "build_config", None) if build_config is None: kwargs = {} @@ -3758,9 +3714,7 @@ class SamplerType(StrEnum): class TorchCompileConfig(StrictBaseModel): - """ - Configuration for torch.compile. - """ + """Configuration for torch.compile.""" enable_fullgraph: bool = Field( default=True, description="Enable full graph compilation in torch.compile.") From c3d32d304120d309f1d6fd7504688c040167c3a5 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Sat, 23 May 2026 07:17:30 +0000 Subject: [PATCH 17/17] [TRTLLM-11556][test] Gate dynamic-draft spec tests to Blackwell+ The dynamic-draft-length speculative decoding tests (Eagle3/PARD/DFlash/SA /draft_target/MTP variants) were validated end-to-end on B200, where all 11 produce correct GSM8K accuracy after the runtime_draft_len fix in 183ced8876. On older arches the same paths trip an FMHA cubin coverage gap that falls back to unfused MHA with O(bs * max_seq^2) workspace, which is unrelated to this PR's logic but blocks the tests from running. Tighten the arch gate from skip_pre_hopper -> skip_pre_blackwell on the five existing dynamic-draft tests, and add the same decorator to the four that previously had no arch gate (dflash, mtp, mtp_sa, mtp_eagle). H100 CI support can be re-enabled later if the FMHA cubin coverage on sm_90 lands. Also two test-config fixes uncovered while validating on B200: * test_pard_sa_dynamic_draft_len: bump max_num_tokens 8192 -> 16384 so the SA-enhanced PARD path's larger per-step token budget doesn't trip total_num_tokens > max_num_tokens on long GSM8K prompts (matches the intermittent failure visible in CI history for the same shape). * test_dflash_dynamic_draft_len: set max_seq_len=8192 explicitly so the DFlash per-slot context K+V buffer is sized against an 8K ceiling rather than the target Llama-3.1's 131K max_position_embeddings, keeping the bs=200 reservation well under 180 GiB on B200/GB200. Signed-off-by: Zheyu Fu --- .../defs/accuracy/test_llm_api_pytorch.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 4b35a9ad2d11..9c505ef0f178 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -447,7 +447,7 @@ def test_eagle3_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") - @skip_pre_hopper + @skip_pre_blackwell def test_eagle3_sa_dynamic_draft_len(self): pytorch_config = dict( max_batch_size=500, @@ -563,7 +563,7 @@ def test_pard_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") - @skip_pre_hopper + @skip_pre_blackwell def test_pard_dynamic_draft_len(self): draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 @@ -588,7 +588,7 @@ def test_pard_dynamic_draft_len(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_pre_hopper + @skip_pre_blackwell def test_pard_sa_dynamic_draft_len(self): draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 @@ -609,7 +609,7 @@ def test_pard_sa_dynamic_draft_len(self): with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config, enable_chunked_prefill=False, - max_num_tokens=8192, + max_num_tokens=16384, **pytorch_config, speculative_config=pard_config) as llm: task = GSM8K(self.MODEL_NAME) @@ -638,6 +638,7 @@ def test_dflash(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell def test_dflash_dynamic_draft_len(self): # DFlash uses a Qwen3-style draft with q/k_norm and 8K-wide cross-attn # context, so the per-layer rmsnorm row count scales as @@ -661,6 +662,7 @@ def test_dflash_dynamic_draft_len(self): draft_len_schedule=draft_len_schedule, ) with LLM(model=target_model_dir, + max_seq_len=8192, **pytorch_config, kv_cache_config=kv_cache_config, speculative_config=spec_config) as llm: @@ -724,7 +726,7 @@ def test_suffix_automaton(self, enable_global_pool): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_pre_hopper + @skip_pre_blackwell def test_suffix_automaton_dynamic_draft_len(self): draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 @@ -754,7 +756,7 @@ def test_suffix_automaton_dynamic_draft_len(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_pre_hopper + @skip_pre_blackwell def test_draft_target_dynamic_draft_len(self): draft_len_schedule = {50: 4, 200: 3, 350: 2} max_draft_len = 4 @@ -1994,6 +1996,7 @@ def test_bfloat16_mtp_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_blackwell @pytest.mark.skip_less_device_memory(60000) def test_mtp_dynamic_draft_len(self): draft_len_schedule = {50: 4, 200: 3, 350: 2} @@ -2018,6 +2021,7 @@ def test_mtp_dynamic_draft_len(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell @pytest.mark.skip_less_device_memory(60000) def test_mtp_sa_dynamic_draft_len(self): """Accuracy test for MTP + SA with dynamic draft length.""" @@ -2045,6 +2049,7 @@ def test_mtp_sa_dynamic_draft_len(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_blackwell @pytest.mark.skip_less_device_memory(60000) def test_mtp_eagle_dynamic_draft_len(self): draft_len_schedule = {50: 4, 200: 3, 350: 2}