@@ -180,271 +180,6 @@ def run(self, payload: list[dict[str, str]]) -> list[str]:
180180 return response
181181
182182
183- class QwenClient :
184- """Abstraction for Qwen's model. Some Qwen models only support streaming output."""
185-
186- def __init__ (self , auth_type : str = "key" , api_key : Optional [str ] = None , azure_config_file : Optional [str ] = None , use_cache : bool = True ):
187- self .cache = Cache ()
188- self .client = self ._setup_client (auth_type , api_key , azure_config_file )
189-
190- def _load_azure_config (self , yaml_file_path : str ) -> AzureConfig :
191- with open (yaml_file_path , "r" ) as file :
192- azure_config_data = yaml .safe_load (file )
193- return AzureConfig (
194- azure_endpoint = azure_config_data .get ("azure_endpoint" ),
195- api_version = azure_config_data .get ("api_version" ),
196- )
197-
198- def _setup_client (self , auth_type : str , api_key : Optional [str ], azure_config_file : Optional [str ]):
199- azure_identity_opts = ["cli" , "managed_identity" ]
200- if auth_type == "key" :
201- # TODO: support Azure OpenAI client.
202- api_key = api_key or os .getenv ("OPENAI_API_KEY" )
203- if not api_key :
204- raise ValueError ("API key must be provided or set in OPENAI_API_KEY environment variable" )
205- return OpenAI (api_key = api_key )
206- elif auth_type in azure_identity_opts :
207- if not azure_config_file :
208- raise ValueError ("Azure configuration file must be provided for access via managed identity.\n Check AIOpsLab/clients/configs/example_azure_config.yml for an example." )
209- azure_config = self ._load_azure_config (azure_config_file )
210- if auth_type == "cli" :
211- credential = AzureCliCredential ()
212- elif auth_type == "managed_identity" :
213- client_id = os .getenv ("AZURE_CLIENT_ID" )
214- if client_id is None :
215- raise ValueError ("Managed identity selected but AZURE_CLIENT_ID is not set." )
216- credential = ManagedIdentityCredential (client_id = client_id )
217- token_provider = get_bearer_token_provider (
218- credential , "https://cognitiveservices.azure.com/.default"
219- )
220- return AzureOpenAI (
221- api_version = azure_config .api_version ,
222- azure_endpoint = azure_config .azure_endpoint ,
223- azure_ad_token_provider = token_provider
224- )
225- else :
226- raise ValueError ("auth_type must be one of 'key', 'cli', or 'managed_identity'" )
227-
228- def inference (self , payload : list [dict [str , str ]]) -> list [str ]:
229- if self .cache is not None :
230- cache_result = self .cache .get_from_cache (payload )
231- if cache_result is not None :
232- return cache_result
233-
234- client = OpenAI (api_key = os .getenv ("DASHSCOPE_API_KEY" ),
235- base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" )
236- try :
237- # TODO: Add constraints for the input context length
238- response = client .chat .completions .create (
239- messages = payload , # type: ignore
240- model = "qwq-32b" ,
241- max_tokens = 1024 ,
242- n = 1 ,
243- timeout = 60 ,
244- stop = [],
245- stream = True
246- )
247- except Exception as e :
248- print (f"Exception: { repr (e )} " )
249- raise e
250-
251- reasoning_content = ""
252- answer_content = ""
253- is_answering = False
254-
255- for chunk in response :
256- if not chunk .choices :
257- print ("\n Usage:" )
258- print (chunk .usage )
259- else :
260- delta = chunk .choices [0 ].delta
261- if hasattr (delta , 'reasoning_content' ) and delta .reasoning_content != None :
262- reasoning_content += delta .reasoning_content
263- else :
264- if delta .content != "" and is_answering is False :
265- is_answering = True
266- answer_content += delta .content
267-
268- return [answer_content ]
269-
270- def run (self , payload : list [dict [str , str ]]) -> list [str ]:
271- response = self .inference (payload )
272- if self .cache is not None :
273- self .cache .add_to_cache (payload , response )
274- self .cache .save_cache ()
275- return response
276-
277-
278- class vLLMClient :
279- """Abstraction for local LLM models."""
280-
281- def __init__ (self ,
282- model = "Qwen/Qwen2.5-Coder-3B-Instruct" ,
283- repetition_penalty = 1.0 ,
284- temperature = 1.0 ,
285- top_p = 0.95 ,
286- max_tokens = 1024 ):
287- self .cache = Cache ()
288- self .model = model
289- self .repetition_penalty = repetition_penalty
290- self .temperature = temperature
291- self .top_p = top_p
292- self .max_tokens = max_tokens
293-
294- def inference (self , payload : list [dict [str , str ]]) -> list [str ]:
295- if self .cache is not None :
296- cache_result = self .cache .get_from_cache (payload )
297- if cache_result is not None :
298- return cache_result
299-
300- client = OpenAI (api_key = "EMPTY" , base_url = "http://localhost:8000/v1" )
301- try :
302- response = client .chat .completions .create (
303- messages = payload , # type: ignore
304- model = self .model ,
305- max_tokens = self .max_tokens ,
306- temperature = self .temperature ,
307- top_p = self .top_p ,
308- frequency_penalty = 0.0 ,
309- presence_penalty = 0.0 ,
310- n = 1 ,
311- timeout = 60 ,
312- stop = [],
313- )
314- except Exception as e :
315- print (f"Exception: { repr (e )} " )
316- raise e
317-
318- return [c .message .content for c in response .choices ] # type: ignore
319-
320- def run (self , payload : list [dict [str , str ]]) -> list [str ]:
321- response = self .inference (payload )
322- if self .cache is not None :
323- self .cache .add_to_cache (payload , response )
324- self .cache .save_cache ()
325- return response
326-
327-
328- class OpenRouterClient :
329- """Abstraction for OpenRouter API with support for multiple models."""
330-
331- def __init__ (self , model = "anthropic/claude-3.5-sonnet" ):
332- self .cache = Cache ()
333- self .model = model
334-
335- def inference (self , payload : list [dict [str , str ]]) -> list [str ]:
336- if self .cache is not None :
337- cache_result = self .cache .get_from_cache (payload )
338- if cache_result is not None :
339- return cache_result
340-
341- client = OpenAI (
342- api_key = os .getenv ("OPENROUTER_API_KEY" ),
343- base_url = "https://openrouter.ai/api/v1"
344- )
345- try :
346- response = self .client .chat .completions .create (
347- messages = payload , # type: ignore
348- model = self .model ,
349- max_tokens = 1024 ,
350- temperature = 0.5 ,
351- top_p = 0.95 ,
352- frequency_penalty = 0.0 ,
353- presence_penalty = 0.0 ,
354- n = 1 ,
355- timeout = 60 ,
356- stop = [],
357- )
358- except Exception as e :
359- print (f"Exception: { repr (e )} " )
360- raise e
361-
362- return [c .message .content for c in response .choices ] # type: ignore
363-
364- def run (self , payload : list [dict [str , str ]]) -> list [str ]:
365- response = self .inference (payload )
366- if self .cache is not None :
367- self .cache .add_to_cache (payload , response )
368- self .cache .save_cache ()
369- return response
370-
371-
372- class LLaMAClient :
373- """Abstraction for Meta's LLaMA-3 model."""
374-
375- def __init__ (self ):
376- self .cache = Cache ()
377-
378- def inference (self , payload : list [dict [str , str ]]) -> list [str ]:
379- if self .cache is not None :
380- cache_result = self .cache .get_from_cache (payload )
381- if cache_result is not None :
382- return cache_result
383-
384- client = Groq (api_key = os .getenv ("GROQ_API_KEY" ))
385- try :
386- response = client .chat .completions .create (
387- messages = payload ,
388- model = "llama-3.1-8b-instant" ,
389- max_tokens = 1024 ,
390- temperature = 0.5 ,
391- top_p = 0.95 ,
392- frequency_penalty = 0.0 ,
393- presence_penalty = 0.0 ,
394- n = 1 ,
395- timeout = 60 ,
396- stop = [],
397- )
398- except Exception as e :
399- print (f"Exception: { repr (e )} " )
400- raise e
401-
402- return [c .message .content for c in response .choices ] # type: ignore
403-
404- def run (self , payload : list [dict [str , str ]]) -> list [str ]:
405- response = self .inference (payload )
406- if self .cache is not None :
407- self .cache .add_to_cache (payload , response )
408- self .cache .save_cache ()
409- return response
410-
411-
412- class DeepSeekClient :
413- """Abstraction for DeepSeek model."""
414-
415- def __init__ (self ):
416- self .cache = Cache ()
417-
418- def inference (self , payload : list [dict [str , str ]]) -> list [str ]:
419- if self .cache is not None :
420- cache_result = self .cache .get_from_cache (payload )
421- if cache_result is not None :
422- return cache_result
423-
424- client = OpenAI (api_key = os .getenv ("DEEPSEEK_API_KEY" ),
425- base_url = "https://api.deepseek.com" )
426- try :
427- response = client .chat .completions .create (
428- messages = payload , # type: ignore
429- model = "deepseek-reasoner" ,
430- max_tokens = 1024 ,
431- stop = [],
432- )
433-
434- except Exception as e :
435- print (f"Exception: { repr (e )} " )
436- raise e
437-
438- return [c .message .content for c in response .choices ] # type: ignore
439-
440- def run (self , payload : list [dict [str , str ]]) -> list [str ]:
441- response = self .inference (payload )
442- if self .cache is not None :
443- self .cache .add_to_cache (payload , response )
444- self .cache .save_cache ()
445- return response
446-
447-
448183class QwenClient :
449184 """Abstraction for Qwen's model. Some Qwen models only support streaming output."""
450185
0 commit comments