Skip to content

Commit 767663d

Browse files
authored
Merge branch 'mindsdb:main' into main
2 parents 337d918 + 5b4af6e commit 767663d

8 files changed

Lines changed: 517 additions & 472 deletions

File tree

docs/mint.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@
522522
"pages": [
523523
"mindsdb_sql/functions/custom_functions",
524524
"mindsdb_sql/functions/llm_function",
525-
"mindsdb_sql/functions/to_markdown"
525+
"mindsdb_sql/functions/to_markdown_function"
526526
]
527527
},
528528
{

mindsdb/integrations/handlers/openai_handler/helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import math
55

66
import openai
7-
from openai import OpenAI
87

98
import tiktoken
109

@@ -181,17 +180,16 @@ def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_na
181180
)
182181

183182

184-
def get_available_models(api_key: Text, api_base: Text) -> List[Text]:
183+
def get_available_models(client) -> List[Text]:
185184
"""
186185
Returns a list of available openai models for the given API key.
187186
188187
Args:
189-
api_key (Text): OpenAI API key
190-
api_base (Text): OpenAI API base URL
188+
client: openai sdk client
191189
192190
Returns:
193191
List[Text]: List of available models
194192
"""
195-
res = OpenAI(api_key=api_key, base_url=api_base).models.list()
193+
res = client.models.list()
196194

197195
return [models.id for models in res.data]

mindsdb/integrations/handlers/openai_handler/openai_handler.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import concurrent.futures
1010
from typing import Text, Tuple, Dict, List, Optional, Any
1111
import openai
12-
from openai import OpenAI, NotFoundError, AuthenticationError
12+
from openai import OpenAI, AzureOpenAI, NotFoundError, AuthenticationError
1313
import numpy as np
1414
import pandas as pd
1515

@@ -87,7 +87,7 @@ def create_engine(self, connection_args: Dict) -> None:
8787
if api_key is not None:
8888
org = connection_args.get('api_organization')
8989
api_base = connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
90-
client = self._get_client(api_key=api_key, base_url=api_base, org=org)
90+
client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args)
9191
OpenAIHandler._check_client_connection(client)
9292

9393
@staticmethod
@@ -188,7 +188,9 @@ def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
188188
"temperature",
189189
"openai_api_key",
190190
"api_organization",
191-
"api_base"
191+
"api_base",
192+
"api_version",
193+
"provider",
192194
}
193195
)
194196

@@ -204,7 +206,7 @@ def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
204206
api_key = get_api_key('openai', args, engine_storage=engine_storage)
205207
api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE)
206208
org = args.get('api_organization')
207-
client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org)
209+
client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
208210
OpenAIHandler._check_client_connection(client)
209211

210212
def create(self, target, args: Dict = None, **kwargs: Any) -> None:
@@ -228,7 +230,8 @@ def create(self, target, args: Dict = None, **kwargs: Any) -> None:
228230
api_key = get_api_key(self.api_key_name, args, self.engine_storage)
229231
connection_args = self.engine_storage.get_connection_args()
230232
api_base = args.get('api_base') or connection_args.get('api_base') or os.environ.get('OPENAI_API_BASE') or self.api_base
231-
available_models = get_available_models(api_key, api_base)
233+
client = self._get_client(api_key=api_key, base_url=api_base, org=args.get('api_organization'), args=args)
234+
available_models = get_available_models(client)
232235

233236
if not args.get('mode'):
234237
args['mode'] = self.default_mode
@@ -810,6 +813,7 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
810813
api_key=api_key,
811814
base_url=args.get('api_base'),
812815
org=args.pop('api_organization') if 'api_organization' in args else None,
816+
args=args
813817
)
814818

815819
try:
@@ -891,7 +895,8 @@ def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame:
891895
client = self._get_client(
892896
api_key=api_key,
893897
base_url=args.get('api_base'),
894-
org=args.get('api_organization')
898+
org=args.get('api_organization'),
899+
args=args,
895900
)
896901
meta = client.models.retrieve(model_name)
897902
except Exception as e:
@@ -935,7 +940,7 @@ def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = Non
935940

936941
api_base = using_args.get('api_base', os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
937942
org = using_args.get('api_organization')
938-
client = self._get_client(api_key=api_key, base_url=api_base, org=org)
943+
client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
939944

940945
args = {**using_args, **args}
941946
prev_model_name = self.base_model_storage.json_get('args').get('model_name', '')
@@ -1173,7 +1178,7 @@ def _check_ft_status(job_id: Text) -> openai.types.fine_tuning.FineTuningJob:
11731178
return ft_stats, result_file_id
11741179

11751180
@staticmethod
1176-
def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> OpenAI:
1181+
def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None, args: dict = None) -> OpenAI:
11771182
"""
11781183
Get an OpenAI client with the given API key, base URL, and organization.
11791184
@@ -1185,4 +1190,11 @@ def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> Op
11851190
Returns:
11861191
openai.OpenAI: OpenAI client.
11871192
"""
1193+
if args is not None and args.get('provider') == 'azure':
1194+
return AzureOpenAI(
1195+
api_key=api_key,
1196+
azure_endpoint=base_url,
1197+
api_version=args.get('api_version'),
1198+
organization=org
1199+
)
11881200
return OpenAI(api_key=api_key, base_url=base_url, organization=org)

0 commit comments

Comments
 (0)