-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvllm_client_example.py
More file actions
252 lines (201 loc) · 8.44 KB
/
vllm_client_example.py
File metadata and controls
252 lines (201 loc) · 8.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
"""
vLLM Client Example - Batched Inference
This script demonstrates how to make batched calls to a vLLM server using the
OpenAI-compatible API. The vLLM server (started via vllm_eval.py) exposes an
OpenAI-compatible endpoint that supports batched requests.
Usage:
# Basic usage with default server
python vllm_client_example.py
# Custom server URL
python vllm_client_example.py --base_url http://localhost:8000/v1
# With custom prompts file
python vllm_client_example.py --prompts_file my_prompts.txt
"""
import argparse
import asyncio
import time
from typing import Any
try:
from openai import AsyncOpenAI, OpenAI
except ImportError:
print("OpenAI package not installed. Install with: pip install openai")
import sys
sys.exit(1)
def create_sample_prompts() -> list[str]:
"""Create sample prompts for demonstration."""
return [
"Explain what a binary search tree is in one sentence.",
"Write a Python function to calculate factorial.",
"What is the time complexity of quicksort?",
"How do you reverse a linked list?",
"Explain the difference between a stack and a queue.",
]
def synchronous_batched_calls(
client: OpenAI, prompts: list[str], model: str = "default", max_tokens: int = 256, temperature: float = 0.7
) -> list[dict[str, Any]]:
"""
Make synchronous batched calls to vLLM server.
This approach sends requests one by one but is simpler to understand.
Args:
client: OpenAI client instance
prompts: List of prompt strings
model: Model name (vLLM uses 'default' or the actual model name)
max_tokens: Maximum tokens to generate per prompt
temperature: Sampling temperature
Returns:
List of response dictionaries
"""
print(f"\n{'=' * 80}")
print("Synchronous Batched Calls")
print(f"{'=' * 80}")
print(f"Processing {len(prompts)} prompts...")
responses = []
start_time = time.time()
for i, prompt in enumerate(prompts, 1):
print(f"\nRequest {i}/{len(prompts)}: {prompt[:50]}...")
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
)
result = {
"prompt": prompt,
"response": response.choices[0].message.content,
"tokens_used": response.usage.total_tokens if response.usage else None,
"finish_reason": response.choices[0].finish_reason,
}
responses.append(result)
print(f"Response: {result['response'][:100]}...")
print(f"Tokens: {result['tokens_used']}, Finish reason: {result['finish_reason']}")
elapsed = time.time() - start_time
print(f"\n{'=' * 80}")
print(f"Completed {len(prompts)} requests in {elapsed:.2f}s ({elapsed / len(prompts):.2f}s per request)")
print(f"{'=' * 80}")
return responses
async def async_batched_calls(
client: AsyncOpenAI, prompts: list[str], model: str = "default", max_tokens: int = 256, temperature: float = 0.7
) -> list[dict[str, Any]]:
"""
Make asynchronous batched calls to vLLM server.
This approach sends all requests concurrently for better throughput.
vLLM's continuous batching will handle them efficiently.
Args:
client: Async OpenAI client instance
prompts: List of prompt strings
model: Model name (vLLM uses 'default' or the actual model name)
max_tokens: Maximum tokens to generate per prompt
temperature: Sampling temperature
Returns:
List of response dictionaries
"""
print(f"\n{'=' * 80}")
print("Asynchronous Batched Calls")
print(f"{'=' * 80}")
print(f"Processing {len(prompts)} prompts concurrently...")
async def process_prompt(prompt: str, index: int) -> dict[str, Any]:
"""Process a single prompt asynchronously."""
print(f"\nRequest {index + 1}/{len(prompts)}: {prompt[:50]}...")
response = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
)
result = {
"prompt": prompt,
"response": response.choices[0].message.content,
"tokens_used": response.usage.total_tokens if response.usage else None,
"finish_reason": response.choices[0].finish_reason,
}
print(f"Response {index + 1}: {result['response'][:100]}...")
print(f"Tokens: {result['tokens_used']}, Finish reason: {result['finish_reason']}")
return result
start_time = time.time()
# Create tasks for all prompts
tasks = [process_prompt(prompt, i) for i, prompt in enumerate(prompts)]
# Wait for all tasks to complete
responses = await asyncio.gather(*tasks)
elapsed = time.time() - start_time
print(f"\n{'=' * 80}")
print(f"Completed {len(prompts)} requests in {elapsed:.2f}s ({elapsed / len(prompts):.2f}s per request)")
print(f"{'=' * 80}")
return list(responses)
def load_prompts_from_file(file_path: str) -> list[str]:
"""Load prompts from a text file (one prompt per line)."""
with open(file_path) as f:
return [line.strip() for line in f if line.strip()]
def main() -> None:
parser = argparse.ArgumentParser(
description="Example client for batched vLLM inference", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--base_url", type=str, default="http://localhost:8000/v1", help="vLLM server base URL (OpenAI-compatible)"
)
parser.add_argument("--model", type=str, default="default", help="Model name to use (default works for most cases)")
parser.add_argument("--max_tokens", type=int, default=256, help="Maximum tokens to generate per prompt")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--prompts_file", type=str, default=None, help="Path to file with prompts (one per line)")
parser.add_argument(
"--mode",
type=str,
choices=["sync", "async", "both"],
default="both",
help="Mode: sync (sequential), async (concurrent), or both",
)
args = parser.parse_args()
# Load prompts
if args.prompts_file:
print(f"Loading prompts from {args.prompts_file}...")
prompts = load_prompts_from_file(args.prompts_file)
else:
print("Using sample prompts...")
prompts = create_sample_prompts()
print(f"\nConnecting to vLLM server at: {args.base_url}")
print(f"Model: {args.model}")
print(f"Prompts to process: {len(prompts)}")
# Synchronous batched calls
if args.mode in ["sync", "both"]:
client = OpenAI(base_url=args.base_url, api_key="EMPTY") # vLLM doesn't require API key
try:
synchronous_batched_calls(
client=client,
prompts=prompts,
model=args.model,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
except Exception as e:
print(f"\nError in synchronous calls: {e}")
print("Make sure vLLM server is running!")
import sys
sys.exit(1)
# Asynchronous batched calls
if args.mode in ["async", "both"]:
async_client = AsyncOpenAI(base_url=args.base_url, api_key="EMPTY") # vLLM doesn't require API key
try:
asyncio.run(
async_batched_calls(
client=async_client,
prompts=prompts,
model=args.model,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
)
except Exception as e:
print(f"\nError in asynchronous calls: {e}")
print("Make sure vLLM server is running!")
import sys
sys.exit(1)
print(f"\n{'=' * 80}")
print("Example completed successfully!")
print(f"{'=' * 80}")
print("\nTips:")
print("- Async mode is faster for batched requests due to concurrent execution")
print("- vLLM uses continuous batching to efficiently handle concurrent requests")
print("- Adjust --max_tokens and --temperature based on your use case")
print("- For production, consider adding retry logic and better error handling")
if __name__ == "__main__":
main()