Skip to content

Commit 2b6c0b5

Browse files
committed
Support MTP, MTP-Eagle, PARD.
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com> Clear naming. Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com> Fix CI Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent ff07459 commit 2b6c0b5

11 files changed

Lines changed: 343 additions & 130 deletions

File tree

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,15 +1599,14 @@ def update_spec_dec_param(
15991599
# Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length.
16001600
# So we create cache for position offsets and packed mask for each draft length to avoid reallocation.
16011601
assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree"
1602-
runtime_draft_len = (spec_metadata.runtime_draft_len
1603-
if spec_metadata is not None else
1604-
max_draft_len)
1602+
runtime_draft_token_buffer_width = (
1603+
spec_metadata.runtime_tokens_per_gen_step - 1)
16051604
self.generate_spec_decoding_generation_length(
1606-
runtime_draft_len=runtime_draft_len)
1605+
runtime_draft_len=runtime_draft_token_buffer_width)
16071606
self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets(
1608-
self.max_num_requests, runtime_draft_len)
1607+
self.max_num_requests, runtime_draft_token_buffer_width)
16091608
self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask(
1610-
self.max_num_requests, runtime_draft_len)
1609+
self.max_num_requests, runtime_draft_token_buffer_width)
16111610

16121611
def generate_spec_decoding_generation_length(self, runtime_draft_len):
16131612
self.spec_decoding_generation_lengths[:self.max_num_requests].fill_(

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,10 @@ def __init__(self, config: CUDAGraphRunnerConfig):
119119

120120
def _create_shared_static_tensors(self):
121121
"""Allocates static tensors sized for the largest possible batch."""
122-
max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0
123-
token_per_request = max_draft_len + 1
122+
runtime_draft_token_buffer_width = (
123+
self.config.original_max_total_draft_tokens
124+
if self.config.spec_config is not None else 0)
125+
token_per_request = runtime_draft_token_buffer_width + 1
124126
max_total_tokens = (self.max_supported_batch_size *
125127
self.max_beam_width * token_per_request)
126128
max_total_tokens = min(max_total_tokens, self.config.max_num_tokens)
@@ -444,6 +446,11 @@ def _get_padded_batch(self, batch: ScheduledRequests,
444446
if padding_size + batch.batch_size > self.config.batch_size:
445447
return 0
446448

449+
runtime_tokens_per_gen_step = (
450+
self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len)
451+
if self.spec_config is not None else 1 + runtime_draft_len)
452+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
453+
447454
# No padding if it would create too many concurrent requests.
448455
# This is not strictly required, but we should probably
449456
# respect the requirement just in case that changes in the future.
@@ -461,7 +468,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
461468
dummy_request = kv_cache_manager.add_dummy_requests(
462469
[dummy_request_id],
463470
is_gen=True,
464-
max_num_draft_tokens=runtime_draft_len,
471+
max_num_draft_tokens=runtime_draft_token_buffer_width,
465472
use_mrope=self.config.use_mrope,
466473
max_beam_width=self.config.max_beam_width,
467474
draft_kv_cache_manager=draft_kv_cache_manager)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def __init__(
180180
spec_config.tokens_per_gen_step -
181181
1) if spec_config is not None else 0
182182
# Saved before zeroing for draft models; used by update_spec_dec_param.
183-
self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
183+
self._spec_dec_max_total_draft_tokens = (
184+
spec_config.max_total_draft_tokens
185+
if spec_config is not None else 0)
184186

185187
preserve_wrapped_eagle3_widths = (spec_config is not None
186188
and is_draft_model
@@ -334,11 +336,14 @@ def __init__(
334336
self.llm_args.attn_backend,
335337
sparse_attn_config=self.sparse_attention_config)
336338

339+
self.get_runtime_tokens_per_gen_step = spec_config.get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len: 1
340+
337341
if self.is_spec_decode:
338342
self.spec_metadata = None
339343
update_spec_config_from_model_config(self.spec_config,
340344
self.model.config)
341-
max_num_draft_tokens = self.original_max_total_draft_tokens * self.batch_size
345+
max_num_draft_tokens = (self.original_max_total_draft_tokens *
346+
self.batch_size)
342347
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
343348
dtype=torch.int,
344349
device='cuda')
@@ -869,7 +874,7 @@ def _get_graphs_to_capture(
869874
return graphs
870875

871876
# Case 3: Target model (two-model) or one-model without dynamic draft
872-
draft_lengths = [self.max_total_draft_tokens]
877+
draft_lengths = [self.max_draft_len]
873878
should_capture_no_spec = (
874879
self.max_total_draft_tokens > 0
875880
and not self.spec_config.spec_dec_mode.use_one_engine()
@@ -1194,12 +1199,15 @@ def _create_cuda_graph_warmup_request(
11941199

11951200
result = ScheduledRequests()
11961201
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
1202+
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
1203+
draft_len)
1204+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
11971205

11981206
# Add (batch_size - 1) dummy requests with seq_len=1.
11991207
requests = kv_cache_manager.add_dummy_requests(
12001208
list(range(batch_size - 1)),
12011209
is_gen=True,
1202-
max_num_draft_tokens=draft_len,
1210+
max_num_draft_tokens=runtime_draft_token_buffer_width,
12031211
use_mrope=self.use_mrope,
12041212
max_beam_width=self.max_beam_width,
12051213
num_extra_decoding_steps=num_extra_decoding_steps,
@@ -1216,26 +1224,29 @@ def _create_cuda_graph_warmup_request(
12161224
available_tokens = kv_cache_manager.get_num_available_tokens(
12171225
token_num_upper_bound=max_seq_len,
12181226
batch_size=batch_size,
1219-
max_num_draft_tokens=draft_len)
1227+
max_num_draft_tokens=runtime_draft_token_buffer_width)
12201228

12211229
# Also consider draft KV cache capacity when it exists
12221230
if draft_kv_cache_manager is not None:
12231231
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
12241232
batch_size=batch_size,
12251233
token_num_upper_bound=max_seq_len,
1226-
max_num_draft_tokens=draft_len)
1234+
max_num_draft_tokens=runtime_draft_token_buffer_width)
12271235
available_tokens = min(available_tokens, draft_available_tokens)
12281236

12291237
token_num = max(
12301238
1,
12311239
min(
1232-
available_tokens, max_seq_len - 1 -
1233-
get_num_extra_kv_tokens(self.spec_config) - draft_len))
1240+
available_tokens,
1241+
max_seq_len - 1 - get_num_extra_kv_tokens(self.spec_config) -
1242+
runtime_draft_token_buffer_width))
12341243
model_config = self.model.model_config.pretrained_config
12351244
max_position_embeddings = getattr(model_config,
12361245
'max_position_embeddings', None)
12371246
if max_position_embeddings is not None:
1238-
token_num = min(token_num, max_position_embeddings - draft_len)
1247+
token_num = min(
1248+
token_num,
1249+
max_position_embeddings - runtime_draft_token_buffer_width)
12391250

