Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
7353cbc
xSupport MTP, MTP-Eagle, PARD.
zheyuf Mar 16, 2026
d05dc84
Add SM constraint
zheyuf Apr 5, 2026
7b9cf71
Fix CI
zheyuf Apr 6, 2026
a1e51c0
resolve comments and CI
zheyuf Apr 10, 2026
dfced74
Fix CI
zheyuf Apr 17, 2026
3a3e369
Add comment on runtime_draft_token_buffer_width
zheyuf Apr 17, 2026
a28fa52
Merge remote-tracking branch 'upstream/main' into MTP_PARD_0315
zheyuf Apr 21, 2026
bb66ad7
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 22, 2026
dafb7c6
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 23, 2026
3dab7bc
Fix pre-commit
zheyuf Apr 23, 2026
a5f0454
[TRTLLM-11556][fix] Align runtime_draft_len selection across writers
zheyuf Apr 24, 2026
a4dd019
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 24, 2026
60804cf
[TRTLLM-11556][fix] Guard _get_graphs_to_capture against spec_config=…
zheyuf Apr 27, 2026
eeee802
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 27, 2026
3ff379e
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 27, 2026
90c9052
Run pre-commit
zheyuf Apr 27, 2026
a98f79d
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 27, 2026
4f50e44
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 29, 2026
b9c3b61
Merge branch 'main' into MTP_PARD_0315
zheyuf Apr 30, 2026
5af0403
Solve conflicts
zheyuf May 15, 2026
22a4718
Add DFlash support
zheyuf May 18, 2026
9e212f4
minor changes/
zheyuf May 18, 2026
69e2403
Merge branch 'main' of github.com:zheyuf/TensorRT-LLM into MTP_PARD_0315
zheyuf May 18, 2026
183ced8
[TRTLLM-11556][fix] Restore logical K for parallel-draft runtime_draf…
zheyuf May 19, 2026
53c0c43
Merge branch 'main' into MTP_PARD_0315
zheyuf May 19, 2026
44d025d
[None][test] Add test_dflash_dynamic_draft_len to QA llm_function_core
zheyuf May 19, 2026
8cb12a9
Merge branch 'main' into MTP_PARD_0315
zheyuf May 19, 2026
a1a7be1
Merge branch 'main' into MTP_PARD_0315
zheyuf May 20, 2026
f22d6fb
fix precommit
zheyuf May 20, 2026
5a54cd5
[None][chore] collapse 24 single-line docstrings in llm_args.py (D200)
zheyuf May 20, 2026
125ec30
Merge branch 'main' into MTP_PARD_0315
zheyuf May 21, 2026
bb6c016
Merge branch 'main' into MTP_PARD_0315
zheyuf May 21, 2026
72e5a07
Merge branch 'main' into MTP_PARD_0315
zheyuf May 21, 2026
417fd56
Merge branch 'main' into MTP_PARD_0315
zheyuf May 22, 2026
253042d
Merge branch 'main' into MTP_PARD_0315
zheyuf May 22, 2026
c3d32d3
[TRTLLM-11556][test] Gate dynamic-draft spec tests to Blackwell+
zheyuf May 23, 2026
b8c9855
Merge remote-tracking branch 'upstream/main' into MTP_PARD_0315
zheyuf May 26, 2026
6870ec0
Merge branch 'main' into MTP_PARD_0315
zheyuf May 26, 2026
d277f56
Merge branch 'main' into MTP_PARD_0315
zheyuf May 26, 2026
c4a972a
Merge branch 'main' into MTP_PARD_0315
zheyuf May 27, 2026
82b7565
Merge branch 'main' into MTP_PARD_0315
zheyuf May 27, 2026
7e2e155
Merge branch 'main' into MTP_PARD_0315
zheyuf May 27, 2026
dc2932d
Merge branch 'main' into MTP_PARD_0315
zheyuf May 28, 2026
4c5342a
Merge branch 'main' into MTP_PARD_0315
zheyuf May 28, 2026
fbd7a2c
Merge branch 'main' into MTP_PARD_0315
zheyuf May 28, 2026
c008fd2
Merge branch 'main' into MTP_PARD_0315
zheyuf May 28, 2026
e6e5a3e
Merge branch 'main' into MTP_PARD_0315
zheyuf May 29, 2026
488bd53
Merge branch 'main' into MTP_PARD_0315
zheyuf May 29, 2026
65c98ff
Merge branch 'main' into MTP_PARD_0315
zheyuf May 30, 2026
4754af4
Merge branch 'main' into MTP_PARD_0315
zheyuf May 30, 2026
fda6e83
Merge branch 'main' into MTP_PARD_0315
zheyuf May 30, 2026
51f09b3
Merge branch 'main' into MTP_PARD_0315
zheyuf May 31, 2026
3df81e9
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 1, 2026
48e5ab5
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 1, 2026
e396769
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 2, 2026
05de92d
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 2, 2026
4562dfe
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 3, 2026
e280f83
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 4, 2026
7761c7e
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 4, 2026
5b28a25
Merge remote-tracking branch 'upstream/main' into MTP_PARD_0315
zheyuf Jun 11, 2026
591c292
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 11, 2026
54dec26
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 12, 2026
dea4dd5
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 13, 2026
47cf224
Merge remote-tracking branch 'upstream/main' into MTP_PARD_0315
zheyuf Jun 22, 2026
f648547
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 22, 2026
8203e4f
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 23, 2026
4b2d446
Merge branch 'main' into MTP_PARD_0315
zheyuf Jun 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,15 +1096,16 @@ def update_spec_dec_param(
# Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length.
# So we create cache for position offsets and packed mask for each draft length to avoid reallocation.
assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree"
runtime_draft_len = (spec_metadata.runtime_draft_len
if spec_metadata is not None else
max_draft_len)
# For algos other than PARD, this equals runtime_draft_len (K); for PARD it's 2K-1.
runtime_draft_token_buffer_width = (
spec_metadata.runtime_tokens_per_gen_step -
1 if spec_metadata is not None else max_draft_len)
self.generate_spec_decoding_generation_length(
runtime_draft_len=runtime_draft_len)
runtime_draft_len=runtime_draft_token_buffer_width)
self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets(
self.max_num_requests, runtime_draft_len)
self.max_num_requests, runtime_draft_token_buffer_width)
self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask(
self.max_num_requests, runtime_draft_len)
self.max_num_requests, runtime_draft_token_buffer_width)
Comment thread
mikeiovine marked this conversation as resolved.

