forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_speculative_decoding.py
More file actions
94 lines (74 loc) · 2.56 KB
/
llm_speculative_decoding.py
File metadata and controls
94 lines (74 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
### :title Speculative Decoding
### :order 5
### :section Customization
from typing import Optional
import click
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (Eagle3DecodingConfig, KvCacheConfig,
MTPDecodingConfig, NGramDecodingConfig)
prompts = [
"What is the capital of France?",
"What is the future of AI?",
]
def run_MTP(model: Optional[str] = None):
spec_config = MTPDecodingConfig(use_relaxed_acceptance_for_thinking=True,
relaxed_topk=10,
relaxed_delta=0.01)
llm = LLM(
# You can change this to a local model path if you have the model downloaded
model=model or "nvidia/DeepSeek-R1-FP4",
speculative_config=spec_config,
)
for prompt in prompts:
response = llm.generate(prompt, SamplingParams(max_tokens=10))
print(response.outputs[0].text)
def run_Eagle3():
spec_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
eagle3_one_model=True)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
speculative_config=spec_config,
kv_cache_config=kv_cache_config,
)
for prompt in prompts:
response = llm.generate(prompt, SamplingParams(max_tokens=10))
print(response.outputs[0].text)
def run_ngram():
spec_config = NGramDecodingConfig(
max_draft_len=3,
max_matching_ngram_size=3,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
speculative_config=spec_config,
# ngram doesn't work with overlap_scheduler
disable_overlap_scheduler=True,
)
for prompt in prompts:
response = llm.generate(prompt, SamplingParams(max_tokens=10))
print(response.outputs[0].text)
@click.command()
@click.argument("algo",
type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
@click.option("--model",
type=str,
default=None,
help="Path to the model or model name.")
def main(algo: str, model: Optional[str] = None):
algo = algo.upper()
if algo == "MTP":
run_MTP(model)
elif algo == "EAGLE3":
run_Eagle3()
elif algo == "NGRAM":
run_ngram()
else:
raise ValueError(f"Invalid algorithm: {algo}")
if __name__ == "__main__":
main()