Skip to content

Commit b8ec221

Browse files
committed
xSupport MTP, MTP-Eagle, PARD.
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com> Clear naming. Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com> Fix CI Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com> Add draft_target support Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com> Add SA and SA hybrid support Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent 1d31029 commit b8ec221

15 files changed

Lines changed: 636 additions & 151 deletions

File tree

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,15 +1651,14 @@ def update_spec_dec_param(
16511651
# Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length.
16521652
# So we create cache for position offsets and packed mask for each draft length to avoid reallocation.
16531653
assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree"
1654-
runtime_draft_len = (spec_metadata.runtime_draft_len
1655-
if spec_metadata is not None else
1656-
max_draft_len)
1654+
runtime_draft_token_buffer_width = (
1655+
spec_metadata.runtime_tokens_per_gen_step - 1)
16571656
self.generate_spec_decoding_generation_length(
1658-
runtime_draft_len=runtime_draft_len)
1657+
runtime_draft_len=runtime_draft_token_buffer_width)
16591658
self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets(
1660-
self.max_num_requests, runtime_draft_len)
1659+
self.max_num_requests, runtime_draft_token_buffer_width)
16611660
self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask(
1662-
self.max_num_requests, runtime_draft_len)
1661+
self.max_num_requests, runtime_draft_token_buffer_width)
16631662