12401251
assert token_num > num_extra_decoding_steps, (
12411252
"Cannot fuse drafting loop. Not enough KV cache space for all draft tokens."
@@ -1246,7 +1257,7 @@ def _create_cuda_graph_warmup_request(
12461257
request_ids=[batch_size - 1],
12471258
token_nums=[token_num],
12481259
is_gen=True,
1249-
max_num_draft_tokens=draft_len,
1260+
max_num_draft_tokens=runtime_draft_token_buffer_width,
12501261
use_mrope=self.use_mrope,
12511262
max_beam_width=self.max_beam_width,
12521263
num_extra_decoding_steps=num_extra_decoding_steps,
@@ -1968,8 +1979,10 @@ def _update_target_input_tensors(
19681979
non_blocking=True)
19691980

19701981
# Prepare draft tokens
1982+
num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1
19711983
self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_(
1972-
next_draft_tokens_device[previous_slots, :].flatten(),
1984+
next_draft_tokens_device[
1985+
previous_slots, :num_draft_tokens_per_extend_request].flatten(),
19731986
non_blocking=True)
19741987

19751988
# Compute kv_len_offsets and update offset tensors
@@ -2005,8 +2018,10 @@ def _apply_incremental_update_target(
20052018
# Pre-compute constants
20062019
extend_requests = scheduled_requests.generation_requests
20072020
num_extend_requests = len(extend_requests)
2008-
num_tokens_per_extend_request = self.runtime_draft_len + 1
20092021
spec_config = self.spec_config
2022+
num_tokens_per_extend_request = self.get_runtime_tokens_per_gen_step(
2023+
self.runtime_draft_len)
2024+
runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1
20102025

20112026
prompt_lengths = torch.empty(num_extend_requests,
20122027
dtype=torch.int,
@@ -2068,7 +2083,8 @@ def _apply_incremental_update_target(
20682083
prompt_lengths = prompt_lengths.tolist()
20692084
num_cached_tokens_per_seq = num_cached_tokens_per_seq.tolist()
20702085

2071-
previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self.runtime_draft_len
2086+
previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy *
2087+
runtime_draft_token_buffer_width)
20722088

20732089
self._update_target_input_tensors(
20742090
num_accepted_tokens_device=num_accepted_tokens_device,
@@ -2347,6 +2363,9 @@ def _prepare_tp_inputs(
23472363
# will contain previous batch indices of generation requests
23482364
previous_batch_indices = []
23492365
previous_pos_indices = []
2366+
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
2367+
self.runtime_draft_len)
2368+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
23502369
for request in extend_requests:
23512370
request_ids.append(request.py_request_id)
23522371
request_accepted_path[
@@ -2405,16 +2424,16 @@ def _prepare_tp_inputs(
24052424
previous_batch_idx = request.py_batch_idx
24062425
request.py_batch_idx = request.py_seq_slot
24072426

2408-
sequence_lengths.append(1 + self.runtime_draft_len)
2427+
sequence_lengths.append(runtime_tokens_per_gen_step)
24092428
num_accepted_draft_tokens.append(
24102429
request.py_num_accepted_draft_tokens)
24112430
past_seen_token_num = request.max_beam_num_tokens - 1
24122431

2413-
draft_lens.append(self.runtime_draft_len)
2432+
draft_lens.append(runtime_draft_token_buffer_width)
24142433
gather_ids.extend(
24152434
list(
24162435
range(len(position_ids),
2417-
len(position_ids) + 1 + self.runtime_draft_len)))
2436+
len(position_ids) + runtime_tokens_per_gen_step)))
24182437
# For the target model + tree decoding
24192438
if not self.is_draft_model and not spec_config.is_linear_tree:
24202439
assert spec_tree_manager is not None
@@ -2427,19 +2446,19 @@ def _prepare_tp_inputs(
24272446
position_ids.extend(
24282447
list(
24292448
range(
2430-
past_seen_token_num, past_seen_token_num + 1 +
2431-
self.runtime_draft_len)))
2449+
past_seen_token_num, past_seen_token_num +
2450+
runtime_tokens_per_gen_step)))
24322451
# previous tensor
24332452
previous_batch_indices.append(previous_batch_idx)
24342453
previous_pos_indices.extend([previous_batch_idx] *
2435-
(1 + self.runtime_draft_len))
2454+
runtime_tokens_per_gen_step)
24362455

24372456
num_cached_tokens_per_seq.append(past_seen_token_num +
2438-
self.runtime_draft_len + 1)
2457+
runtime_tokens_per_gen_step)
24392458
request.cached_tokens = num_cached_tokens_per_seq[-1]
24402459
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
24412460
self.attn_backend) and spec_config.is_linear_tree:
2442-
prompt_lengths.append(1 + self.runtime_draft_len)
2461+
prompt_lengths.append(runtime_tokens_per_gen_step)
24432462
else:
24442463
prompt_lengths.append(request.py_prompt_len)
24452464