self.update_position_offsets_for_cpp(cpp_query_len)

Expand Down
13 changes: 10 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ def __init__(self, config: CUDAGraphRunnerConfig):

def _create_shared_static_tensors(self):
"""Allocates static tensors sized for the largest possible batch."""
max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0
token_per_request = max_draft_len + 1
runtime_draft_token_buffer_width = (
self.config.original_max_total_draft_tokens
if self.config.spec_config is not None else 0)
token_per_request = runtime_draft_token_buffer_width + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, self.config.max_num_tokens)
Expand Down Expand Up @@ -486,6 +488,11 @@ def _get_padded_batch(self, batch: ScheduledRequests,
if padding_size + batch.batch_size > self.config.batch_size:
return 0

runtime_tokens_per_gen_step = (
self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len)
if self.spec_config is not None else 1 + runtime_draft_len)
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1

# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
Expand All @@ -503,7 +510,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
dummy_request = kv_cache_manager.add_dummy_requests(
[dummy_request_id],
is_gen=True,
max_num_draft_tokens=runtime_draft_len,
max_num_draft_tokens=runtime_draft_token_buffer_width,
use_mrope=self.config.use_mrope,
max_beam_width=self.config.max_beam_width,
draft_kv_cache_manager=draft_kv_cache_manager)
Expand Down
133 changes: 87 additions & 46 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py

Large diffs are not rendered by default.

