diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 93e1a2bfe4c4..f2e20fa9c76c 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1096,15 +1096,16 @@ def update_spec_dec_param( # Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length. # So we create cache for position offsets and packed mask for each draft length to avoid reallocation. assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree" - runtime_draft_len = (spec_metadata.runtime_draft_len - if spec_metadata is not None else - max_draft_len) + # For algos other than PARD, this equals runtime_draft_len (K); for PARD it's 2K-1. + runtime_draft_token_buffer_width = ( + spec_metadata.runtime_tokens_per_gen_step - + 1 if spec_metadata is not None else max_draft_len) self.generate_spec_decoding_generation_length( - runtime_draft_len=runtime_draft_len) + runtime_draft_len=runtime_draft_token_buffer_width) self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets( - self.max_num_requests, runtime_draft_len) + self.max_num_requests, runtime_draft_token_buffer_width) self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask( - self.max_num_requests, runtime_draft_len) + self.max_num_requests, runtime_draft_token_buffer_width) self.update_position_offsets_for_cpp(cpp_query_len) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 478639035c62..49634c95b524 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -128,8 +128,10 @@ def __init__(self, config: CUDAGraphRunnerConfig): def _create_shared_static_tensors(self): """Allocates static tensors sized for the largest possible batch.""" - max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0 - token_per_request = max_draft_len + 1 + runtime_draft_token_buffer_width = ( + self.config.original_max_total_draft_tokens + if self.config.spec_config is not None else 0) + token_per_request = runtime_draft_token_buffer_width + 1 max_total_tokens = (self.max_supported_batch_size * self.max_beam_width * token_per_request) max_total_tokens = min(max_total_tokens, self.config.max_num_tokens) @@ -486,6 +488,11 @@ def _get_padded_batch(self, batch: ScheduledRequests, if padding_size + batch.batch_size > self.config.batch_size: return 0 + runtime_tokens_per_gen_step = ( + self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len) + if self.spec_config is not None else 1 + runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 + # No padding if it would create too many concurrent requests. # This is not strictly required, but we should probably # respect the requirement just in case that changes in the future. @@ -503,7 +510,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, dummy_request = kv_cache_manager.add_dummy_requests( [dummy_request_id], is_gen=True, - max_num_draft_tokens=runtime_draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.config.use_mrope, max_beam_width=self.config.max_beam_width, draft_kv_cache_manager=draft_kv_cache_manager) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index ff1e5f8a5306..21c5b172a614 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -277,7 +277,9 @@ def __init__( spec_config.tokens_per_gen_step - 1) if spec_config is not None else 0 # Saved before zeroing for draft models; used by update_spec_dec_param. - self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 + self._spec_dec_max_total_draft_tokens = ( + spec_config.max_total_draft_tokens + if spec_config is not None else 0) # Dynamic tree draft loop produces up to K * max_draft_len tokens, # which may exceed max_total_draft_tokens. Use the larger value for @@ -497,6 +499,8 @@ def __init__( self.llm_args.attn_backend, sparse_attention_config=self.sparse_attention_config) + self.get_runtime_tokens_per_gen_step = spec_config.get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len: 1 + if self.is_spec_decode: update_spec_config_from_model_config(self.spec_config, self.model.config) @@ -518,19 +522,8 @@ def __init__( self.without_logits = self.spec_config.spec_dec_mode.without_logits( ) or self.model_is_wrapped self.max_total_draft_tokens = spec_config.tokens_per_gen_step - 1 - # Parallel-draft modes (PARD, DFlash) size their per-request draft - # buffer by tokens_per_gen_step - 1 so the engine reserves exactly - # one slot per draft token the target will verify. PARD still uses - # 2K tokens per gen req (K drafts + K mask fillers); DFlash was - # reduced to K+1 (K drafts + 1 bonus) - the spec config's - # tokens_per_gen_step carries the per-algorithm width. - if spec_config.spec_dec_mode.is_parallel_draft(): - self.max_draft_len = self.max_total_draft_tokens - else: - self.max_draft_len = spec_config.max_draft_len - # Mutable per-iteration draft length (updated each iteration when - # dynamic draft length is enabled; otherwise stays fixed). - self.runtime_draft_len = self.max_draft_len + self.max_draft_len = spec_config.max_draft_len + self.runtime_draft_len = spec_config.max_draft_len else: self.without_logits = False @@ -1281,12 +1274,38 @@ def _get_graphs_to_capture( ): graphs = [(graph_bs, draft_len) for graph_bs, draft_len in self._dynamic_draft_len_mapping.items()] + # Workaround for dynamic draft length: + # capture the maximum speculative graph shape up front. Dynamic draft length + # breaks the previous assumption that attention workspace demand can be safely + # ordered by batch size alone; a later graph shape may require a larger shared + # graph workspace, and resizing that workspace can change its data_ptr and + # invalidate pointers captured by earlier graphs, causing illegal memory access + # on replay. + # + # This adds the overhead of one extra captured graph, and that graph is not + # expected to be used by the normal schedule-driven dynamic draft-length path. + # + # Follow-up first-principles fix: + # query or precompute the exact attention workspace requirement for all + # reachable graph shapes, pre-size the shared graph workspace once without + # capturing an extra graph, and avoid resizing it in graph mode afterward. + max_spec_graph = (max(cuda_graph_batch_sizes), + self.original_max_draft_len) + if max_spec_graph not in graphs: + graphs.append(max_spec_graph) logger.info(f"Dynamic draft length enabled for one-model path. " f"Capturing {len(graphs)} graphs: {graphs}") return graphs # Case 3: Target model (two-model) or one-model without dynamic draft - draft_lengths = [self.max_total_draft_tokens] + # Match the runtime_draft_len semantics enforced in _prepare_tp_inputs: + # logical K for linear-tree modes, total tree tokens for tree decoding. + # spec_config is None for non-spec models — fall back to max_draft_len (= 0). + draft_lengths = [ + self.max_draft_len if + (self.spec_config is None or self.spec_config.is_linear_tree) else + self.max_total_draft_tokens + ] should_capture_no_spec = ( self.max_total_draft_tokens > 0 and not self.spec_config.spec_dec_mode.use_one_engine() @@ -1338,11 +1357,14 @@ def _capture_generation_cuda_graphs(self, sparse_config = self.sparse_attention_config if (isinstance(sparse_config, SeqLenAwareSparseAttentionConfig) and sparse_config.needs_separate_short_long_cuda_graphs()): - # For short sequences, use the (seq_len_threshold - max_draft_len - 1) as the maximum sequence length - # to make sure all of the past and current input tokens are within the sequence length threshold. + # For short sequences, subtract the maximum runtime tokens consumed + # by a generation step so all current-step tokens stay within the + # sequence length threshold. PARD uses 2K tokens here, not K+1. + max_runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.max_draft_len) # For long sequences, use the default maximum sequence length. - max_seq_len = sparse_config.seq_len_threshold - ( - self.max_draft_len + 1) + max_seq_len = (sparse_config.seq_len_threshold - + max_runtime_tokens_per_gen_step) if max_seq_len < effective_max_seq_len: max_seq_len_list = [effective_max_seq_len, max_seq_len] else: @@ -1652,12 +1674,15 @@ def _create_cuda_graph_warmup_request( result = ScheduledRequests() num_extra_decoding_steps = self._get_num_extra_decoding_steps() + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 # Add (batch_size - 1) dummy requests with seq_len=1. requests = kv_cache_manager.add_dummy_requests( list(range(batch_size - 1)), is_gen=True, - max_num_draft_tokens=draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, kv_reserve_draft_tokens=self.max_draft_loop_tokens, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, @@ -1710,7 +1735,7 @@ def _create_cuda_graph_warmup_request( request_ids=[batch_size - 1], token_nums=[token_num], is_gen=True, - max_num_draft_tokens=draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, kv_reserve_draft_tokens=self.max_draft_loop_tokens, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, @@ -2685,8 +2710,10 @@ def _update_target_input_tensors( non_blocking=True) # Prepare draft tokens + num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1 self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_( - next_draft_tokens_device[previous_slots, :].flatten(), + next_draft_tokens_device[ + previous_slots, :num_draft_tokens_per_extend_request].flatten(), non_blocking=True) # Compute kv_len_offsets and update offset tensors @@ -2722,8 +2749,10 @@ def _apply_incremental_update_target( # Pre-compute constants extend_requests = scheduled_requests.generation_requests num_extend_requests = len(extend_requests) - num_tokens_per_extend_request = self.runtime_draft_len + 1 spec_config = self.spec_config + num_tokens_per_extend_request = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1 prompt_lengths = torch.empty(num_extend_requests, dtype=torch.int, @@ -2785,7 +2814,8 @@ def _apply_incremental_update_target( prompt_lengths = prompt_lengths.tolist() num_cached_tokens_per_seq = num_cached_tokens_per_seq.tolist() - previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self.runtime_draft_len + previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy * + runtime_draft_token_buffer_width) self._update_target_input_tensors( num_accepted_tokens_device=num_accepted_tokens_device, @@ -3144,6 +3174,9 @@ def append_cross_attention_state(request: LlmRequest, # will contain previous batch indices of generation requests previous_batch_indices = [] previous_pos_indices = [] + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 for request in extend_requests: request_ids.append(request.py_request_id) request_accepted_path[ @@ -3192,31 +3225,31 @@ def append_cross_attention_state(request: LlmRequest, previous_batch_idx = request.py_batch_idx request.py_batch_idx = request.py_seq_slot - sequence_lengths.append(1 + self.runtime_draft_len) + sequence_lengths.append(runtime_tokens_per_gen_step) num_accepted_draft_tokens.append( request.py_num_accepted_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 - draft_lens.append(self.runtime_draft_len) + draft_lens.append(runtime_draft_token_buffer_width) gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.runtime_draft_len))) + len(position_ids) + runtime_tokens_per_gen_step))) position_ids.extend( list( - range(past_seen_token_num, past_seen_token_num + 1 + - self.runtime_draft_len))) + range(past_seen_token_num, past_seen_token_num + + runtime_tokens_per_gen_step))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * - (1 + self.runtime_draft_len)) + runtime_tokens_per_gen_step) num_cached_tokens_per_seq.append(past_seen_token_num + - self.runtime_draft_len + 1) + runtime_tokens_per_gen_step) request.cached_tokens = num_cached_tokens_per_seq[-1] if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend) and spec_config.is_linear_tree: - prompt_lengths.append(1 + self.runtime_draft_len) + prompt_lengths.append(runtime_tokens_per_gen_step) else: prompt_lengths.append(request.py_prompt_len) @@ -3554,30 +3587,36 @@ def previous_seq_slots_device(): # Initialize these two values to zeros self.previous_pos_id_offsets_cuda *= 0 self.previous_kv_lens_offsets_cuda *= 0 + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 if previous_batch_len > 0: previous_slots = previous_seq_slots_device() # previous input ids - previous_batch_tokens = previous_batch_len * ( - 1 + self.runtime_draft_len) + previous_batch_tokens = (previous_batch_len * + runtime_tokens_per_gen_step) new_tokens = new_tokens_device.transpose( 0, - 1)[previous_slots, :(1 + self.runtime_draft_len)].flatten() + 1)[previous_slots, :runtime_tokens_per_gen_step].flatten() self.input_ids_cuda[num_tokens:num_tokens + previous_batch_tokens].copy_( new_tokens, non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len - if self.runtime_draft_len > 0: - self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + - previous_batch_draft_tokens].copy_( - next_draft_tokens_device[ - previous_slots, :self. - runtime_draft_len].flatten(), - non_blocking=True) + previous_batch_draft_tokens = (previous_batch_len * + runtime_draft_token_buffer_width) + if runtime_draft_token_buffer_width > 0: + self.draft_tokens_cuda[ + num_draft_tokens:num_draft_tokens + + previous_batch_draft_tokens].copy_( + next_draft_tokens_device[ + previous_slots, : + runtime_draft_token_buffer_width].flatten(), + non_blocking=True) # prepare data for the preprocess inputs - kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1 + kv_len_offsets_device = (new_tokens_lens_device - + runtime_tokens_per_gen_step) previous_pos_indices_host = torch.tensor( previous_pos_indices, dtype=torch.int, @@ -3603,8 +3642,8 @@ def previous_seq_slots_device(): extend_dummy_requests) self.previous_pos_id_offsets_cuda[ (num_extend_reqeust_wo_dummy - previous_batch_len) * - (1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy * - (1 + self.runtime_draft_len)].copy_( + runtime_tokens_per_gen_step:num_extend_reqeust_wo_dummy * + runtime_tokens_per_gen_step].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) @@ -4876,6 +4915,8 @@ def forward(self, # Propagate runtime_draft_len (already set on self by py_executor) # to spec_metadata so downstream code (eagle3, interface, trtllm) can read it. spec_metadata.runtime_draft_len = self.runtime_draft_len + spec_metadata.runtime_tokens_per_gen_step = ( + self.get_runtime_tokens_per_gen_step(self.runtime_draft_len)) # Parallel-draft modes advertise a per-gen-step width via # tokens_per_gen_step (PARD: 2K, DFlash: K+1). Pass diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 904964acb8a3..2cd98d7bd73e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2521,26 +2521,53 @@ def _handle_dynamic_draft_len(self, from tensorrt_llm._torch.speculative.utils import \ get_draft_len_for_batch_size + spec_dec_mode = self.model_engine.spec_config.spec_dec_mode + # 1. Resolve runtime draft length from schedule runtime_draft_len = get_draft_len_for_batch_size( self.model_engine.spec_config.draft_len_schedule, scheduled_batch.batch_size, self.model_engine.max_draft_len) - # 2. Pad or truncate draft tokens to the resolved length - PADDING_TOKEN = 0 + DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. for request in scheduled_batch.generation_requests: - current_draft_len = len(request.py_draft_tokens) - if current_draft_len < runtime_draft_len: - padding_needed = runtime_draft_len - current_draft_len - request.py_draft_tokens.extend([PADDING_TOKEN] * - padding_needed) - elif current_draft_len > runtime_draft_len: - request.py_draft_tokens = request.py_draft_tokens[: - runtime_draft_len] + current_num_draft_tokens = len(request.py_draft_tokens) + if spec_dec_mode.is_pard(): + # special case: PARD carries 2K-1 draft tokens per request + runtime_draft_token_buffer_width = ( + self.model_engine.spec_config. + get_runtime_tokens_per_gen_step(runtime_draft_len) - 1) + current_runtime_draft_len = ( + current_num_draft_tokens + + 1) // 2 if current_num_draft_tokens > 0 else 0 + real_draft_tokens = request.py_draft_tokens[:min( + current_runtime_draft_len, runtime_draft_len)] + real_draft_tokens.extend( + [DRAFT_BUFFER_PAD] * + (runtime_draft_len - len(real_draft_tokens))) + request.py_draft_tokens = real_draft_tokens + [ + DRAFT_BUFFER_PAD + ] * (runtime_draft_token_buffer_width - + len(real_draft_tokens)) + else: + if current_num_draft_tokens < runtime_draft_len: + padding_needed = (runtime_draft_len - + current_num_draft_tokens) + request.py_draft_tokens.extend([DRAFT_BUFFER_PAD] * + padding_needed) + elif current_num_draft_tokens > runtime_draft_len: + request.py_draft_tokens = request.py_draft_tokens[: + runtime_draft_len] self.model_engine.runtime_draft_len = runtime_draft_len else: - self.model_engine.runtime_draft_len = self.model_engine.max_total_draft_tokens + # Linear-tree modes (incl. PARD) use logical K; tree decoding + # (e.g. EAGLE3 dynamic tree) uses total tree tokens. Same + # selection as _prepare_tp_inputs and _get_graphs_to_capture. + spec_config = self.model_engine.spec_config + self.model_engine.runtime_draft_len = ( + self.model_engine.max_draft_len + if spec_config is not None and spec_config.is_linear_tree else + self.model_engine.max_total_draft_tokens) def _can_queue(self, scheduled_batch): diff --git a/tensorrt_llm/_torch/speculative/dflash.py b/tensorrt_llm/_torch/speculative/dflash.py index 1e06dddf7877..efa31fa2cc1e 100644 --- a/tensorrt_llm/_torch/speculative/dflash.py +++ b/tensorrt_llm/_torch/speculative/dflash.py @@ -383,7 +383,18 @@ def forward( num_gens = batch_size - num_contexts raw_logits = logits - K = self.max_draft_len + K = spec_metadata.runtime_draft_len + + if K == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) # Lazy init buffers and attach worker reference for prepare() self._lazy_init_ctx_buffers(draft_model, spec_metadata, attn_metadata) @@ -485,7 +496,7 @@ def forward( ) vocab_size = gen_logits.shape[-1] - gen_logits = gen_logits.reshape(num_gens, self.max_draft_len, vocab_size) + gen_logits = gen_logits.reshape(num_gens, K, vocab_size) d2t = getattr(draft_model.model, "d2t", None) gen_draft_tokens = torch.argmax(gen_logits, dim=-1, keepdim=False).long() @@ -583,7 +594,7 @@ def prepare_1st_drafter_inputs( gen_accepted_tokens = accepted_tokens[num_contexts : num_contexts + num_gens, :] total_tokens_per_req = self._draft_tokens_per_req # K+1 - K = self.max_draft_len + K = spec_metadata.runtime_draft_len # Get captured multi-layer hidden states from spec_metadata captured_hs = spec_metadata.get_hidden_states(total_target_tokens) diff --git a/tensorrt_llm/_torch/speculative/draft_target.py b/tensorrt_llm/_torch/speculative/draft_target.py index c026b6d5b290..fe6dc7d1e4c7 100644 --- a/tensorrt_llm/_torch/speculative/draft_target.py +++ b/tensorrt_llm/_torch/speculative/draft_target.py @@ -70,7 +70,7 @@ def prepare(self): num_seqs, dtype=torch.int, device="cpu", pin_memory=prefer_pinned() ) self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True) - self.num_tokens -= self.num_generations * self.max_draft_len + self.num_tokens -= self.num_generations * self.runtime_draft_len self.is_spec_dec_tree = False self.is_spec_dec_dynamic_tree = False @@ -131,10 +131,11 @@ def _update_kv_after_first_draft_step( num_accepted_tokens: torch.Tensor, num_contexts: int, batch_size: int, + runtime_draft_len: int, ): if hasattr(attn_metadata, "kv_lens_cuda"): attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - self.max_draft_len - num_accepted_tokens[num_contexts:batch_size] + runtime_draft_len - num_accepted_tokens[num_contexts:batch_size] ) attn_metadata.kv_lens_cuda[:num_contexts] += 1 @@ -175,6 +176,18 @@ def forward( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len + + if runtime_draft_len == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) raw_logits = logits @@ -204,10 +217,10 @@ def forward( draft_kv_cache_manager = self.get_draft_kv_cache_manager(resource_manager) with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i in range(self.max_draft_len): + for i in range(runtime_draft_len): if i == 0: start_ids_gen = ( - spec_metadata.batch_indices_cuda[:num_gens] * (self.max_draft_len + 1) + spec_metadata.batch_indices_cuda[:num_gens] * (runtime_draft_len + 1) ).long() gather_ids_gen = ( start_ids_gen @@ -260,7 +273,11 @@ def forward( attn_metadata.host_request_types[: attn_metadata.num_contexts].fill_(1) attn_metadata.num_contexts = 0 self._update_kv_after_first_draft_step( - attn_metadata, num_accepted_tokens, num_contexts, batch_size + attn_metadata, + num_accepted_tokens, + num_contexts, + batch_size, + runtime_draft_len, ) else: self._update_kv_for_chained_draft_step(attn_metadata, batch_size) @@ -306,13 +323,14 @@ def sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len if spec_metadata.draft_tokens is None: draft_tokens = torch.zeros( - (num_gens, self.max_draft_len), dtype=torch.int, device=logits.device + (num_gens, runtime_draft_len), dtype=torch.int, device=logits.device ) else: - draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, self.max_draft_len) + draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, runtime_draft_len) return self._sample_and_accept_draft_tokens_base( logits, draft_tokens, num_contexts, batch_size, spec_metadata @@ -337,6 +355,7 @@ def prepare_1st_drafter_inputs( num_contexts = attn_metadata.num_contexts batch_size = attn_metadata.num_seqs num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len if num_contexts > 0: input_ids_ctx = self._prepare_context_input_ids( @@ -350,7 +369,9 @@ def prepare_1st_drafter_inputs( input_ids_ctx = torch.empty(0, dtype=torch.int32, device="cuda") if num_gens > 0: - input_ids_gen = accepted_tokens[num_contexts:, :].flatten().to(torch.int32) + input_ids_gen = ( + accepted_tokens[num_contexts:, : runtime_draft_len + 1].flatten().to(torch.int32) + ) else: input_ids_gen = torch.empty(0, dtype=torch.int32, device="cuda") diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 68dac4c5e9f0..1e20d8fbc9b1 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -519,7 +519,7 @@ def prepare(self): if sa_manager is not None: gen_request_ids = self.request_ids[num_seqs - self.num_generations:] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) def maybe_capture_hidden_states( self, @@ -700,7 +700,7 @@ def forward(self, num_accepted_tokens=num_accepted_tokens, num_gens=num_gens, num_contexts=num_contexts, - max_draft_len=self.max_draft_len, + max_draft_len=runtime_draft_len, ) # Save the old attn_metadata and spec_metadata diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index a6232866cc36..865dae305b47 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -315,8 +315,9 @@ def support_capturable_guided_decoder(self): ) or self.is_external_drafter() or self.is_sa() def support_dynamic_draft_len(self): - # TODO: expand to all one-model algorithms - return self.is_eagle3_one_model() or self.is_mtp_eagle_one_model() + return self.is_mtp_one_model() or self.is_eagle3_one_model( + ) or self.is_mtp_eagle_one_model() or self.is_pard() or self.is_dflash( + ) or self.is_draft_target_one_model() or self.is_sa() def has_draft_model(self): return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle() @@ -455,6 +456,9 @@ class SpecMetadata: # draft_len_schedule. Otherwise it equals max_draft_len (the static max). # Always set by model_engine.forward() before any downstream code reads it. runtime_draft_len: int = 0 + # Total runtime tokens per generation request for the current iteration, + # Normally, it equals 1 + runtime_draft_len. But for PARD, it equals 2 * runtime_draft_len. + runtime_tokens_per_gen_step: int = 1 # Auto-detected per step from populated sampling params: # True if every request is greedy (no temp/top_k/top_p) and we can take @@ -1113,9 +1117,8 @@ def _sample_and_accept_draft_tokens_base( num_accepted_tokens: [batch_size] - Number of accepted tokens per request """ # Derive draft length from the actual draft_tokens shape rather than - # spec_metadata.runtime_draft_len, because they can differ: PARD sets - # runtime_draft_len = 2K-1 for input sizing but only passes K draft - # tokens for acceptance; + # spec_metadata.runtime_draft_len, because callers may slice a wider + # runtime token layout down to the K draft tokens used for acceptance. runtime_draft_len = draft_tokens.shape[-1] num_gens = batch_size - num_contexts diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index dda345844c19..9718a48cfe02 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -238,7 +238,7 @@ def prepare(self): num_contexts = num_seqs - self.num_generations gen_request_ids = self.request_ids[num_contexts:] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) class MTPSampler(SpecSamplerBase): @@ -351,7 +351,7 @@ def forward( - hidden states: H_E, H_F, H_G, H_H (H_H is invalid) Draft model: MTP1: - # For generation request, `mtp_num_modules` of tokens will be used as input. + # For generation request, `runtime_draft_len` tokens are used as input. - input tokens: FGX - input hidden states: H_E, H_F, H_G - KV cache: (BCDE) + FGX @@ -402,6 +402,13 @@ def forward( - new generated draft tokens: UVQ ''' + runtime_draft_len = spec_metadata.runtime_draft_len + # skip the draft forward if the runtime draft length is 0 + if runtime_draft_len == 0: + return self.skip_drafting(input_ids, position_ids, hidden_states, + logits, attn_metadata, spec_metadata, + draft_model) + batch_size = attn_metadata.num_seqs raw_logits = logits @@ -432,7 +439,8 @@ def forward( # update attn metadata if attn_metadata is not None: - self.change_attn_metadata(num_accepted_tokens, attn_metadata) + self.change_attn_metadata(num_accepted_tokens, attn_metadata, + spec_metadata) # Run MTP layers to predict draft tokens next_draft_tokens = [] @@ -443,7 +451,8 @@ def forward( resource_manager) with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i, mtp_layer in enumerate(draft_model.mtp_layers): + for i, mtp_layer in enumerate( + draft_model.mtp_layers[:runtime_draft_len]): if self.guided_decoder is not None: new_tokens = draft_inputs['input_ids'][last_tokens_idx] self.guided_decoder.add_draft_batch(new_tokens, @@ -516,17 +525,17 @@ def skip_forward( resource_manager=None, ): batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.max_draft_len - accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + runtime_draft_len = spec_metadata.runtime_draft_len + accepted_tokens = torch.empty((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) - next_draft_tokens = torch.empty((batch_size, mtp_num_modules), + next_draft_tokens = torch.empty((batch_size, runtime_draft_len), dtype=torch.int, device=logits.device) - next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + next_new_tokens = torch.empty((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) return { @@ -599,14 +608,15 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): seq_lens = attn_metadata.seq_lens_cuda seq_lens_cpu = attn_metadata.seq_lens hidden_size = hidden_states.shape[-1] - mtp_num_modules = self.spec_config.max_draft_len + runtime_draft_len = spec_metadata.runtime_draft_len + max_draft_len = self.spec_config.max_draft_len if self.is_thop: _, _ = torch.ops.trtllm.mtp_update_hidden_states_op( input_ids, seq_lens, hidden_states, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, - mtp_num_modules, batch_size, num_contexts, hidden_size) + max_draft_len, batch_size, num_contexts, hidden_size) else: assert len(spec_metadata.request_ids) == batch_size mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool @@ -635,7 +645,7 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): dim=1) ctx_batch_idx = spec_metadata.batch_indices_cuda[:num_contexts] row_indices_ctx = ctx_batch_idx.unsqueeze(1).expand( - -1, mtp_num_modules) + -1, max_draft_len) col_indices_ctx = (seq_lens_ctx.unsqueeze(1) + spec_metadata.draft_token_indices_cuda) new_mtp_past_tokens.append(cat_tokens_ctx[row_indices_ctx, @@ -646,10 +656,10 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): # generation if num_gens > 0: unpacked_input_ids_gen = input_ids[num_ctx_tokens:].reshape( - num_gens, mtp_num_modules + 1).int() + num_gens, runtime_draft_len + 1).int() hidden_states_gen = hidden_states[num_ctx_tokens:, :] unpacked_hidden_states_gen = hidden_states_gen.reshape( - num_gens, mtp_num_modules + 1, hidden_size) + num_gens, runtime_draft_len + 1, hidden_size) cat_tokens_gen = torch.cat( (mtp_tokens[num_contexts:], unpacked_input_ids_gen), dim=1) cat_hidden_states_gen = torch.cat( @@ -658,10 +668,10 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): dim=1) gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] row_indices_gen = gen_batch_idx.unsqueeze(1).expand( - -1, mtp_num_modules) + -1, max_draft_len) col_indices_gen = ( num_accepted_tokens[num_contexts:].unsqueeze(1) + - spec_metadata.draft_token_indices_cuda) + spec_metadata.draft_token_indices_cuda[:max_draft_len]) new_mtp_past_tokens.append(cat_tokens_gen[row_indices_gen, col_indices_gen]) new_mtp_past_hidden_states.append( @@ -676,17 +686,17 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): new_mtp_past_hidden_states) @torch.compile(options={"max-autotune": True}) - def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules, + def topk_kernel(self, gen_logprobs, num_gens, runtime_draft_len, spec_metadata): topk_value, topk_indices = torch.topk(gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1) - topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1, + topk_indices = topk_indices.reshape(num_gens, runtime_draft_len + 1, self.spec_config.relaxed_topk) - topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1, + topk_value = topk_value.reshape(num_gens, runtime_draft_len + 1, self.spec_config.relaxed_topk) draft_tokens = spec_metadata.draft_tokens.reshape( - num_gens, mtp_num_modules) + num_gens, runtime_draft_len) return topk_value, topk_indices, draft_tokens @torch.compile(options={"max-autotune": True}) @@ -771,14 +781,14 @@ def sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts - mtp_num_modules = self.spec_config.max_draft_len + runtime_draft_len = spec_metadata.runtime_draft_len if logits.dim() == 1: logits = logits.unsqueeze(0) # The return buffer if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop: - accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)), + accepted_tokens = torch.ones((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, @@ -814,41 +824,40 @@ def sample_and_accept_draft_tokens( # generation gen_logprobs = self.process_generation_logits(logits, num_contexts) topk_value, topk_indices, draft_tokens = self.topk_kernel( - gen_logprobs, num_gens, mtp_num_modules, spec_metadata) + gen_logprobs, num_gens, runtime_draft_len, spec_metadata) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op( spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens, mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens, - mtp_num_modules, batch_size, num_contexts, + runtime_draft_len, batch_size, num_contexts, self.spec_config.relaxed_topk, self.spec_config.relaxed_delta, self.spec_config.begin_thinking_phase_token, self.spec_config.end_thinking_phase_token) # Apply force override for relaxed acceptance path num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, - spec_metadata.runtime_draft_len) + num_accepted_tokens, num_contexts, runtime_draft_len) # Strict acceptance else: if self.is_thop: # Temporary buffer target_tokens_cache = torch.zeros(batch_size * - (mtp_num_modules + 1), + (runtime_draft_len + 1), dtype=torch.int, device=logits.device) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( logits, spec_metadata.draft_tokens, target_tokens_cache, - mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) + runtime_draft_len, batch_size, num_contexts, + logits.shape[-1]) # Apply force override for THOP path num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, - spec_metadata.runtime_draft_len) + num_accepted_tokens, num_contexts, runtime_draft_len) else: # Reshape draft tokens for base implementation draft_tokens = spec_metadata.draft_tokens.reshape( - num_gens, mtp_num_modules) + num_gens, runtime_draft_len) # Use base implementation for strict acceptance accepted_tokens, num_accepted_tokens = self._sample_and_accept_draft_tokens_base( @@ -865,16 +874,17 @@ def sample_and_accept_draft_tokens( num_accepted_tokens=num_accepted_tokens, num_gens=num_gens, num_contexts=num_contexts, - max_draft_len=mtp_num_modules, + max_draft_len=runtime_draft_len, ) return accepted_tokens, num_accepted_tokens def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, - attn_metadata: AttentionMetadata): + attn_metadata: AttentionMetadata, + spec_metadata: MTPSpecMetadata): self._prepare_attn_metadata_for_spec_dec(attn_metadata) batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.max_draft_len + runtime_draft_len = spec_metadata.runtime_draft_len num_contexts = attn_metadata.num_contexts attn_metadata._seq_lens[num_contexts:batch_size] -= 1 @@ -886,7 +896,7 @@ def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, # buffer once the graph has been captured also - this will invalidate # the graph and force an expensive recapture. attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - mtp_num_modules + 1 - + runtime_draft_len + 1 - num_accepted_tokens[num_contexts:batch_size]) # A generation request's KV length can never be smaller than the # number of query tokens (mtp_num_modules) the draft layer processes @@ -898,14 +908,14 @@ def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, # computes a negative effective KV range and triggers an illegal # memory access (e.g. Step3p7 MTP with dense sliding-window attention). attn_metadata.kv_lens_cuda[num_contexts:batch_size].clamp_( - min=mtp_num_modules) + min=runtime_draft_len) attn_metadata.on_update_kv_lens() if attn_metadata.kv_cache_params is not None and not attn_metadata.is_cuda_graph: for i in range(num_contexts, batch_size): # used for vanilla MLA, list on cpu attn_metadata.kv_cache_params.num_cached_tokens_per_seq[ - i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item() + i] -= runtime_draft_len + 1 - num_accepted_tokens[i].item() def prepare_drafter_inputs( self, @@ -923,8 +933,8 @@ def prepare_drafter_inputs( Args: input_ids: torch.IntTensor [num_tokens] - The input ids of all requests. Flattened. - num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) + The input ids of all requests. Flatten. + num_tokens = sum(all prompts) + num_generation * (runtime_draft_len + 1) position_ids: torch.IntTensor [num_tokens] for RoPE models, or [3, 1, num_tokens] for MRoPE @@ -951,8 +961,8 @@ def prepare_drafter_inputs( Returns: draft_inputs input_ids: torch.Tensor [num_tokens] - The new input ids of all requests. Flattened. - num_tokens = sum(all prompts) + num_generation * (mtp_num_modules) + The new input ids of all requests. Flatten. + num_tokens = sum(all prompts) + num_generation * (runtime_draft_len) position_ids: torch.Tensor [num_tokens] for RoPE models, or [3, 1, num_tokens] for MRoPE @@ -975,7 +985,7 @@ def prepare_drafter_inputs( num_gens = batch_size - num_contexts mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool - mtp_num_modules = self.spec_config.max_draft_len + runtime_draft_len = spec_metadata.runtime_draft_len if self.is_thop: # Temporary buffer @@ -997,7 +1007,7 @@ def prepare_drafter_inputs( spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, hidden_states, accepted_tokens, num_accepted_tokens, return_input_ids, - return_hidden_states, mtp_num_modules, batch_size, + return_hidden_states, runtime_draft_len, batch_size, num_contexts, hidden_size) else: @@ -1022,11 +1032,18 @@ def prepare_drafter_inputs( accepted_tokens_gen = accepted_tokens[num_contexts:, :] input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx].unsqueeze(1) - input_ids_gen = torch.concat( - [mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen], - dim=1) + + if runtime_draft_len > 1: + history_tokens = mtp_past_tokens_pool[slot_ids][:, -( + runtime_draft_len - 1):] + else: + history_tokens = torch.empty((num_gens, 0), + dtype=torch.int, + device=input_ids.device) + input_ids_gen = torch.concat([history_tokens, input_ids_gen], + dim=1) hidden_states_gen = mtp_past_hidden_states_pool[ - slot_ids].flatten(0, 1) + slot_ids][:, -runtime_draft_len:, :].flatten(0, 1) return_input_ids_list.append(input_ids_gen.flatten(0, 1)) return_hidden_states_list.append(hidden_states_gen) # Concatenate into continuous buffers @@ -1045,17 +1062,17 @@ def prepare_drafter_inputs( position_ids, slice(num_ctx_tokens, None)) if position_ids_gen.dim() == 1: position_ids_gen = position_ids_gen.reshape( - num_gens, mtp_num_modules + 1)[:, -mtp_num_modules:] + num_gens, runtime_draft_len + 1)[:, -runtime_draft_len:] position_ids_gen = position_ids_gen - ( - 1 + mtp_num_modules - + 1 + runtime_draft_len - num_accepted_tokens[num_contexts:].unsqueeze(1)) position_ids_list.append(position_ids_gen.flatten()) else: leading_shape = position_ids_gen.shape[:-1] position_ids_gen = position_ids_gen.reshape( *leading_shape, num_gens, - mtp_num_modules + 1)[..., -mtp_num_modules:] - position_ids_delta = (1 + mtp_num_modules - + runtime_draft_len + 1)[..., -runtime_draft_len:] + position_ids_delta = (1 + runtime_draft_len - num_accepted_tokens[num_contexts:]).view( *((1, ) * len(leading_shape)), num_gens, 1) diff --git a/tensorrt_llm/_torch/speculative/pard.py b/tensorrt_llm/_torch/speculative/pard.py index 628de48022d7..34a4509314cd 100644 --- a/tensorrt_llm/_torch/speculative/pard.py +++ b/tensorrt_llm/_torch/speculative/pard.py @@ -48,7 +48,7 @@ def prepare(self): if sa_manager is not None: gen_request_ids = self.request_ids[num_seqs - self.num_generations :] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) def _get_sa_manager(self): """Get SA manager from spec_resource_manager. @@ -99,15 +99,6 @@ def __init__( def max_draft_len(self) -> int: return self.spec_config.max_draft_len - @property - def _draft_tokens_per_req(self) -> int: - """Total tokens per gen request in the draft forward. - - Uses 2K to fit all accepted tokens (up to K+1) plus K-1 mask tokens, - ensuring K unique predictions regardless of how many tokens were accepted. - """ - return 2 * self.max_draft_len - def _prepare_attn_metadata_for_pard(self, attn_metadata, spec_metadata): """ Save attn_metadata fields that PARD modifies during forward. @@ -190,13 +181,25 @@ def forward( num_gens = batch_size - num_contexts raw_logits = logits - K = self.max_draft_len + K = spec_metadata.runtime_draft_len + + if K == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) self._execute_guided_decoder_if_present(logits) # draft_tokens buffer has (2K-1) entries per gen request; extract the K real drafts if num_gens > 0: - draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, 2 * K - 1)[:, :K] + draft_tokens = spec_metadata.draft_tokens[: num_gens * (2 * K - 1)] + draft_tokens = draft_tokens.reshape(num_gens, 2 * K - 1)[:, :K] else: draft_tokens = spec_metadata.draft_tokens.reshape(0, K) @@ -262,14 +265,13 @@ def forward( gen_start_idx = attn_metadata.num_ctx_tokens request_bases = ( - torch.arange(num_gens, dtype=torch.long, device="cuda") - * self._draft_tokens_per_req + torch.arange(num_gens, dtype=torch.long, device="cuda") * (2 * K) + gen_start_idx ) gen_num_accepted = num_accepted_tokens[num_contexts:batch_size].long() base_offsets = gen_num_accepted - 1 # M = bonus position - offsets = torch.arange(self.max_draft_len, dtype=torch.long, device="cuda") + offsets = torch.arange(K, dtype=torch.long, device="cuda") gen_gather_ids = ( request_bases.unsqueeze(1) + base_offsets.unsqueeze(1) + offsets.unsqueeze(0) @@ -281,7 +283,7 @@ def forward( ) vocab_size = gen_logits.shape[-1] - gen_logits = gen_logits.reshape(num_gens, self.max_draft_len, vocab_size) + gen_logits = gen_logits.reshape(num_gens, K, vocab_size) # Use torch.argmax directly to avoid cute_argmax stride issues d2t = getattr(draft_model.model, "d2t", None) @@ -384,6 +386,8 @@ def prepare_1st_drafter_inputs( num_contexts = attn_metadata.num_contexts batch_size = attn_metadata.num_seqs num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len + total_tokens_per_req = 2 * runtime_draft_len if ( hasattr(self.spec_config, "mask_token_id") @@ -412,8 +416,6 @@ def prepare_1st_drafter_inputs( gen_num_accepted = num_accepted_tokens[num_contexts : num_contexts + num_gens] gen_accepted_tokens = accepted_tokens[num_contexts : num_contexts + num_gens, :] - total_tokens_per_req = self._draft_tokens_per_req # 2K - # Start with all mask tokens request_ids_2d = torch.full( (num_gens, total_tokens_per_req), @@ -452,9 +454,9 @@ def prepare_1st_drafter_inputs( - total_tokens_per_req ) else: - gen_pos_starts = position_ids[ - attn_metadata.num_ctx_tokens :: self._draft_tokens_per_req - ][:num_gens] + gen_pos_starts = position_ids[attn_metadata.num_ctx_tokens :: total_tokens_per_req][ + :num_gens + ] offsets = torch.arange(total_tokens_per_req, dtype=torch.int32, device="cuda") position_ids_gen = (gen_pos_starts.unsqueeze(1) + offsets.unsqueeze(0)).flatten() diff --git a/tensorrt_llm/_torch/speculative/sa_worker.py b/tensorrt_llm/_torch/speculative/sa_worker.py index f014423686f5..90cd617cef2a 100644 --- a/tensorrt_llm/_torch/speculative/sa_worker.py +++ b/tensorrt_llm/_torch/speculative/sa_worker.py @@ -82,7 +82,7 @@ def prepare(self) -> None: self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True) if self.sa_manager is not None: - self.sa_manager.prepare(self.request_ids, self.max_draft_len) + self.sa_manager.prepare(self.request_ids, self.runtime_draft_len) else: raise ValueError("SA manager is not set") @@ -160,6 +160,18 @@ def forward( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts raw_logits = logits + runtime_draft_len = spec_metadata.runtime_draft_len + + if runtime_draft_len == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) self._execute_guided_decoder_if_present(logits) @@ -206,6 +218,7 @@ def _sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len # Get draft tokens from spec_metadata (set during prepare) draft_tokens = spec_metadata.draft_tokens @@ -219,15 +232,15 @@ def _sample_and_accept_draft_tokens( if draft_tokens is None or draft_tokens.numel() == 0: use_zeros = True elif draft_tokens.dim() == 1: - # 1D tensor - try to reshape to [num_gens, max_draft_len] - expected_size = num_gens * self.max_draft_len + # 1D tensor - try to reshape to [num_gens, runtime_draft_len] + expected_size = num_gens * runtime_draft_len if draft_tokens.numel() == expected_size and num_gens > 0: - draft_tokens = draft_tokens.reshape(num_gens, self.max_draft_len) + draft_tokens = draft_tokens.reshape(num_gens, runtime_draft_len) else: use_zeros = True elif draft_tokens.dim() == 2: # 2D tensor - check shape - if draft_tokens.shape[-1] != self.max_draft_len: + if draft_tokens.shape[-1] != runtime_draft_len: use_zeros = True else: # Slice to get only generation requests' draft tokens @@ -238,7 +251,9 @@ def _sample_and_accept_draft_tokens( if use_zeros: # No valid draft tokens - create zeros for generation requests draft_tokens = torch.zeros( - (num_gens, self.max_draft_len), dtype=torch.int32, device=logits.device + (num_gens, runtime_draft_len), + dtype=torch.int32, + device=logits.device, ) # Use base implementation for sampling and acceptance @@ -275,7 +290,7 @@ def _generate_draft_tokens( """ sa_manager = spec_metadata.sa_manager request_ids = spec_metadata.request_ids - max_draft_len = self._max_draft_len + runtime_draft_len = spec_metadata.runtime_draft_len if sa_manager is None or request_ids is None: # No SA manager available, throw error @@ -286,7 +301,7 @@ def _generate_draft_tokens( request_ids, accepted_tokens, num_accepted_tokens, - max_draft_len, + runtime_draft_len, max_ngram_size=self._max_matching_ngram_size, ) else: @@ -294,7 +309,7 @@ def _generate_draft_tokens( request_ids, accepted_tokens, num_accepted_tokens, - max_draft_len, + runtime_draft_len, max_ngram_size=self._max_matching_ngram_size, ) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 609e11c5a27a..abd6e28a4dc8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1457,6 +1457,10 @@ def tokens_per_gen_step(self) -> int: """Total tokens per gen request in one spec dec iteration (including golden token).""" return 1 + self.max_total_draft_tokens + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """Total tokens per gen request for the current runtime draft length.""" + return 1 + runtime_draft_len + def num_capture_layers(self) -> int: return 0 @@ -2176,6 +2180,10 @@ def tokens_per_gen_step(self) -> int: """PARD needs 2K tokens per gen request: K+1 accepted + K-1 masks.""" return 2 * self.max_draft_len + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """PARD needs 2K runtime tokens per gen request for logical draft length K.""" + return 1 if runtime_draft_len == 0 else 2 * runtime_draft_len + def supports_backend(self, backend: str) -> bool: return backend == "pytorch" @@ -2230,6 +2238,10 @@ def tokens_per_gen_step(self) -> int: """ return self.max_draft_len + 1 + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """DFlash needs K+1 runtime tokens per gen request (K drafts + 1 bonus).""" + return 1 + runtime_draft_len + def supports_backend(self, backend: str) -> bool: return backend == "pytorch" diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 1e9f7ff123aa..4d7c0290ba47 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -16,6 +16,8 @@ meta-llama/Llama-3.1-8B-Instruct: - spec_dec_algo: PARD extra_acc_spec: use_sa_spec accuracy: 74.20 + - spec_dec_algo: Draft_Target + accuracy: 74.20 - spec_dec_algo: DFlash accuracy: 74.20 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ff21b81babe7..2bdae70dcc2f 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -30,8 +30,8 @@ # isort: off from tensorrt_llm.llmapi import ( AttentionDpConfig, CudaGraphConfig, DeepSeekSparseAttentionConfig, - DFlashDecodingConfig, Eagle3DecodingConfig, KvCacheConfig, - MiniMaxM3SparseAttentionConfig, MoeConfig, MTPDecodingConfig, + DFlashDecodingConfig, DraftTargetDecodingConfig, Eagle3DecodingConfig, + KvCacheConfig, MiniMaxM3SparseAttentionConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, PARDDecodingConfig, RocketSparseAttentionConfig, SADecodingConfig, SamplingParams, SchedulerConfig, SkipSoftmaxAttentionConfig, SAEnhancerConfig, TorchCompileConfig) @@ -446,6 +446,39 @@ def test_eagle3_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_blackwell + def test_eagle3_sa_dynamic_draft_len(self): + pytorch_config = dict( + max_batch_size=500, + disable_overlap_scheduler=False, + cuda_graph_config=(CudaGraphConfig(max_batch_size=500)), + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + + eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B" + target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + + spec_config = Eagle3DecodingConfig( + max_draft_len=4, + speculative_model=eagle_model_dir, + eagle3_one_model=True, + sa_config=SAEnhancerConfig(enable_global_pool=True), + draft_len_schedule={ + 50: 4, + 200: 3, + 350: 2 + }, + ) + + with LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_hopper @parametrize_with_ids("overlap_scheduler", [True, False]) def test_pard(self, overlap_scheduler): @@ -529,7 +562,58 @@ def test_pard_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") - @skip_pre_hopper + @skip_pre_blackwell + def test_pard_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + pard_model_dir = f"{llm_models_root()}/PARD-Llama-3.2-1B" + pard_config = PARDDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=pard_model_dir, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=pard_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + def test_pard_sa_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + max_batch_size=500, + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + pard_model_dir = f"{llm_models_root()}/PARD-Llama-3.2-1B" + pard_config = PARDDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=pard_model_dir, + sa_config=SAEnhancerConfig(enable_global_pool=True), + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=16384, + **pytorch_config, + speculative_config=pard_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + def test_dflash(self): pytorch_config = dict( max_batch_size=8, @@ -553,6 +637,37 @@ def test_dflash(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + def test_dflash_dynamic_draft_len(self): + # DFlash uses a Qwen3-style draft with q/k_norm and 8K-wide cross-attn + # context, so the per-layer rmsnorm row count scales as + # B * max_ctx * num_kv_heads. Very large batches (e.g. 500) push that + # past the flashinfer rmsnorm kernel's stable range; cap at 200. + draft_len_schedule = {50: 4, 100: 3, 150: 2} + max_draft_len = 4 + pytorch_config = dict( + max_batch_size=200, + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(max_batch_size=200, + enable_padding=True), + ) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.6) + dflash_model_dir = f"{llm_models_root()}/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + spec_config = DFlashDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=dflash_model_dir, + draft_len_schedule=draft_len_schedule, + ) + with LLM(model=target_model_dir, + max_seq_len=8192, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_hopper def test_ngram(self): max_bs = 16 @@ -610,6 +725,65 @@ def test_suffix_automaton(self, enable_global_pool): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + def test_suffix_automaton_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + + pytorch_config = dict( + max_batch_size=500, + disable_overlap_scheduler=True, + cuda_graph_config=cuda_graph_config, + ) + + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + free_gpu_memory_fraction=0.8) + spec_config = SADecodingConfig( + max_draft_len=max_draft_len, + max_matching_ngram_size=-1, + enable_global_pool=True, + draft_len_schedule=draft_len_schedule, + ) + + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + enable_chunked_prefill=False, + max_num_tokens=8192) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + def test_draft_target_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + free_gpu_memory_fraction=0.6, + ) + + spec_config = DraftTargetDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=self.MODEL_PATH, + draft_len_schedule=draft_len_schedule, + ) + + with LLM(model=self.MODEL_PATH, + **pytorch_config, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attn_backend", ["TRTLLM"]) @@ -1658,6 +1832,87 @@ def test_bfloat16_mtp_sa_global_pool(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @skip_pre_blackwell + @pytest.mark.skip_less_device_memory(60000) + def test_mtp_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + @pytest.mark.skip_less_device_memory(60000) + def test_mtp_sa_dynamic_draft_len(self): + """Accuracy test for MTP + SA with dynamic draft length.""" + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + max_batch_size=500, + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + sa_config=SAEnhancerConfig(enable_global_pool=True), + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + + @skip_pre_blackwell + @pytest.mark.skip_less_device_memory(60000) + def test_mtp_eagle_dynamic_draft_len(self): + draft_len_schedule = {50: 4, 200: 3, 350: 2} + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig(max_batch_size=500) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + # Force MTP-Eagle one-model path. + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + use_mtp_vanilla=False, + mtp_eagle_one_model=True, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @parametrize_with_ids("mtp_nextn", [0, 2]) def test_bfloat16_4gpus_kv_cache_aware_routing(self, mtp_nextn): diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 8bd7453fda7a..05e44318043a 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -387,6 +387,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus[llguidance-mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus[xgrammar-mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus[xgrammar-mtp_nextn=2] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_eagle_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_mtp_sa_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=nvfp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True] @@ -600,6 +603,8 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=False-attn_backend=TRTLLM] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=True-attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=True-attn_backend=TRTLLM] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_dflash_dynamic_draft_len +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_draft_target_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_dummy_load_format accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=True-overlap_scheduler=True] @@ -607,6 +612,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_a accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_rejection_dynamic_tree_smoke[no_dynamic_tree] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_rejection_dynamic_tree_smoke[dynamic_tree] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_global_pool accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=True] @@ -659,8 +665,11 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_c accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=True-attn_backend=TRTLLM-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_global_pool +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_suffix_automaton_dynamic_draft_len accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False] diff --git a/tests/unittest/_torch/speculative/hw_agnostic/test_mtp.py b/tests/unittest/_torch/speculative/hw_agnostic/test_mtp.py index acb34fc704e6..d987ed442b5d 100644 --- a/tests/unittest/_torch/speculative/hw_agnostic/test_mtp.py +++ b/tests/unittest/_torch/speculative/hw_agnostic/test_mtp.py @@ -374,7 +374,7 @@ def test_sample_and_accept_draft_tokens( mtp_num_modules=mtp_num_modules, ) spec_metadata.draft_tokens = draft_tokens - + spec_metadata.runtime_draft_len = mtp_num_modules # mtp worker mtpworker = MTPWorker(spec_config) @@ -1114,6 +1114,7 @@ def test_mtp_update_mtp_hidden_states( mtp_hidden_states_tensor_pool ) spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.runtime_draft_len = num_nextn_predict_layers spec_metadata.prepare() mtpworker = MTPWorker(spec_config) @@ -1706,6 +1707,7 @@ def test_prepare_drafter_inputs( mtp_hidden_states_tensor_pool ) spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.runtime_draft_len = num_nextn_predict_layers spec_metadata.prepare() mtpworker = MTPWorker(spec_config)