1616from ..util .console import logger
1717
1818
19+ def parse_model_name (model_spec : str ) -> tuple [str , str ]:
20+ # Used to parse the model specification string
21+ # Example: "ollama::deepseek-r1:70b" -> ("ollama", "deepseek-r1:70b")
22+ assert "::" in model_spec , "Model specification must contain '::' to separate api_type and model_name"
23+ api_type , model_name = model_spec .split ("::" , 1 )
24+ return api_type .lower (), model_name
25+
26+
1927@Singleton
2028class Cost :
2129 def __init__ (self ):
@@ -79,6 +87,38 @@ def add(self, chatgpt_response, model_name):
7987else :
8088 logger .debug ("OLLAMA server is set." )
8189
90+ # Set up vLLM client (using OpenAI client with custom base URL)
91+ try :
92+ vllm_host = os .environ .get ("VLLM_HOST" , "" )
93+ if vllm_host != "" :
94+ # Get API key from environment variable if set
95+ vllm_api_key = os .environ .get ("VLLM_API_KEY" , "" )
96+
97+ vllm_client = AsyncOpenAI (
98+ base_url = vllm_host ,
99+ api_key = vllm_api_key ,
100+ timeout = 60.0
101+ )
102+ vllm_client_sync = OpenAI (
103+ base_url = vllm_host ,
104+ api_key = vllm_api_key ,
105+ timeout = 60.0
106+ )
107+ vllm_available = True
108+ logger .debug (f"vLLM client is initialized with server at { vllm_host } " )
109+ if vllm_api_key != "" :
110+ logger .debug ("vLLM API key is set" )
111+ else :
112+ vllm_client = None
113+ vllm_client_sync = None
114+ vllm_available = False
115+ logger .debug ("VLLM_HOST environment variable is not set, vLLM will not work." )
116+ except Exception as e :
117+ vllm_client = None
118+ vllm_client_sync = None
119+ vllm_available = False
120+ logger .debug (f"Error setting up vLLM client: { e } " )
121+
82122_semaphore = asyncio .Semaphore (Config .base_config ["llm_max_concurrency" ])
83123
84124
@@ -159,24 +199,62 @@ async def ollama(model_name: str, messages: list[dict[str, str]], format: Litera
159199 return response ["message" ]["content" ]
160200
161201
162- async def llm (model_name : str , prompt : str , format : Literal ["" , "json" ] = "" ) -> str | None :
163- if model_name .startswith ("gpt" ):
202+ @handle_openai_exceptions
203+ async def vllm (model_name : str , messages : list [dict [str , str ]], format : Literal ["" , "json" ] = "" ) -> str | None :
204+ if not vllm_available :
205+ raise ValueError ("vLLM is not available. Set VLLM_HOST environment variable." )
206+
207+ response_format = {"type" : "json_object" } if format == "json" else NOT_GIVEN
208+
209+ async with _semaphore :
210+ try :
211+ response = await vllm_client .chat .completions .create (
212+ model = model_name ,
213+ messages = messages ,
214+ response_format = response_format
215+ )
216+ # Note: We don't track cost for vLLM
217+ logger .debug (f"vLLM API call completed successfully for model: { model_name } " )
218+ except Exception as e :
219+ logger .error (f"Error making vLLM request: { e } " )
220+ raise
221+
222+ return response .choices [0 ].message .content
223+
224+
225+ async def llm (model_spec : str , prompt : str , format : Literal ["" , "json" ] = "" ) -> str | None :
226+ api_type , model_name = parse_model_name (model_spec )
227+
228+ if api_type == "openai" :
164229 return await chatgpt (model_name , [{"role" : "user" , "content" : prompt }], format )
165- else :
230+ elif api_type == "ollama" :
166231 return await ollama (model_name , [{"role" : "user" , "content" : prompt }], format )
232+ elif api_type == "vllm" :
233+ return await vllm (model_name , [{"role" : "user" , "content" : prompt }], format )
234+ else :
235+ raise ValueError (f"Unknown API type: { api_type } " )
167236
168- async def llm_embedding (model_name : str , prompt : str ):
169- if model_name in ("text-embedding-3-small" , "text-embedding-3-large" ):
237+
238+ async def llm_embedding (model_spec : str , prompt : str ):
239+ api_type , model_name = parse_model_name (model_spec )
240+
241+ if api_type == "openai" and model_name in ("text-embedding-3-small" , "text-embedding-3-large" ):
170242 return await gpt3_embedding (model_name , prompt )
171243 else :
172- raise ValueError (f"Model { model_name } is not supported." )
244+ raise ValueError (f"Model { model_spec } is not supported for embeddings ." )
173245
174246
175- async def llm_with_message (model_name : str , messages : list [dict [str , str ]], format : Literal ["" , "json" ] = "" ) -> str | None :
176- if model_name .startswith ("gpt" ):
247+ async def llm_with_message (model_spec : str , messages : list [dict [str , str ]], format : Literal ["" , "json" ] = "" ) -> str | None :
248+ api_type , model_name = parse_model_name (model_spec )
249+
250+ if api_type == "openai" :
177251 return await chatgpt (model_name , messages , format )
178- else :
252+ elif api_type == "ollama" :
179253 return await ollama (model_name , messages , format )
254+ elif api_type == "vllm" :
255+ return await vllm (model_name , messages , format )
256+ else :
257+ raise ValueError (f"Unknown API type: { api_type } " )
180258
181259
182260# sync versions
@@ -231,11 +309,17 @@ def wrapper(*args, **kwargs):
231309 return wrapper
232310
233311
234- def llm_sync (model_name : str , prompt : str , format : Literal ["" , "json" ] = "" ) -> str | None :
235- if model_name .startswith ("gpt" ):
312+ def llm_sync (model_spec : str , prompt : str , format : Literal ["" , "json" ] = "" ) -> str | None :
313+ api_type , model_name = parse_model_name (model_spec )
314+
315+ if api_type == "openai" :
236316 return chatgpt_sync (model_name , [{"role" : "user" , "content" : prompt }], format )
237- else :
317+ elif api_type == "ollama" :
238318 return ollama_sync (model_name , [{"role" : "user" , "content" : prompt }], format )
319+ elif api_type == "vllm" :
320+ return vllm_sync (model_name , [{"role" : "user" , "content" : prompt }], format )
321+ else :
322+ raise ValueError (f"Unknown API type: { api_type } " )
239323
240324
241325@handle_openai_exceptions_sync
@@ -251,3 +335,35 @@ def chatgpt_sync(model_name: str, messages: list[dict[str, str]], format: Litera
251335def ollama_sync (model_name : str , prompt : str , format : Literal ["" , "json" ] = "" ) -> str | None :
252336 response = ollama_client_sync .chat (model = model_name , messages = [{"role" : "user" , "content" : prompt }], stream = False , format = format )
253337 return response ["message" ]["content" ]
338+
339+
340+ @handle_openai_exceptions_sync
341+ def vllm_sync (model_name : str , messages : list [dict [str , str ]], format : Literal ["" , "json" ] = "" ) -> str | None :
342+ """Make a synchronous request to vLLM server using the OpenAI client library.
343+
344+ Args:
345+ model_name: The model name to use on the vLLM server
346+ messages: List of message dictionaries containing role and content
347+ format: Optional format requested ("json" for JSON mode)
348+
349+ Returns:
350+ The model's response text
351+ """
352+ if not vllm_available :
353+ raise ValueError ("vLLM is not available. Set VLLM_HOST environment variable." )
354+
355+ response_format = {"type" : "json_object" } if format == "json" else NOT_GIVEN
356+
357+ try :
358+ response = vllm_client_sync .chat .completions .create (
359+ model = model_name ,
360+ messages = messages ,
361+ response_format = response_format
362+ )
363+ # Note: We don't track cost for vLLM
364+ logger .debug (f"vLLM API call completed successfully for model: { model_name } " )
365+ except Exception as e :
366+ logger .error (f"Error making vLLM request: { e } " )
367+ raise
368+
369+ return response .choices [0 ].message .content
0 commit comments