forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_sampling.py
More file actions
248 lines (189 loc) ยท 7.64 KB
/
Copy pathllm_sampling.py
File metadata and controls
248 lines (189 loc) ยท 7.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
### :title Sampling Techniques Showcase
### :order 6
### :section Customization
"""
This example demonstrates various sampling techniques available in TensorRT-LLM.
It showcases different sampling parameters and their effects on text generation.
"""
from typing import Optional
import click
from tensorrt_llm import LLM, SamplingParams
# Example prompts to demonstrate different sampling techniques
prompts = [
"What is the future of artificial intelligence?",
"Describe a beautiful sunset over the ocean.",
"Write a short story about a robot discovering emotions.",
]
def demonstrate_greedy_decoding(prompt: str):
"""Demonstrates greedy decoding with temperature=0."""
print("\n๐ฏ === GREEDY DECODING ===")
print("Using temperature=0 for deterministic, focused output")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
sampling_params = SamplingParams(
max_tokens=50,
temperature=0.0, # Greedy decoding
)
response = llm.generate(prompt, sampling_params)
print(f"Prompt: {prompt}")
print(f"Response: {response.outputs[0].text}")
def demonstrate_temperature_sampling(prompt: str):
"""Demonstrates temperature sampling with different temperature values."""
print("\n๐ก๏ธ === TEMPERATURE SAMPLING ===")
print(
"Higher temperature = more creative/random, Lower temperature = more focused"
)
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
temperatures = [0.3, 0.7, 1.0, 1.5]
for temp in temperatures:
sampling_params = SamplingParams(
max_tokens=50,
temperature=temp,
)
response = llm.generate(prompt, sampling_params)
print(f"Temperature {temp}: {response.outputs[0].text}")
def demonstrate_top_k_sampling(prompt: str):
"""Demonstrates top-k sampling with different k values."""
print("\n๐ === TOP-K SAMPLING ===")
print("Only consider the top-k most likely tokens at each step")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
top_k_values = [1, 5, 20, 50]
for k in top_k_values:
sampling_params = SamplingParams(
max_tokens=50,
temperature=0.8, # Use moderate temperature
top_k=k,
)
response = llm.generate(prompt, sampling_params)
print(f"Top-k {k}: {response.outputs[0].text}")
def demonstrate_top_p_sampling(prompt: str):
"""Demonstrates top-p (nucleus) sampling with different p values."""
print("\n๐ฏ === TOP-P (NUCLEUS) SAMPLING ===")
print("Only consider tokens whose cumulative probability is within top-p")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
top_p_values = [0.1, 0.5, 0.9, 0.95]
for p in top_p_values:
sampling_params = SamplingParams(
max_tokens=50,
temperature=0.8, # Use moderate temperature
top_p=p,
)
response = llm.generate(prompt, sampling_params)
print(f"Top-p {p}: {response.outputs[0].text}")
def demonstrate_combined_sampling(prompt: str):
"""Demonstrates combined top-k and top-p sampling."""
print("\n๐ === COMBINED TOP-K + TOP-P SAMPLING ===")
print("Using both top-k and top-p together for balanced control")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
sampling_params = SamplingParams(
max_tokens=50,
temperature=0.8,
top_k=40, # Consider top 40 tokens
top_p=0.9, # Within 90% cumulative probability
)
response = llm.generate(prompt, sampling_params)
print(f"Combined (k=40, p=0.9): {response.outputs[0].text}")
def demonstrate_multiple_sequences(prompt: str):
"""Demonstrates generating multiple sequences with different sampling."""
print("\n๐ === MULTIPLE SEQUENCES ===")
print("Generate multiple different responses for the same prompt")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
sampling_params = SamplingParams(
max_tokens=40,
temperature=0.8,
top_k=50,
top_p=0.95,
n=3, # Generate 3 different sequences
)
response = llm.generate(prompt, sampling_params)
print(f"Prompt: {prompt}")
for i, output in enumerate(response.outputs):
print(f"Sequence {i+1}: {output.text}")
def demonstrate_beam_search(prompt: str):
"""Demonstrates beam search."""
print("\n๐ฏ === BEAM SEARCH ===")
beam_width = 2
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_beam_width=beam_width)
sampling_params = SamplingParams(
max_tokens=50,
use_beam_search=True,
n=beam_width,
)
response = llm.generate(prompt, sampling_params)
print(f"Prompt: {prompt}")
print(f"Response: {response.outputs[0].text}")
def demonstrate_with_logprobs(prompt: str):
"""Demonstrates generation with log probabilities."""
print("\n๐ === GENERATION WITH LOG PROBABILITIES ===")
print("Get probability information for generated tokens")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
sampling_params = SamplingParams(
max_tokens=20,
temperature=0.7,
top_k=50,
logprobs=True, # Return log probabilities
)
response = llm.generate(prompt, sampling_params)
output = response.outputs[0]
print(f"Prompt: {prompt}")
print(f"Generated: {output.text}")
print(f"Logprobs: {output.logprobs}")
def run_all_demonstrations(model_path: Optional[str] = None):
"""Run all sampling demonstrations."""
print("๐ TensorRT LLM Sampling Techniques Showcase")
print("=" * 50)
# Use the first prompt for most demonstrations
demo_prompt = prompts[0]
# Run all demonstrations
demonstrate_greedy_decoding(demo_prompt)
demonstrate_temperature_sampling(demo_prompt)
demonstrate_top_k_sampling(demo_prompt)
demonstrate_top_p_sampling(demo_prompt)
demonstrate_combined_sampling(demo_prompt)
demonstrate_multiple_sequences(demo_prompt)
demonstrate_beam_search(demo_prompt)
demonstrate_with_logprobs(demo_prompt)
print("\n๐ All sampling demonstrations completed!")
@click.command()
@click.option("--model",
type=str,
default=None,
help="Path to the model or model name")
@click.option("--demo",
type=click.Choice([
"greedy", "temperature", "top_k", "top_p", "combined",
"multiple", "beam", "logprobs", "creative", "all"
]),
default="all",
help="Which demonstration to run")
@click.option("--prompt", type=str, default=None, help="Custom prompt to use")
def main(model: Optional[str], demo: str, prompt: Optional[str]):
"""
Showcase various sampling techniques in TensorRT-LLM.
Examples:
python llm_sampling.py --demo all
python llm_sampling.py --demo temperature --prompt "Tell me a joke"
python llm_sampling.py --demo beam --model path/to/your/model
"""
demo_prompt = prompt or prompts[0]
# Run specific demonstration
if demo == "greedy":
demonstrate_greedy_decoding(demo_prompt)
elif demo == "temperature":
demonstrate_temperature_sampling(demo_prompt)
elif demo == "top_k":
demonstrate_top_k_sampling(demo_prompt)
elif demo == "top_p":
demonstrate_top_p_sampling(demo_prompt)
elif demo == "combined":
demonstrate_combined_sampling(demo_prompt)
elif demo == "multiple":
demonstrate_multiple_sequences(demo_prompt)
elif demo == "beam":
demonstrate_beam_search(demo_prompt)
elif demo == "logprobs":
demonstrate_with_logprobs(demo_prompt)
elif demo == "all":
run_all_demonstrations(model)
if __name__ == "__main__":
main()