|
1 | 1 | # Copyright (c) Microsoft Corporation. |
2 | 2 | # Licensed under the MIT License. |
3 | 3 |
|
4 | | -"""An common abstraction for a cached LLM inference setup. Currently supports OpenAI's gpt-4-turbo and other models.""" |
| 4 | +"""A common abstraction for a cached LLM inference setup. Currently supports OpenAI's gpt-4-turbo and other models.""" |
5 | 5 |
|
6 | 6 |
|
7 | 7 | import os |
8 | | -from openai import OpenAI |
| 8 | +import json |
| 9 | +import yaml |
9 | 10 | from groq import Groq |
10 | 11 | from pathlib import Path |
11 | | -import json |
| 12 | +from typing import Optional, List, Dict |
| 13 | +from dataclasses import dataclass |
| 14 | + |
| 15 | +from groq import Groq |
| 16 | +from openai import OpenAI, AzureOpenAI |
| 17 | +from azure.identity import get_bearer_token_provider, AzureCliCredential, ManagedIdentityCredential |
| 18 | + |
12 | 19 | from dotenv import load_dotenv |
13 | 20 |
|
14 | 21 | # Load environment variables from the .env file |
15 | 22 | load_dotenv() |
| 23 | +"""An common abstraction for a cached LLM inference setup. Currently supports OpenAI's gpt-4-turbo and other models.""" |
| 24 | + |
16 | 25 |
|
17 | 26 | CACHE_DIR = Path("./cache_dir") |
18 | 27 | CACHE_PATH = CACHE_DIR / "cache.json" |
| 28 | +GPT_MODEL = "gpt-4o" |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class AzureConfig: |
| 33 | + azure_endpoint: str |
| 34 | + api_version: str |
19 | 35 |
|
20 | 36 |
|
21 | 37 | class Cache: |
@@ -53,20 +69,58 @@ def save_cache(self): |
53 | 69 | class GPTClient: |
54 | 70 | """Abstraction for OpenAI's GPT series model.""" |
55 | 71 |
|
56 | | - def __init__(self): |
| 72 | + def __init__(self, auth_type: str = "key", api_key: Optional[str] = None, azure_config_file: Optional[str] = None, use_cache: bool = True): |
57 | 73 | self.cache = Cache() |
| 74 | + self.client = self._setup_client(auth_type, api_key, azure_config_file) |
| 75 | + |
| 76 | + def _load_azure_config(self, yaml_file_path: str) -> AzureConfig: |
| 77 | + with open(yaml_file_path, "r") as file: |
| 78 | + azure_config_data = yaml.safe_load(file) |
| 79 | + return AzureConfig( |
| 80 | + azure_endpoint=azure_config_data.get("azure_endpoint"), |
| 81 | + api_version=azure_config_data.get("api_version"), |
| 82 | + ) |
| 83 | + |
| 84 | + def _setup_client(self, auth_type: str, api_key: Optional[str], azure_config_file: Optional[str]): |
| 85 | + azure_identity_opts = ["cli", "managed_identity"] |
| 86 | + if auth_type == "key": |
| 87 | + # TODO: support Azure OpenAI client. |
| 88 | + api_key = api_key or os.getenv("OPENAI_API_KEY") |
| 89 | + if not api_key: |
| 90 | + raise ValueError("API key must be provided or set in OPENAI_API_KEY environment variable") |
| 91 | + return OpenAI(api_key=api_key) |
| 92 | + elif auth_type in azure_identity_opts: |
| 93 | + if not azure_config_file: |
| 94 | + raise ValueError("Azure configuration file must be provided for access via managed identity.\n Check AIOpsLab/clients/configs/example_azure_config.yml for an example.") |
| 95 | + azure_config = self._load_azure_config(azure_config_file) |
| 96 | + if auth_type == "cli": |
| 97 | + credential = AzureCliCredential() |
| 98 | + elif auth_type == "managed_identity": |
| 99 | + client_id = os.getenv("AZURE_CLIENT_ID") |
| 100 | + if client_id is None: |
| 101 | + raise ValueError("Managed identity selected but AZURE_CLIENT_ID is not set.") |
| 102 | + credential = ManagedIdentityCredential(client_id=client_id) |
| 103 | + token_provider = get_bearer_token_provider( |
| 104 | + credential, "https://cognitiveservices.azure.com/.default" |
| 105 | + ) |
| 106 | + return AzureOpenAI( |
| 107 | + api_version=azure_config.api_version, |
| 108 | + azure_endpoint=azure_config.azure_endpoint, |
| 109 | + azure_ad_token_provider=token_provider |
| 110 | + ) |
| 111 | + else: |
| 112 | + raise ValueError("auth_type must be one of 'key', 'cli', or 'managed_identity'") |
58 | 113 |
|
59 | 114 | def inference(self, payload: list[dict[str, str]]) -> list[str]: |
60 | 115 | if self.cache is not None: |
61 | 116 | cache_result = self.cache.get_from_cache(payload) |
62 | 117 | if cache_result is not None: |
63 | 118 | return cache_result |
64 | 119 |
|
65 | | - client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
66 | 120 | try: |
67 | | - response = client.chat.completions.create( |
| 121 | + response = self.client.chat.completions.create( |
68 | 122 | messages=payload, # type: ignore |
69 | | - model="gpt-4-turbo-2024-04-09", |
| 123 | + model=GPT_MODEL, |
70 | 124 | max_tokens=1024, |
71 | 125 | temperature=0.5, |
72 | 126 | top_p=0.95, |
|
0 commit comments