Skip to content

Commit 2bd4db3

Browse files
committed
add qwen3 support
1 parent 1e92fbb commit 2bd4db3

2 files changed

Lines changed: 45 additions & 43 deletions

File tree

.env.example

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
OPENAI_API_KEY=your-api-key
2-
OLLAMA_HOST=http://ollama.server:11434
3-
VLLM_HOST=http://vllm.server:8000
2+
AZURE_OPENAI_URL=
3+
OLLAMA_HOST=
4+
VLLM_HOST=
45
VLLM_API_KEY=your-vllm-api-key
56
# do not change this TORCH_HOME variable
67
TORCH_HOME=./pretrained_models

hydra_vl4ai/agent/llm.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
from functools import wraps
5-
import os
5+
import re
66
from typing import Literal
77
import httpx
88
import numpy as np
@@ -85,7 +85,7 @@ def add(self, chatgpt_response, model_name):
8585
ollama_client_sync = None
8686
logger.debug("OLLAMA server is not available, Llama will not work.")
8787
else:
88-
logger.debug("OLLAMA server is set.")
88+
logger.debug(f"OLLAMA server is set on {os.environ['OLLAMA_HOST']}")
8989

9090
# Set up vLLM client (using OpenAI client with custom base URL)
9191
try:
@@ -130,7 +130,9 @@ async def wrapper(*args, **kwargs):
130130
for _ in range(max_trial):
131131
try:
132132
logger.debug(f"Call OpenAI API. {args, kwargs}")
133-
return await func(*args, **kwargs)
133+
response = await func(*args, **kwargs)
134+
logger.debug(f"Response: {response}")
135+
return response
134136
except openai.APITimeoutError as e:
135137
logger.error(f"OpenAI API Timeout: {e}")
136138
pass
@@ -160,7 +162,9 @@ def handle_ollama_exceptions(func):
160162
async def wrapper(*args, **kwargs):
161163
for _ in range(max_trial):
162164
try:
163-
return await func(*args, **kwargs)
165+
response = await func(*args, **kwargs)
166+
logger.debug(f"Response: {response}")
167+
return response
164168
except httpx.ConnectError:
165169
pass
166170
except httpx.ConnectTimeout:
@@ -194,9 +198,20 @@ async def gpt3_embedding(model_name: str, prompt: str):
194198

195199
@handle_ollama_exceptions
196200
async def ollama(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str:
201+
# set non-thinking mode
202+
if "qwen3" in model_name.lower():
203+
messages[-1]["content"] += "/no_think"
204+
197205
async with _semaphore:
198206
response = await ollama_client.chat(model=model_name, messages=messages, stream=False, format=format)
199-
return response["message"]["content"]
207+
response_content = response["message"]["content"]
208+
209+
if "qwen3" in model_name.lower():
210+
# remove the <think> part
211+
pattern = r"<think>\s*?</think>"
212+
response_content = re.sub(pattern, "", response_content, flags=re.DOTALL)
213+
response_content = response_content.lstrip("\n")
214+
return response_content
200215

201216

202217
@handle_openai_exceptions
@@ -207,17 +222,11 @@ async def vllm(model_name: str, messages: list[dict[str, str]], format: Literal[
207222
response_format = {"type": "json_object"} if format == "json" else NOT_GIVEN
208223

209224
async with _semaphore:
210-
try:
211-
response = await vllm_client.chat.completions.create(
212-
model=model_name,
213-
messages=messages,
214-
response_format=response_format
215-
)
216-
# Note: We don't track cost for vLLM
217-
logger.debug(f"vLLM API call completed successfully for model: {model_name}")
218-
except Exception as e:
219-
logger.error(f"Error making vLLM request: {e}")
220-
raise
225+
response = await vllm_client.chat.completions.create(
226+
model=model_name,
227+
messages=messages,
228+
response_format=response_format
229+
)
221230

222231
return response.choices[0].message.content
223232

@@ -325,45 +334,37 @@ def llm_sync(model_spec: str, prompt: str, format: Literal["", "json"] = "") ->
325334
@handle_openai_exceptions_sync
326335
def chatgpt_sync(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
327336
response_format = {"type": "json_object"} if format == "json" else NOT_GIVEN
328-
response = openai_client_sync.chat.completions.create(model=model_name, messages=messages, response_format=response_format)
337+
response = openai_client_sync.chat.completions.create(model=model_name,
338+
messages=messages, response_format=response_format)
329339
Cost.add(response, model_name)
330340
logger.debug(f"Cost: {Cost.cost:.4f}, input: {Cost.input_tokens}, output: {Cost.output_tokens}")
331341
return response.choices[0].message.content
332342

333343

334344
@handle_openai_exceptions_sync
335-
def ollama_sync(model_name: str, prompt: str, format: Literal["", "json"] = "") -> str | None:
336-
response = ollama_client_sync.chat(model=model_name, messages=[{"role": "user", "content": prompt}], stream=False, format=format)
337-
return response["message"]["content"]
345+
def ollama_sync(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
346+
response = ollama_client_sync.chat(model=model_name, messages=messages, stream=False, format=format)
347+
response_content = response["message"]["content"]
348+
349+
if "qwen3" in model_name.lower():
350+
# remove the <think> part
351+
pattern = r"<think>\s*?</think>"
352+
response_content = re.sub(pattern, "", response_content, flags=re.DOTALL)
353+
response_content = response_content.lstrip("\n")
354+
return response_content
338355

339356

340357
@handle_openai_exceptions_sync
341358
def vllm_sync(model_name: str, messages: list[dict[str, str]], format: Literal["", "json"] = "") -> str | None:
342-
"""Make a synchronous request to vLLM server using the OpenAI client library.
343-
344-
Args:
345-
model_name: The model name to use on the vLLM server
346-
messages: List of message dictionaries containing role and content
347-
format: Optional format requested ("json" for JSON mode)
348-
349-
Returns:
350-
The model's response text
351-
"""
352359
if not vllm_available:
353360
raise ValueError("vLLM is not available. Set VLLM_HOST environment variable.")
354361

355362
response_format = {"type": "json_object"} if format == "json" else NOT_GIVEN
356363

357-
try:
358-
response = vllm_client_sync.chat.completions.create(
359-
model=model_name,
360-
messages=messages,
361-
response_format=response_format
362-
)
363-
# Note: We don't track cost for vLLM
364-
logger.debug(f"vLLM API call completed successfully for model: {model_name}")
365-
except Exception as e:
366-
logger.error(f"Error making vLLM request: {e}")
367-
raise
364+
response = vllm_client_sync.chat.completions.create(
365+
model=model_name,
366+
messages=messages,
367+
response_format=response_format
368+
)
368369

369370
return response.choices[0].message.content

0 commit comments

Comments
 (0)