16641663
def generate_spec_decoding_generation_length(self, runtime_draft_len):
16651664
self.spec_decoding_generation_lengths[:self.max_num_requests].fill_(

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ def __init__(self, config: CUDAGraphRunnerConfig):
118118

119119
def _create_shared_static_tensors(self):
120120
"""Allocates static tensors sized for the largest possible batch."""
121-
max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0
122-
token_per_request = max_draft_len + 1
121+
runtime_draft_token_buffer_width = (
122+
self.config.original_max_total_draft_tokens
123+
if self.config.spec_config is not None else 0)
124+
token_per_request = runtime_draft_token_buffer_width + 1
123125
max_total_tokens = (self.max_supported_batch_size *
124126
self.max_beam_width * token_per_request)
125127
max_total_tokens = min(max_total_tokens, self.config.max_num_tokens)
@@ -443,6 +445,11 @@ def _get_padded_batch(self, batch: ScheduledRequests,
443445
if padding_size + batch.batch_size > self.config.batch_size:
444446
return 0
445447

448+
runtime_tokens_per_gen_step = (
449+
self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len)
450+
if self.spec_config is not None else 1 + runtime_draft_len)
451+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
452+
446453
# No padding if it would create too many concurrent requests.
447454
# This is not strictly required, but we should probably
448455
# respect the requirement just in case that changes in the future.
@@ -460,7 +467,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
460467
dummy_request = kv_cache_manager.add_dummy_requests(
461468
[dummy_request_id],
462469
is_gen=True,
463-
max_num_draft_tokens=runtime_draft_len,
470+
max_num_draft_tokens=runtime_draft_token_buffer_width,
464471
use_mrope=self.config.use_mrope,
465472
max_beam_width=self.config.max_beam_width,
466473
draft_kv_cache_manager=draft_kv_cache_manager)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def __init__(
182182
spec_config.tokens_per_gen_step -
183183
1) if spec_config is not None else 0
184184
# Saved before zeroing for draft models; used by update_spec_dec_param.
185-
self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
185+
self._spec_dec_max_total_draft_tokens = (
186+
spec_config.max_total_draft_tokens
187+
if spec_config is not None else 0)
186188

187189
preserve_wrapped_eagle3_widths = (spec_config is not None
188190
and is_draft_model
@@ -341,11 +343,14 @@ def __init__(
341343
self.llm_args.attn_backend,
342344
sparse_attn_config=self.sparse_attention_config)
343345

346+
self.get_runtime_tokens_per_gen_step = spec_config.get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len: 1
347+
344348
if self.is_spec_decode:
345349
self.spec_metadata = None
346350
update_spec_config_from_model_config(self.spec_config,
347351
self.model.config)
348-
max_num_draft_tokens = self.original_max_total_draft_tokens * self.batch_size
352+
max_num_draft_tokens = (self.original_max_total_draft_tokens *
353+
self.batch_size)
349354
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
350355
dtype=torch.int,
351356
device='cuda')
@@ -876,12 +881,31 @@ def _get_graphs_to_capture(
876881
):
877882
graphs = [(graph_bs, draft_len) for graph_bs, draft_len in
878883
self._dynamic_draft_len_mapping.items()]
884+
# Workaround for dynamic draft length:
885+
# capture the maximum speculative graph shape up front. Dynamic draft length
886+
# breaks the previous assumption that attention workspace demand can be safely
887+
# ordered by batch size alone; a later graph shape may require a larger shared
888+
# graph workspace, and resizing that workspace can change its data_ptr and
889+
# invalidate pointers captured by earlier graphs, causing illegal memory access
890+
# on replay.
891+
#
892+
# This adds the overhead of one extra captured graph, and that graph is not
893+
# expected to be used by the normal schedule-driven dynamic draft-length path.
894+
#
895+
# Follow-up first-principles fix:
896+
# query or precompute the exact attention workspace requirement for all
897+
# reachable graph shapes, pre-size the shared graph workspace once without
898+
# capturing an extra graph, and avoid resizing it in graph mode afterward.
899+
max_spec_graph = (max(cuda_graph_batch_sizes),
900+
self.original_max_draft_len)
901+
if max_spec_graph not in graphs:
902+
graphs.append(max_spec_graph)
879903
logger.info(f"Dynamic draft length enabled for one-model path. "
880904
f"Capturing {len(graphs)} graphs: {graphs}")
881905
return graphs
882906

883907
# Case 3: Target model (two-model) or one-model without dynamic draft
884-
draft_lengths = [self.max_total_draft_tokens]
908+
draft_lengths = [self.max_draft_len]
885909
should_capture_no_spec = (
886910
self.max_total_draft_tokens > 0
887911
and not self.spec_config.spec_dec_mode.use_one_engine()
@@ -1206,12 +1230,15 @@ def _create_cuda_graph_warmup_request(
12061230

12071231
result = ScheduledRequests()
12081232
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
1233+
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
1234+
draft_len)
1235+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
12091236

12101237
# Add (batch_size - 1) dummy requests with seq_len=1.
12111238
requests = kv_cache_manager.add_dummy_requests(
12121239
list(range(batch_size - 1)),
12131240
is_gen=True,
1214-
max_num_draft_tokens=draft_len,
1241+
max_num_draft_tokens=runtime_draft_token_buffer_width,
12151242
use_mrope=self.use_mrope,
12161243
max_beam_width=self.max_beam_width,
12171244
num_extra_decoding_steps=num_extra_decoding_steps,
@@ -1228,26 +1255,29 @@ def _create_cuda_graph_warmup_request(
12281255
available_tokens = kv_cache_manager.get_num_available_tokens(
12291256
token_num_upper_bound=max_seq_len,
12301257
batch_size=batch_size,
1231-
max_num_draft_tokens=draft_len)
1258+
max_num_draft_tokens=runtime_draft_token_buffer_width)
12321259

12331260
# Also consider draft KV cache capacity when it exists
12341261
if draft_kv_cache_manager is not None:
12351262
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
12361263
batch_size=batch_size,
12371264
token_num_upper_bound=max_seq_len,
1238-
max_num_draft_tokens=draft_len)
1265+
max_num_draft_tokens=runtime_draft_token_buffer_width)
12391266
available_tokens = min(available_tokens, draft_available_tokens)
12401267

12411268
token_num = max(
12421269
1,
12431270
min(
1244-
available_tokens, max_seq_len - 1 -
1245-
get_num_extra_kv_tokens(self.spec_config) - draft_len))
1271+
available_tokens,
1272+
max_seq_len - 1 - get_num_extra_kv_tokens(self.spec_config) -
1273+
runtime_draft_token_buffer_width))
12461274
model_config = self.model.model_config.pretrained_config
12471275
max_position_embeddings = getattr(model_config,
12481276
'max_position_embeddings', None)
12491277
if max_position_embeddings is not None:
1250-
token_num = min(token_num, max_position_embeddings - draft_len)
1278+
token_num = min(
1279+
token_num,
1280+
max_position_embeddings - runtime_draft_token_buffer_width)
12511281

