99import concurrent .futures
1010from typing import Text , Tuple , Dict , List , Optional , Any
1111import openai
12- from openai import OpenAI , NotFoundError , AuthenticationError
12+ from openai import OpenAI , AzureOpenAI , NotFoundError , AuthenticationError
1313import numpy as np
1414import pandas as pd
1515
@@ -87,7 +87,7 @@ def create_engine(self, connection_args: Dict) -> None:
8787 if api_key is not None :
8888 org = connection_args .get ('api_organization' )
8989 api_base = connection_args .get ('api_base' ) or os .environ .get ('OPENAI_API_BASE' , OPENAI_API_BASE )
90- client = self ._get_client (api_key = api_key , base_url = api_base , org = org )
90+ client = self ._get_client (api_key = api_key , base_url = api_base , org = org , args = connection_args )
9191 OpenAIHandler ._check_client_connection (client )
9292
9393 @staticmethod
@@ -188,7 +188,9 @@ def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
188188 "temperature" ,
189189 "openai_api_key" ,
190190 "api_organization" ,
191- "api_base"
191+ "api_base" ,
192+ "api_version" ,
193+ "provider" ,
192194 }
193195 )
194196
@@ -204,7 +206,7 @@ def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
204206 api_key = get_api_key ('openai' , args , engine_storage = engine_storage )
205207 api_base = args .get ('api_base' ) or connection_args .get ('api_base' ) or os .environ .get ('OPENAI_API_BASE' , OPENAI_API_BASE )
206208 org = args .get ('api_organization' )
207- client = OpenAIHandler ._get_client (api_key = api_key , base_url = api_base , org = org )
209+ client = OpenAIHandler ._get_client (api_key = api_key , base_url = api_base , org = org , args = args )
208210 OpenAIHandler ._check_client_connection (client )
209211
210212 def create (self , target , args : Dict = None , ** kwargs : Any ) -> None :
@@ -228,7 +230,8 @@ def create(self, target, args: Dict = None, **kwargs: Any) -> None:
228230 api_key = get_api_key (self .api_key_name , args , self .engine_storage )
229231 connection_args = self .engine_storage .get_connection_args ()
230232 api_base = args .get ('api_base' ) or connection_args .get ('api_base' ) or os .environ .get ('OPENAI_API_BASE' ) or self .api_base
231- available_models = get_available_models (api_key , api_base )
233+ client = self ._get_client (api_key = api_key , base_url = api_base , org = args .get ('api_organization' ), args = args )
234+ available_models = get_available_models (client )
232235
233236 if not args .get ('mode' ):
234237 args ['mode' ] = self .default_mode
@@ -810,6 +813,7 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
810813 api_key = api_key ,
811814 base_url = args .get ('api_base' ),
812815 org = args .pop ('api_organization' ) if 'api_organization' in args else None ,
816+ args = args
813817 )
814818
815819 try :
@@ -891,7 +895,8 @@ def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame:
891895 client = self ._get_client (
892896 api_key = api_key ,
893897 base_url = args .get ('api_base' ),
894- org = args .get ('api_organization' )
898+ org = args .get ('api_organization' ),
899+ args = args ,
895900 )
896901 meta = client .models .retrieve (model_name )
897902 except Exception as e :
@@ -935,7 +940,7 @@ def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = Non
935940
936941 api_base = using_args .get ('api_base' , os .environ .get ('OPENAI_API_BASE' , OPENAI_API_BASE ))
937942 org = using_args .get ('api_organization' )
938- client = self ._get_client (api_key = api_key , base_url = api_base , org = org )
943+ client = self ._get_client (api_key = api_key , base_url = api_base , org = org , args = args )
939944
940945 args = {** using_args , ** args }
941946 prev_model_name = self .base_model_storage .json_get ('args' ).get ('model_name' , '' )
@@ -1173,7 +1178,7 @@ def _check_ft_status(job_id: Text) -> openai.types.fine_tuning.FineTuningJob:
11731178 return ft_stats , result_file_id
11741179
11751180 @staticmethod
1176- def _get_client (api_key : Text , base_url : Text , org : Optional [Text ] = None ) -> OpenAI :
1181+ def _get_client (api_key : Text , base_url : Text , org : Optional [Text ] = None , args : dict = None ) -> OpenAI :
11771182 """
11781183 Get an OpenAI client with the given API key, base URL, and organization.
11791184
@@ -1185,4 +1190,11 @@ def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> Op
11851190 Returns:
11861191 openai.OpenAI: OpenAI client.
11871192 """
1193+ if args is not None and args .get ('provider' ) == 'azure' :
1194+ return AzureOpenAI (
1195+ api_key = api_key ,
1196+ azure_endpoint = base_url ,
1197+ api_version = args .get ('api_version' ),
1198+ organization = org
1199+ )
11881200 return OpenAI (api_key = api_key , base_url = base_url , organization = org )
0 commit comments