@@ -2740,30 +2759,36 @@ def previous_seq_slots_device():
27402759
# Initialize these two values to zeros
27412760
self.previous_pos_id_offsets_cuda *= 0
27422761
self.previous_kv_lens_offsets_cuda *= 0
2762+
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
2763+
self.runtime_draft_len)
2764+
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
27432765

27442766
if previous_batch_len > 0:
27452767
previous_slots = previous_seq_slots_device()
27462768
# previous input ids
2747-
previous_batch_tokens = previous_batch_len * (
2748-
1 + self.runtime_draft_len)
2769+
previous_batch_tokens = (previous_batch_len *
2770+
runtime_tokens_per_gen_step)
27492771
new_tokens = new_tokens_device.transpose(
27502772
0,
2751-
1)[previous_slots, :(1 + self.runtime_draft_len)].flatten()
2773+
1)[previous_slots, :runtime_tokens_per_gen_step].flatten()
27522774
self.input_ids_cuda[num_tokens:num_tokens +
27532775
previous_batch_tokens].copy_(
27542776
new_tokens, non_blocking=True)
27552777

27562778
# previous draft tokens
2757-
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
2758-
if self.runtime_draft_len > 0:
2759-
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
2760-
previous_batch_draft_tokens].copy_(
2761-
next_draft_tokens_device[
2762-
previous_slots, :self.
2763-
runtime_draft_len].flatten(),
2764-
non_blocking=True)
2779+
previous_batch_draft_tokens = (previous_batch_len *
2780+
runtime_draft_token_buffer_width)
2781+
if runtime_draft_token_buffer_width > 0:
2782+
self.draft_tokens_cuda[
2783+
num_draft_tokens:num_draft_tokens +
2784+
previous_batch_draft_tokens].copy_(
2785+
next_draft_tokens_device[
2786+
previous_slots, :
2787+
runtime_draft_token_buffer_width].flatten(),
2788+
non_blocking=True)
27652789
# prepare data for the preprocess inputs
2766-
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
2790+
kv_len_offsets_device = (new_tokens_lens_device -
2791+
runtime_tokens_per_gen_step)
27672792
previous_pos_indices_host = torch.tensor(
27682793
previous_pos_indices,
27692794
dtype=torch.int,
@@ -2789,8 +2814,8 @@ def previous_seq_slots_device():
27892814
extend_dummy_requests)
27902815
self.previous_pos_id_offsets_cuda[
27912816
(num_extend_reqeust_wo_dummy - previous_batch_len) *
2792-
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
2793-
(1 + self.runtime_draft_len)].copy_(
2817+
runtime_tokens_per_gen_step:num_extend_reqeust_wo_dummy *
2818+
runtime_tokens_per_gen_step].copy_(
27942819
new_tokens_lens_device[self.previous_pos_indices_cuda[
27952820
0:previous_batch_tokens]],
27962821
non_blocking=True)
@@ -3626,6 +3651,8 @@ def forward(self,
36263651
# Propagate runtime_draft_len (already set on self by py_executor)
36273652
# to spec_metadata so downstream code (eagle3, interface, trtllm) can read it.
36283653
spec_metadata.runtime_draft_len = self.runtime_draft_len
3654+
spec_metadata.runtime_tokens_per_gen_step = (
3655+
self.get_runtime_tokens_per_gen_step(self.runtime_draft_len))
36293656

36303657
attn_metadata.update_spec_dec_param(
36313658
batch_size=scheduled_requests.batch_size,

0 commit comments

Comments
 (0)