Skip to content

Commit 1e92fbb

Browse files
committed
support vllm
1 parent c3dbf8b commit 1e92fbb

4 files changed

Lines changed: 136 additions & 18 deletions

File tree

.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
OPENAI_API_KEY=your-api-key
22
OLLAMA_HOST=http://ollama.server:11434
3+
VLLM_HOST=http://vllm.server:8000
4+
VLLM_API_KEY=your-vllm-api-key
35
# do not change this TORCH_HOME variable
46
TORCH_HOME=./pretrained_models

config/okvqa.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
dataset: okvqa
22
task: vqa
33
prompt: chatgpt-35/okvqa
4-
llm_model: "gpt-4o-mini" # "gpt-3.5-turbo-0630"
5-
llm_code_model: "gpt-4o-mini" # "gpt-3.5-turbo-0630"
6-
embedding_model: "text-embedding-3-small"
4+
llm_model: "openai::gpt-4o-mini" # "gpt-3.5-turbo-0630"
5+
llm_code_model: "openai::gpt-4o-mini" # "gpt-3.5-turbo-0630"
6+
embedding_model: "openai::text-embedding-3-small"
77

88
debug: True
99
executor_port: 31888

config/refcoco.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
dataset: refcoco
22
task: grounding
33
prompt: chatgpt-35/refcoco
4-
llm_model: "gpt-4o-mini" # "gpt-3.5-turbo-0630"
5-
llm_code_model: "gpt-4o-mini" # "gpt-3.5-turbo-0630"
6-
embedding_model: "text-embedding-3-small"
4+
llm_model: "openai::gpt-4o-mini" # "gpt-3.5-turbo-0630"
5+
llm_code_model: "openai::gpt-4o-mini" # "gpt-3.5-turbo-0630"
6+
embedding_model: "openai::text-embedding-3-small"
77

88
debug: True
99
executor_port: 31888

hydra_vl4ai/agent/llm.py

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
from ..util.console import logger
1717

1818

19+
def parse_model_name(model_spec: str) -> tuple[str, str]:
20+
# Used to parse the model specification string
21+
# Example: "ollama::deepseek-r1:70b" -> ("ollama", "deepseek-r1:70b")
22+
assert "::" in model_spec, "Model specification must contain '::' to separate api_type and model_name"
23+
api_type, model_name = model_spec.split("::", 1)
24+
return api_type.lower(), model_name
25+
26+
1927
@Singleton
2028
class Cost:
2129
def __init__(self):
@@ -79,6 +87,38 @@ def add(self, chatgpt_response, model_name):
7987
else:
8088
logger.debug("OLLAMA server is set.")
8189

90+
# Set up vLLM client (using OpenAI client with custom base URL)
91+
try:
92+
vllm_host = os.environ.get("VLLM_HOST", "")
93+
if vllm_host != "":
94+
# Get API key from environment variable if set
95+
vllm_api_key = os.environ.get("VLLM_API_KEY", "")
96+
97+
vllm_client = AsyncOpenAI(
98+
base_url=vllm_host,
99+
api_key=vllm_api_key,
100+
timeout=60.0
101+
)
102+
vllm_client_sync = OpenAI(
103+
base_url=vllm_host,
104+
api_key=vllm_api_key,
105+
timeout=60.0
106+
)
107+
vllm_available = True
108+
logger.debug(f"vLLM client is initialized with server at {vllm_host}")
109+
if vllm_api_key != "":
110+
logger.debug("vLLM API key is set")
111+
else:
112+
vllm_client = None
113+
vllm_client_sync = None
114+
vllm_available = False
115+
logger.debug("VLLM_HOST environment variable is not set, vLLM will not work.")
116+
except Exception as e:
117+
vllm_client = None
118+
vllm_client_sync = None
119+
vllm_available = False
120+
logger.debug(f"Error setting up vLLM client: {e}")
121+
82122
_semaphore = asyncio.Semaphore(Config.base_config["llm_max_concurrency"])
83123

84124