12521282
assert token_num > num_extra_decoding_steps, (
12531283
"Cannot fuse drafting loop. Not enough KV cache space for all draft tokens."
@@ -1258,7 +1288,7 @@ def _create_cuda_graph_warmup_request(
12581288
request_ids=[batch_size - 1],
12591289
token_nums=[token_num],
12601290
is_gen=True,
1261-
max_num_draft_tokens=draft_len,
1291+
max_num_draft_tokens=runtime_draft_token_buffer_width,
12621292
use_mrope=self.use_mrope,
12631293
max_beam_width=self.max_beam_width,
12641294
num_extra_decoding_steps=num_extra_decoding_steps,
@@ -1985,8 +2015,10 @@ def _update_target_input_tensors(
19852015
non_blocking=True)
19862016

19872017
# Prepare draft tokens
2018+
num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1
19882019
self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_(
1989-
next_draft_tokens_device[previous_slots, :].flatten(),
2020+
next_draft_tokens_device[
2021+
previous_slots, :num_draft_tokens_per_extend_request].flatten(),
19902022
non_blocking=True)
19912023

19922024
# Compute kv_len_offsets and update offset tensors
@@ -2022,8 +2054,10 @@ def _apply_incremental_update_target(
20222054
# Pre-compute constants
20232055
extend_requests = scheduled_requests.generation_requests
20242056
num_extend_requests = len(extend_requests)
2025-
num_tokens_per_extend_request = self.runtime_draft_len + 1
20262057
spec_config = self.spec_config
2058+
num_tokens_per_extend_request = self.get_runtime_tokens_per_gen_step(
2059+
self.runtime_draft_len)
2060+
runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1
20272061

20282062
prompt_lengths = torch.empty(num_extend_requests,
20292063
dtype=torch.int,
@@ -2085,7 +2119,8 @@ def _apply_incremental_update_target(
20852119
prompt_lengths = prompt_lengths.tolist()
20862120
num_cached_tokens_per_seq = num_cached_tokens_per_seq.tolist()
20872121

2088-
previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self.runtime_draft_len
2122+
previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy *
2123+
runtime_draft_token_buffer_width)
20892124

20902125
self._update_target_input_tensors(
20912126
num_accepted_tokens_device=num_accepted_tokens_device,
@@ -2368,6 +2403,9 @@ def _prepare_tp_inputs(
23682403
# will contain previous batch indices of generation requests
23692404
previous_batch_indices = []
23702405
previous_pos_indices = []
2406+
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
2407+
self.runtime_draft_len)
2408+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
23712409
for request in extend_requests:
23722410
request_ids.append(request.py_request_id)
23732411
request_accepted_path[
@@ -2426,16 +2464,16 @@ def _prepare_tp_inputs(
24262464
previous_batch_idx = request.py_batch_idx
24272465
request.py_batch_idx = request.py_seq_slot
24282466

2429-
sequence_lengths.append(1 + self.runtime_draft_len)
2467+
sequence_lengths.append(runtime_tokens_per_gen_step)
24302468
num_accepted_draft_tokens.append(
24312469
request.py_num_accepted_draft_tokens)
24322470
past_seen_token_num = request.max_beam_num_tokens - 1
24332471

2434-
draft_lens.append(self.runtime_draft_len)
2472+
draft_lens.append(runtime_draft_token_buffer_width)
24352473
gather_ids.extend(
24362474
list(
24372475
range(len(position_ids),
2438-
len(position_ids) + 1 + self.runtime_draft_len)))
2476+
len(position_ids) + runtime_tokens_per_gen_step)))
24392477
# For the target model + tree decoding
24402478
if not self.is_draft_model and not spec_config.is_linear_tree:
24412479
assert spec_tree_manager is not None
@@ -2448,19 +2486,19 @@ def _prepare_tp_inputs(
24482486
position_ids.extend(
24492487
list(
24502488
range(
2451-
past_seen_token_num, past_seen_token_num + 1 +
2452-
self.runtime_draft_len)))
2489+
past_seen_token_num, past_seen_token_num +
2490+
runtime_tokens_per_gen_step)))
24532491
# previous tensor
24542492
previous_batch_indices.append(previous_batch_idx)
24552493
previous_pos_indices.extend([previous_batch_idx] *
2456-
(1 + self.runtime_draft_len))
2494+
runtime_tokens_per_gen_step)
24572495

24582496
num_cached_tokens_per_seq.append(past_seen_token_num +
2459-
self.runtime_draft_len + 1)
2497+
runtime_tokens_per_gen_step)
24602498
request.cached_tokens = num_cached_tokens_per_seq[-1]
24612499
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
24622500
self.attn_backend) and spec_config.is_linear_tree:
2463-
prompt_lengths.append(1 + self.runtime_draft_len)
2501+
prompt_lengths.append(runtime_tokens_per_gen_step)
24642502
else:
24652503
prompt_lengths.append(request.py_prompt_len)
24662504