49 changes: 38 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,26 +2521,53 @@ def _handle_dynamic_draft_len(self,
from tensorrt_llm._torch.speculative.utils import \
get_draft_len_for_batch_size

spec_dec_mode = self.model_engine.spec_config.spec_dec_mode

# 1. Resolve runtime draft length from schedule
runtime_draft_len = get_draft_len_for_batch_size(
self.model_engine.spec_config.draft_len_schedule,
scheduled_batch.batch_size, self.model_engine.max_draft_len)

# 2. Pad or truncate draft tokens to the resolved length
PADDING_TOKEN = 0
DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id.
for request in scheduled_batch.generation_requests:
current_draft_len = len(request.py_draft_tokens)
if current_draft_len < runtime_draft_len:
padding_needed = runtime_draft_len - current_draft_len
request.py_draft_tokens.extend([PADDING_TOKEN] *
padding_needed)
elif current_draft_len > runtime_draft_len:
request.py_draft_tokens = request.py_draft_tokens[:
runtime_draft_len]
current_num_draft_tokens = len(request.py_draft_tokens)
if spec_dec_mode.is_pard():
# special case: PARD carries 2K-1 draft tokens per request
runtime_draft_token_buffer_width = (
self.model_engine.spec_config.
get_runtime_tokens_per_gen_step(runtime_draft_len) - 1)
current_runtime_draft_len = (
current_num_draft_tokens +
1) // 2 if current_num_draft_tokens > 0 else 0
real_draft_tokens = request.py_draft_tokens[:min(
current_runtime_draft_len, runtime_draft_len)]
real_draft_tokens.extend(
[DRAFT_BUFFER_PAD] *
(runtime_draft_len - len(real_draft_tokens)))
request.py_draft_tokens = real_draft_tokens + [
DRAFT_BUFFER_PAD
] * (runtime_draft_token_buffer_width -
len(real_draft_tokens))
else:
if current_num_draft_tokens < runtime_draft_len:
padding_needed = (runtime_draft_len -
current_num_draft_tokens)
request.py_draft_tokens.extend([DRAFT_BUFFER_PAD] *
padding_needed)
elif current_num_draft_tokens > runtime_draft_len:
request.py_draft_tokens = request.py_draft_tokens[:
runtime_draft_len]

self.model_engine.runtime_draft_len = runtime_draft_len
else:
self.model_engine.runtime_draft_len = self.model_engine.max_total_draft_tokens
# Linear-tree modes (incl. PARD) use logical K; tree decoding
# (e.g. EAGLE3 dynamic tree) uses total tree tokens. Same
# selection as _prepare_tp_inputs and _get_graphs_to_capture.
spec_config = self.model_engine.spec_config
self.model_engine.runtime_draft_len = (
self.model_engine.max_draft_len
if spec_config is not None and spec_config.is_linear_tree else
self.model_engine.max_total_draft_tokens)

def _can_queue(self, scheduled_batch):

Expand Down
17 changes: 14 additions & 3 deletions tensorrt_llm/_torch/speculative/dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,18 @@ def forward(
num_gens = batch_size - num_contexts

raw_logits = logits
K = self.max_draft_len
K = spec_metadata.runtime_draft_len

if K == 0:
return self.skip_drafting(
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
)

# Lazy init buffers and attach worker reference for prepare()
self._lazy_init_ctx_buffers(draft_model, spec_metadata, attn_metadata)
Expand Down Expand Up @@ -485,7 +496,7 @@ def forward(
)

vocab_size = gen_logits.shape[-1]
gen_logits = gen_logits.reshape(num_gens, self.max_draft_len, vocab_size)
gen_logits = gen_logits.reshape(num_gens, K, vocab_size)

d2t = getattr(draft_model.model, "d2t", None)
gen_draft_tokens = torch.argmax(gen_logits, dim=-1, keepdim=False).long()
Expand Down Expand Up @@ -583,7 +594,7 @@ def prepare_1st_drafter_inputs(
gen_accepted_tokens = accepted_tokens[num_contexts : num_contexts + num_gens, :]

total_tokens_per_req = self._draft_tokens_per_req # K+1
K = self.max_draft_len
K = spec_metadata.runtime_draft_len

# Get captured multi-layer hidden states from spec_metadata
captured_hs = spec_metadata.get_hidden_states(total_target_tokens)
Expand Down
37 changes: 29 additions & 8 deletions tensorrt_llm/_torch/speculative/draft_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def prepare(self):
num_seqs, dtype=torch.int, device="cpu", pin_memory=prefer_pinned()
)
self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True)
self.num_tokens -= self.num_generations * self.max_draft_len
self.num_tokens -= self.num_generations * self.runtime_draft_len
self.is_spec_dec_tree = False
self.is_spec_dec_dynamic_tree = False

Expand Down Expand Up @@ -131,10 +131,11 @@ def _update_kv_after_first_draft_step(
num_accepted_tokens: torch.Tensor,
num_contexts: int,
batch_size: int,
runtime_draft_len: int,
):
if hasattr(attn_metadata, "kv_lens_cuda"):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_len - num_accepted_tokens[num_contexts:batch_size]
runtime_draft_len - num_accepted_tokens[num_contexts:batch_size]
)
attn_metadata.kv_lens_cuda[:num_contexts] += 1

