Skip to content

Commit d3a16b3

Browse files
authored
[TRTLLM-11045][feat] Integrate SA with EAGLE3 and PARD (#11878)
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
1 parent ae2dc3d commit d3a16b3

14 files changed

Lines changed: 518 additions & 98 deletions

File tree

docs/source/features/speculative-decoding.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ speculative_config = Eagle3DecodingConfig(
4848
llm = LLM(model, speculative_config=speculative_config)
4949
```
5050

51+
EAGLE 3 can be combined with the [Suffix Automaton enhancement](#suffix-automaton-sa-enhancement) for improved acceptance rates on repetitive content. See the SA section below for details.
52+
5153
### NGram
5254

5355
The NGram method is an implementation of [this Prompt Lookup Decoding algorithm](https://github.com/apoorvumang/prompt-lookup-decoding).
@@ -88,6 +90,29 @@ speculative_config = MTPDecodingConfig(
8890
llm = LLM("/path/to/deepseek_model", speculative_config=speculative_config)
8991
```
9092

93+
MTP can be combined with the [Suffix Automaton enhancement](#suffix-automaton-sa-enhancement) for improved acceptance rates on repetitive content. See the SA section below for details.
94+
95+
### PARD
96+
97+
PARD (PARallel Draft) is a target-independent speculative decoding method that predicts all draft tokens in a single forward pass using mask tokens. Unlike MTP or EAGLE 3 which generate drafts one token at a time, PARD produces K draft tokens in parallel.
98+
99+
Reference: [PARD: Parallel Drafting for Speculative Decoding](https://arxiv.org/pdf/2504.18583)
100+
101+
* `max_draft_len`: Maximum draft candidate length.
102+
* `speculative_model`: Path or HuggingFace model ID for the PARD draft model.
103+
* `mask_token_id`: Token ID used as the mask token for parallel prediction. If not set, it is read from the draft model config.
104+
105+
```python
106+
from tensorrt_llm.llmapi import PARDDecodingConfig
107+
108+
speculative_config = PARDDecodingConfig(
109+
max_draft_len=4, speculative_model="/path/to/pard_model")
110+
111+
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
112+
```
113+
114+
PARD can be combined with the [Suffix Automaton enhancement](#suffix-automaton-sa-enhancement) for improved acceptance rates on repetitive content. See the SA section below for details.
115+
91116
### User-provided drafting
92117
A completely user-defined drafting method can be supplied with a `UserProvidedDecodingConfig` that includes
93118
* `max_draft_len`: Maximum draft candidate length.
@@ -103,6 +128,40 @@ speculative_config = UserProvidedDecodingConfig(
103128
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
104129
```
105130

131+
## Suffix Automaton (SA) Enhancement
132+
133+
The Suffix Automaton (SA) is a model-free, GPU-based pattern-matching draft enhancer. It finds suffix matches in previously generated tokens and proposes draft tokens when the match is long enough. SA is very accurate when it matches (exact pattern repetition), while neural methods are better for novel content — combining them gives the best of both worlds.
134+
135+
SA can be combined with the following speculative decoding techniques:
136+
137+
* **MTP** (`MTPDecodingConfig`)
138+
* **EAGLE 3** (`Eagle3DecodingConfig`)
139+
* **PARD** (`PARDDecodingConfig`)
140+
141+
To enable SA combination, set `use_sa_spec=True` on the speculative config. The `sa_spec_threshold` parameter controls the minimum suffix match length required to override the neural draft (default: 4).
142+
143+
```python
144+
from tensorrt_llm.llmapi import Eagle3DecodingConfig
145+
146+
speculative_config = Eagle3DecodingConfig(
147+
max_draft_len=4,
148+
speculative_model="/path/to/eagle3_model",
149+
use_sa_spec=True,
150+
sa_spec_threshold=4)
151+
152+
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
153+
```
154+
155+
SA can also be used as a standalone speculative decoding technique via `SADecodingConfig`:
156+
157+
```python
158+
from tensorrt_llm.llmapi import SADecodingConfig
159+
160+
speculative_config = SADecodingConfig(max_draft_len=4)
161+
162+
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
163+
```
164+
106165
## Usage with `trtllm-bench` and `trtllm-serve`
107166

108167
```{eval-rst}
@@ -117,6 +176,8 @@ Speculative decoding options must be specified via `--config config.yaml` for bo
117176
* `Eagle3`
118177
* `NGram`
119178
* `DraftTarget`
179+
* `PARD`
180+
* `SA`
120181

121182
> Note: The PyTorch backend supports only `Eagle3`. `decoding_type: Eagle` is accepted as a backward-compatible alias for `Eagle3`, but EAGLE (v1/v2) draft checkpoints are incompatible.
122183
@@ -138,6 +199,16 @@ speculative_config:
138199
speculative_model: /path/to/draft/model
139200
```
140201
202+
```yaml
203+
# SA combination: enable Suffix Automaton enhancement with any supported technique
204+
speculative_config:
205+
decoding_type: Eagle3
206+
max_draft_len: 4
207+
speculative_model: /path/to/draft/model
208+
use_sa_spec: true
209+
sa_spec_threshold: 4
210+
```
211+
141212
```{note}
142213
The field name `speculative_model_dir` can also be used as an alias for `speculative_config.speculative_model`. For example:
143214

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import functools
44
import gc
55
import inspect
6-
import itertools
76
import math
87
import os
98
import weakref
@@ -3489,21 +3488,28 @@ def _prepare_inputs(
34893488
raise NotImplementedError(
34903489
f"Unsupported cp_type {getattr(cp_type, 'name', cp_type)}.")
34913490

3492-
# Initialize SA state for new requests (MTP+SA path)
3491+
# Initialize SA state for new requests (MTP+SA, EAGLE3+SA, PARD+SA, etc.)
34933492
use_sa_spec = (self.spec_config is not None
34943493
and getattr(self.spec_config, 'use_sa_spec', False))
3495-
if (use_sa_spec and spec_metadata is not None
3496-
and hasattr(spec_metadata, 'sa_manager')
3497-
and spec_metadata.sa_manager is not None
3498-
and self.mapping.is_last_pp_rank()):
3499-
sa_manager = spec_metadata.sa_manager
3500-
for request in itertools.chain(
3501-
scheduled_requests.context_requests,
3502-
scheduled_requests.generation_requests):
3503-
if request.py_request_id not in sa_manager._initialized_requests:
3504-
sa_manager.add_request(request.py_request_id,
3505-
request.get_tokens(0))
3506-
sa_manager._initialized_requests.add(request.py_request_id)
3494+
if use_sa_spec and resource_manager is not None and self.mapping.is_last_pp_rank(
3495+
):
3496+
from tensorrt_llm._torch.speculative.suffix_automaton import \
3497+
SuffixAutomatonManager
3498+
spec_rm = resource_manager.get_resource_manager(
3499+
ResourceManagerType.SPEC_RESOURCE_MANAGER)
3500+
sa_manager = None
3501+
if spec_rm is not None:
3502+
if isinstance(spec_rm, SuffixAutomatonManager):
3503+
sa_manager = spec_rm
3504+
else:
3505+
sa_manager = getattr(spec_rm, 'sa_manager', None)
3506+
if sa_manager is not None:
3507+
for request in scheduled_requests.all_requests():
3508+
if request.py_request_id not in sa_manager._initialized_requests:
3509+
sa_manager.add_request(request.py_request_id,
3510+
request.get_tokens(0))
3511+
sa_manager._initialized_requests.add(
3512+
request.py_request_id)
35073513

35083514
return self._prepare_tp_inputs(
35093515
scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata,

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .mtp import MTPEagleWorker, MTPSampler, MTPSpecMetadata, MTPWorker
88
from .ngram import NGramDrafter, NGramPoolManager
99
from .pard import PARDSpecMetadata, PARDWorker
10+
from .sa_enhancer import SADraftEnhancer
1011
from .sa_worker import SASampler, SASpecMetadata, SAWorker
1112
from .save_hidden_state import (SaveHiddenStatesResourceManager,
1213
SaveHiddenStatesSpecMetadata)
@@ -31,6 +32,7 @@
3132
"NGramPoolManager",
3233
"PARDSpecMetadata",
3334
"PARDWorker",
35+
"SADraftEnhancer",
3436
"SASampler",
3537
"SASpecMetadata",
3638
"SAWorker",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..pyexecutor.scheduler import ScheduledRequests
1515
from .interface import SpecMetadata, SpecWorkerBase
1616
from .mtp import MTPSampler
17+
from .sa_enhancer import SADraftEnhancer
1718
from .spec_tree_manager import SpecTreeManager
1819

1920
if TYPE_CHECKING:
@@ -27,14 +28,21 @@ class Eagle3ResourceManager(BaseResourceManager):
2728
and one for the draft model. Use this class to manage the hidden states.
2829
"""
2930

30-
def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
31-
hidden_size: int, max_num_requests: int, max_seq_len: int,
32-
max_num_tokens: int):
31+
def __init__(self,
32+
config: "EagleDecodingConfig",
33+
dtype: torch.dtype,
34+
hidden_size: int,
35+
max_num_requests: int,
36+
max_seq_len: int,
37+
max_num_tokens: int,
38+
sa_manager=None):
3339
self.dtype = dtype
3440
self.max_draft_len = config.max_draft_len
3541
self.hidden_size = hidden_size
3642
self.max_num_requests = max_num_requests
3743
self.max_seq_len = max_seq_len
44+
# Optional SA manager for EAGLE3+SA mode
45+
self.sa_manager = sa_manager
3846
# There could be dummy request for padding batch when using CUDA graph.
3947
# Reserve one more slot for the dummy request.
4048
slot_size = self.max_seq_len + 1
@@ -94,13 +102,18 @@ def free_resources(self, request: LlmRequest):
94102
self.seq_lens[slot_id] = 0
95103
self.start_indices[slot_id] = 0
96104
self.slot_manager.remove_slot(request.request_id)
105+
if self.sa_manager is not None:
106+
self.sa_manager.remove_request(request.request_id)
97107

98108
def add_dummy_requests(self, request_ids: List[int]):
99109
for rid in request_ids:
100110
self.slot_manager.add_slot(rid)
111+
if self.sa_manager is not None:
112+
self.sa_manager.add_dummy_requests(request_ids)
101113

102114
def shutdown(self):
103-
pass
115+
if self.sa_manager is not None:
116+
self.sa_manager.shutdown()
104117

105118
def get_max_resource_count(self) -> int:
106119
return self.max_num_requests
@@ -298,6 +311,8 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
298311
dtype: torch.dtype = torch.bfloat16
299312
# The index of the batch inputs
300313
batch_indices_cuda: Optional[torch.Tensor] = None
314+
# Optional resource manager (used to access SA manager for EAGLE3+SA)
315+
spec_resource_manager: Optional[Eagle3ResourceManager] = None
301316

302317
def __post_init__(self):
303318
if self.layers_to_capture is None:
@@ -345,6 +360,12 @@ def prepare(self):
345360
non_blocking=True)
346361
self.num_tokens -= (self.num_generations) * self.max_draft_len
347362

363+
sa_manager = getattr(self.spec_resource_manager, 'sa_manager', None)
364+
if sa_manager is not None:
365+
gen_request_ids = self.request_ids[num_seqs - self.num_generations:]
366+
if gen_request_ids:
367+
sa_manager.prepare(gen_request_ids, self.max_draft_len)
368+
348369
def maybe_capture_hidden_states(
349370
self,
350371
layer_id: int,
@@ -375,6 +396,9 @@ def __init__(self,
375396
super().__init__(use_separate_draft_kv_cache)
376397
self.spec_config = spec_config
377398
self.mapping = mapping
399+
self.sa_enhancer: Optional[SADraftEnhancer] = None
400+
if getattr(spec_config, 'use_sa_spec', False):
401+
self.sa_enhancer = SADraftEnhancer(spec_config.sa_spec_threshold)
378402

379403
@property
380404
def max_draft_len(self) -> int:
@@ -424,6 +448,19 @@ def forward(self,
424448
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
425449
logits, attn_metadata, spec_metadata)
426450

451+
sa_manager = getattr(spec_metadata.spec_resource_manager, 'sa_manager',
452+
None)
453+
if self.sa_enhancer is not None and sa_manager is not None:
454+
self.sa_enhancer.extend_and_prepare(
455+
sa_manager=sa_manager,
456+
request_ids=spec_metadata.request_ids,
457+
accepted_tokens=accepted_tokens,
458+
num_accepted_tokens=num_accepted_tokens,
459+
num_gens=num_gens,
460+
num_contexts=num_contexts,
461+
max_draft_len=self.max_draft_len,
462+
)
463+
427464
# Save the old attn_metadata and spec_metadata
428465
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
429466

@@ -528,6 +565,14 @@ def forward(self,
528565
}
529566
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
530567

568+
# Override with SA draft tokens after all draft layers have run,
569+
# so that draft layers never see SA tokens in their inputs.
570+
if self.sa_enhancer is not None:
571+
gen_draft_tokens = next_draft_tokens[num_contexts:]
572+
gen_draft_tokens = self.sa_enhancer.maybe_override_all_draft_tokens(
573+
gen_draft_tokens)
574+
next_draft_tokens[num_contexts:] = gen_draft_tokens
575+
531576
# restore attn_metadata to support cuda graph
532577
self._restore_attn_metadata_from_spec_dec(attn_metadata)
533578
# restore all_rank_num_tokens for attention DP
@@ -588,11 +633,10 @@ def draft_decoder(
588633
Draft token ids. Flattened.
589634
'''
590635

591-
# Note: using greedy for draft tokens is a bit easier to implement and
592-
# faster. It doesn't affect the final output and seems to have a negligible
593-
# impact on AR.
594636
d2t = getattr(draft_model.model, "d2t", None)
595-
return self._draft_sampler_greedy(logits, d2t)
637+
draft_tokens = self._draft_sampler_greedy(logits, d2t)
638+
639+
return draft_tokens
596640

597641
def prepare_1st_drafter_inputs(
598642
self,

0 commit comments

Comments
 (0)