@@ -159,24 +199,62 @@ async def ollama(model_name: str, messages: list[dict[str, str]], format: Litera
159199
return response["message"]["content"]
160200

161201

162-
async def llm(model_name: str, prompt: str, format: Literal["", "json"] = "") -> str | None:
163-
if model_name.startswith("gpt"):
202+
@handle_openai_exceptions
203+
async def vllm(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
204+
if not vllm_available:
205+
raise ValueError("vLLM is not available. Set VLLM_HOST environment variable.")
206+
207+
response_format = {"type": "json_object"} if format == "json" else NOT_GIVEN
208+
209+
async with _semaphore:
210+
try:
211+
response = await vllm_client.chat.completions.create(
212+
model=model_name,
213+
messages=messages,
214+
response_format=response_format
215+
)
216+
# Note: We don't track cost for vLLM
217+
logger.debug(f"vLLM API call completed successfully for model: {model_name}")
218+
except Exception as e:
219+
logger.error(f"Error making vLLM request: {e}")
220+
raise
221+
222+
return response.choices[0].message.content
223+
224+
225+
async def llm(model_spec: str, prompt: str, format: Literal["", "json"] = "") -> str | None:
226+
api_type, model_name = parse_model_name(model_spec)
227+
228+
if api_type == "openai":
164229
return await chatgpt(model_name, [{"role": "user", "content": prompt}], format)
165-
else:
230+
elif api_type == "ollama":
166231
return await ollama(model_name, [{"role": "user", "content": prompt}], format)
232+
elif api_type == "vllm":
233+
return await vllm(model_name, [{"role": "user", "content": prompt}], format)
234+
else:
235+
raise ValueError(f"Unknown API type: {api_type}")
167236

168-
async def llm_embedding(model_name: str, prompt: str):
169-
if model_name in ("text-embedding-3-small", "text-embedding-3-large"):
237+
238+
async def llm_embedding(model_spec: str, prompt: str):
239+
api_type, model_name = parse_model_name(model_spec)
240+
241+
if api_type == "openai" and model_name in ("text-embedding-3-small", "text-embedding-3-large"):
170242
return await gpt3_embedding(model_name, prompt)
171243
else:
172-
raise ValueError(f"Model {model_name} is not supported.")
244+
raise ValueError(f"Model {model_spec} is not supported for embeddings.")
173245

174246

175-
async def llm_with_message(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
176-
if model_name.startswith("gpt"):
247+
async def llm_with_message(model_spec: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
248+
api_type, model_name = parse_model_name(model_spec)
249+
250+
if api_type == "openai":
177251
return await chatgpt(model_name, messages, format)
178-
else:
252+
elif api_type == "ollama":
179253
return await ollama(model_name, messages, format)
254+
elif api_type == "vllm":
255+
return await vllm(model_name, messages, format)
256+
else:
257+
raise ValueError(f"Unknown API type: {api_type}")
180258

181259

182260
# sync versions
@@ -231,11 +309,17 @@ def wrapper(*args, **kwargs):
231309
return wrapper
232310

233311

234-
def llm_sync(model_name: str, prompt: str, format: Literal["", "json"] = "") -> str | None:
235-
if model_name.startswith("gpt"):
312+
def llm_sync(model_spec: str, prompt: str, format: Literal["", "json"] = "") -> str | None:
313+
api_type, model_name = parse_model_name(model_spec)
314+
315+
if api_type == "openai":
236316
return chatgpt_sync(model_name, [{"role": "user", "content": prompt}], format)
237-
else:
317+
elif api_type == "ollama":
238318
return ollama_sync(model_name, [{"role": "user", "content": prompt}], format)
319+
elif api_type == "vllm":
320+
return vllm_sync(model_name, [{"role": "user", "content": prompt}], format)
321+
else:
322+
raise ValueError(f"Unknown API type: {api_type}")
239323

240324

241325
@handle_openai_exceptions_sync
@@ -251,3 +335,35 @@ def chatgpt_sync(model_name: str, messages: list[dict[str, str]], format: Litera
251335
def ollama_sync(model_name: str, prompt: str, format: Literal["", "json"] = "") -> str | None:
252336
response = ollama_client_sync.chat(model=model_name, messages=[{"role": "user", "content": prompt}], stream=False, format=format)
253337
return response["message"]["content"]
338+
339+
340+
@handle_openai_exceptions_sync
341+
def vllm_sync(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
342+
"""Make a synchronous request to vLLM server using the OpenAI client library.
343+
344+
Args:
345+
model_name: The model name to use on the vLLM server
346+
messages: List of message dictionaries containing role and content
347+
format: Optional format requested ("json" for JSON mode)
348+
349+
Returns:
350+
The model's response text
351+
"""
352+
if not vllm_available:
353+
raise ValueError("vLLM is not available. Set VLLM_HOST environment variable.")
354+
355+
response_format = {"type": "json_object"} if format == "json" else NOT_GIVEN
356+
357+
try:
358+
response = vllm_client_sync.chat.completions.create(
359+
model=model_name,
360+
messages=messages,
361+
response_format=response_format
362+
)
363+
# Note: We don't track cost for vLLM
364+
logger.debug(f"vLLM API call completed successfully for model: {model_name}")
365+
except Exception as e:
366+
logger.error(f"Error making vLLM request: {e}")
367+
raise
368+
369+
return response.choices[0].message.content

0 commit comments

Comments
 (0)