Expand Down Expand Up @@ -175,6 +176,18 @@ def forward(
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
runtime_draft_len = spec_metadata.runtime_draft_len

if runtime_draft_len == 0:
return self.skip_drafting(
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
)

raw_logits = logits

Expand Down Expand Up @@ -204,10 +217,10 @@ def forward(
draft_kv_cache_manager = self.get_draft_kv_cache_manager(resource_manager)

with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i in range(self.max_draft_len):
for i in range(runtime_draft_len):
if i == 0:
start_ids_gen = (
spec_metadata.batch_indices_cuda[:num_gens] * (self.max_draft_len + 1)
spec_metadata.batch_indices_cuda[:num_gens] * (runtime_draft_len + 1)
).long()
gather_ids_gen = (
start_ids_gen
Expand Down Expand Up @@ -260,7 +273,11 @@ def forward(
attn_metadata.host_request_types[: attn_metadata.num_contexts].fill_(1)
attn_metadata.num_contexts = 0
self._update_kv_after_first_draft_step(
attn_metadata, num_accepted_tokens, num_contexts, batch_size
attn_metadata,
num_accepted_tokens,
num_contexts,
batch_size,
runtime_draft_len,
)
else:
self._update_kv_for_chained_draft_step(attn_metadata, batch_size)
Expand Down Expand Up @@ -306,13 +323,14 @@ def sample_and_accept_draft_tokens(
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
runtime_draft_len = spec_metadata.runtime_draft_len

if spec_metadata.draft_tokens is None:
draft_tokens = torch.zeros(
(num_gens, self.max_draft_len), dtype=torch.int, device=logits.device
(num_gens, runtime_draft_len), dtype=torch.int, device=logits.device
)
else:
draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, self.max_draft_len)
draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, runtime_draft_len)

return self._sample_and_accept_draft_tokens_base(
logits, draft_tokens, num_contexts, batch_size, spec_metadata
Expand All @@ -337,6 +355,7 @@ def prepare_1st_drafter_inputs(
num_contexts = attn_metadata.num_contexts
batch_size = attn_metadata.num_seqs
num_gens = batch_size - num_contexts
runtime_draft_len = spec_metadata.runtime_draft_len

if num_contexts > 0:
input_ids_ctx = self._prepare_context_input_ids(
Expand All @@ -350,7 +369,9 @@ def prepare_1st_drafter_inputs(
input_ids_ctx = torch.empty(0, dtype=torch.int32, device="cuda")

if num_gens > 0:
input_ids_gen = accepted_tokens[num_contexts:, :].flatten().to(torch.int32)
input_ids_gen = (
accepted_tokens[num_contexts:, : runtime_draft_len + 1].flatten().to(torch.int32)
)
else:
input_ids_gen = torch.empty(0, dtype=torch.int32, device="cuda")

Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def prepare(self):
if sa_manager is not None:
gen_request_ids = self.request_ids[num_seqs - self.num_generations:]
if gen_request_ids:
sa_manager.prepare(gen_request_ids, self.max_draft_len)
sa_manager.prepare(gen_request_ids, self.runtime_draft_len)

def maybe_capture_hidden_states(
self,
Expand Down Expand Up @@ -700,7 +700,7 @@ def forward(self,
num_accepted_tokens=num_accepted_tokens,
num_gens=num_gens,
num_contexts=num_contexts,
max_draft_len=self.max_draft_len,
max_draft_len=runtime_draft_len,
)

# Save the old attn_metadata and spec_metadata
Expand Down
13 changes: 8 additions & 5 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,9 @@ def support_capturable_guided_decoder(self):
) or self.is_external_drafter() or self.is_sa()

def support_dynamic_draft_len(self):
# TODO: expand to all one-model algorithms
return self.is_eagle3_one_model() or self.is_mtp_eagle_one_model()
return self.is_mtp_one_model() or self.is_eagle3_one_model(
) or self.is_mtp_eagle_one_model() or self.is_pard() or self.is_dflash(
) or self.is_draft_target_one_model() or self.is_sa()

def has_draft_model(self):
return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle()
Expand Down Expand Up @@ -455,6 +456,9 @@ class SpecMetadata:
# draft_len_schedule. Otherwise it equals max_draft_len (the static max).
# Always set by model_engine.forward() before any downstream code reads it.
runtime_draft_len: int = 0
# Total runtime tokens per generation request for the current iteration,
# Normally, it equals 1 + runtime_draft_len. But for PARD, it equals 2 * runtime_draft_len.
runtime_tokens_per_gen_step: int = 1
Comment thread
zheyuf marked this conversation as resolved.

# Auto-detected per step from populated sampling params:
# True if every request is greedy (no temp/top_k/top_p) and we can take
Expand Down Expand Up @@ -1113,9 +1117,8 @@ def _sample_and_accept_draft_tokens_base(
num_accepted_tokens: [batch_size] - Number of accepted tokens per request
"""
# Derive draft length from the actual draft_tokens shape rather than
# spec_metadata.runtime_draft_len, because they can differ: PARD sets
# runtime_draft_len = 2K-1 for input sizing but only passes K draft
# tokens for acceptance;
# spec_metadata.runtime_draft_len, because callers may slice a wider
# runtime token layout down to the K draft tokens used for acceptance.
runtime_draft_len = draft_tokens.shape[-1]
num_gens = batch_size - num_contexts

Expand Down
Loading
Loading