@@ -2765,30 +2803,36 @@ def previous_seq_slots_device():
27652803
# Initialize these two values to zeros
27662804
self.previous_pos_id_offsets_cuda *= 0
27672805
self.previous_kv_lens_offsets_cuda *= 0
2806+
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
2807+
self.runtime_draft_len)
2808+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
27682809

27692810
if previous_batch_len > 0:
27702811
previous_slots = previous_seq_slots_device()
27712812
# previous input ids
2772-
previous_batch_tokens = previous_batch_len * (
2773-
1 + self.runtime_draft_len)
2813+
previous_batch_tokens = (previous_batch_len *
2814+
runtime_tokens_per_gen_step)
27742815
new_tokens = new_tokens_device.transpose(
27752816
0,
2776-
1)[previous_slots, :(1 + self.runtime_draft_len)].flatten()
2817+
1)[previous_slots, :runtime_tokens_per_gen_step].flatten()
27772818
self.input_ids_cuda[num_tokens:num_tokens +
27782819
previous_batch_tokens].copy_(
27792820
new_tokens, non_blocking=True)
27802821

27812822
# previous draft tokens
2782-
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
2783-
if self.runtime_draft_len > 0:
2784-
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
2785-
previous_batch_draft_tokens].copy_(
2786-
next_draft_tokens_device[
2787-
previous_slots, :self.
2788-
runtime_draft_len].flatten(),
2789-
non_blocking=True)
2823+
previous_batch_draft_tokens = (previous_batch_len *
2824+
runtime_draft_token_buffer_width)
2825+
if runtime_draft_token_buffer_width > 0:
2826+
self.draft_tokens_cuda[
2827+
num_draft_tokens:num_draft_tokens +
2828+
previous_batch_draft_tokens].copy_(
2829+
next_draft_tokens_device[
2830+
previous_slots, :
2831+
runtime_draft_token_buffer_width].flatten(),
2832+
non_blocking=True)
27902833
# prepare data for the preprocess inputs
2791-
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
2834+
kv_len_offsets_device = (new_tokens_lens_device -
2835+
runtime_tokens_per_gen_step)
27922836
previous_pos_indices_host = torch.tensor(
27932837
previous_pos_indices,
27942838
dtype=torch.int,
@@ -2814,8 +2858,8 @@ def previous_seq_slots_device():
28142858
extend_dummy_requests)
28152859
self.previous_pos_id_offsets_cuda[
28162860
(num_extend_reqeust_wo_dummy - previous_batch_len) *
2817-
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
2818-
(1 + self.runtime_draft_len)].copy_(
2861+
runtime_tokens_per_gen_step:num_extend_reqeust_wo_dummy *
2862+
runtime_tokens_per_gen_step].copy_(
28192863
new_tokens_lens_device[self.previous_pos_indices_cuda[
28202864
0:previous_batch_tokens]],
28212865
non_blocking=True)
@@ -3679,6 +3723,8 @@ def forward(self,
36793723
# Propagate runtime_draft_len (already set on self by py_executor)
36803724
# to spec_metadata so downstream code (eagle3, interface, trtllm) can read it.
36813725
spec_metadata.runtime_draft_len = self.runtime_draft_len
3726+
spec_metadata.runtime_tokens_per_gen_step = (
3727+
self.get_runtime_tokens_per_gen_step(self.runtime_draft_len))
36823728

36833729
# PARD has 2K tokens per gen request, not K+1. Pass 2K-1
36843730
# so generation_lengths = 2K and the XQA kernel computes

0 commit comments

Comments
 (0)