forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_runtime.py
More file actions
144 lines (109 loc) · 4.42 KB
/
Copy pathllm_runtime.py
File metadata and controls
144 lines (109 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
### :title Runtime Configuration Examples
### :order 6
### :section Customization
'''
This script demonstrates various runtime configuration options in TensorRT-LLM,
including KV cache management and CUDA graph optimizations.
**KV Cache Configuration:**
The KV cache (key-value cache) stores attention keys and values during inference,
which is crucial for efficient autoregressive generation. Proper KV cache configuration helps with:
1. **Memory Management**: Control GPU memory allocation for the key-value cache through
`free_gpu_memory_fraction`, balancing memory between model weights and cache storage.
2. **Block Reuse Optimization**: Enable `enable_block_reuse` to optimize memory usage
for shared prefixes across multiple requests, improving throughput for common prompts.
3. **Performance Tuning**: Configure cache block sizes and total capacity to match
your workload characteristics (batch size, sequence length, and request patterns).
Please refer to the `KvCacheConfig` API reference for more details.
**CUDA Graph Configuration:**
CUDA graphs help reduce kernel launch overhead and improve GPU utilization by capturing
and replaying GPU operations. Benefits include:
- Reduced kernel launch overhead for repeated operations
- Better GPU utilization through optimized execution
- Improved throughput for inference workloads
Please refer to the `CudaGraphConfig` API reference for more details.
**How to Run:**
Run all examples:
```bash
python llm_runtime.py
```
Run specific example:
```bash
python llm_runtime.py --example kv_cache
python llm_runtime.py --example cuda_graph
```
'''
import argparse
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
def example_cuda_graph_config():
"""
Example demonstrating CUDA graph configuration for performance optimization.
CUDA graphs help with:
- Reduced kernel launch overhead
- Better GPU utilization
- Improved throughput for repeated operations
"""
print("\n=== CUDA Graph Configuration Example ===")
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4],
enable_padding=True,
)
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
cuda_graph_config=cuda_graph_config, # Enable CUDA graphs
max_batch_size=4,
max_seq_len=512,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.5))
prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=50, temperature=0.8, top_p=0.95)
# This should benefit from CUDA graphs
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt}")
print(f"Generated: {output.outputs[0].text}")
print()
def example_kv_cache_config():
"""Example demonstrating KV cache configuration for memory management and performance."""
print("\n=== KV Cache Configuration Example ===")
print("\n1. KV Cache Configuration:")
llm_advanced = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_batch_size=8,
max_seq_len=1024,
kv_cache_config=KvCacheConfig(
# free_gpu_memory_fraction: the fraction of free GPU memory to allocate to the KV cache
free_gpu_memory_fraction=0.5,
# enable_block_reuse: whether to enable block reuse
enable_block_reuse=True))
prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]
outputs = llm_advanced.generate(prompts)
for i, output in enumerate(outputs):
print(f"Query {i+1}: {output.prompt}")
print(f"Answer: {output.outputs[0].text[:100]}...")
print()
def main():
"""
Main function to run all runtime configuration examples.
"""
parser = argparse.ArgumentParser(
description="Runtime Configuration Examples")
parser.add_argument("--example",
type=str,
choices=["kv_cache", "cuda_graph", "all"],
default="all",
help="Which example to run")
args = parser.parse_args()
if args.example == "kv_cache" or args.example == "all":
example_kv_cache_config()
if args.example == "cuda_graph" or args.example == "all":
example_cuda_graph_config()
if __name__ == "__main__":
main()