@@ -374,11 +374,14 @@ def batch(seq, n):
374374 for i in range (0 , len (seq ), n ):
375375 yield seq [i : i + n ]
376376
377- def _inference_provider (self , provider : str ) -> str :
377+ def _inference_provider (self , provider : str , owned_by : str ) -> str :
378378 if provider == CloudProvider .ELLM :
379379 return CloudProvider .ELLM
380+ # this check for provider like azure/bedrock that provides other cloud providers model (openai/anthropic)
380381 if provider in ModelProvider :
381382 return ModelProvider (provider )
383+ elif owned_by in ModelProvider :
384+ return ModelProvider (owned_by )
382385 if provider in OnPremProvider :
383386 return OnPremProvider (provider )
384387 owned_by = self .config .owned_by or ""
@@ -598,7 +601,7 @@ async def _get_deployment(
598601 deployment = deployment ,
599602 api_key = api_key ,
600603 routing_id = routing_id ,
601- inference_provider = self ._inference_provider (provider ),
604+ inference_provider = self ._inference_provider (provider , self . config . owned_by ),
602605 is_reasoning_model = is_reasoning_model ,
603606 )
604607 self .request .state .timing ["external_call" ] = perf_counter () - t0
@@ -1011,7 +1014,7 @@ async def _completion_stream(
10111014 response : AsyncGenerator [ModelResponseStream , None ] = await acompletion (
10121015 timeout = self .config .timeout ,
10131016 api_key = ctx .api_key ,
1014- base_url = ctx .deployment .api_base ,
1017+ base_url = ctx .deployment .api_base or None ,
10151018 model = ctx .routing_id ,
10161019 messages = messages ,
10171020 stream = True ,
@@ -1043,7 +1046,7 @@ async def _completion(
10431046 response = await acompletion (
10441047 timeout = self .config .timeout ,
10451048 api_key = ctx .api_key ,
1046- base_url = ctx .deployment .api_base ,
1049+ base_url = ctx .deployment .api_base or None ,
10471050 model = ctx .routing_id ,
10481051 messages = messages ,
10491052 stream = False ,
@@ -1104,7 +1107,7 @@ async def image_generation(
11041107 response = await acompletion (
11051108 timeout = self .config .timeout ,
11061109 api_key = ctx .api_key ,
1107- base_url = ctx .deployment .api_base ,
1110+ base_url = ctx .deployment .api_base or None ,
11081111 model = ctx .routing_id ,
11091112 messages = messages_payload ,
11101113 stream = False ,
@@ -1192,7 +1195,7 @@ async def image_edit(
11921195 response = await acompletion (
11931196 timeout = self .config .timeout ,
11941197 api_key = ctx .api_key ,
1195- base_url = ctx .deployment .api_base ,
1198+ base_url = ctx .deployment .api_base or None ,
11961199 model = ctx .routing_id ,
11971200 messages = messages_payload ,
11981201 stream = False ,
@@ -1296,7 +1299,7 @@ async def embedding(
12961299 aembedding (
12971300 timeout = self .config .timeout ,
12981301 api_key = ctx .api_key ,
1299- api_base = ctx .deployment .api_base ,
1302+ api_base = ctx .deployment .api_base or None ,
13001303 model = ctx .routing_id ,
13011304 input = txt ,
13021305 dimensions = dimensions ,
@@ -1359,7 +1362,7 @@ async def reranking(
13591362 arerank (
13601363 timeout = self .config .timeout ,
13611364 api_key = ctx .api_key ,
1362- api_base = ctx .deployment .api_base ,
1365+ api_base = ctx .deployment .api_base or None ,
13631366 model = ctx .routing_id ,
13641367 query = query ,
13651368 documents = docs ,
0 commit comments