Skip to content

Commit 99bec00

Browse files
committed
fix bedrock (#887)
1 parent 0139eaa commit 99bec00

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

  • services/api/src/owl/utils

services/api/src/owl/utils/lm.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,14 @@ def batch(seq, n):
374374
for i in range(0, len(seq), n):
375375
yield seq[i : i + n]
376376

377-
def _inference_provider(self, provider: str) -> str:
377+
def _inference_provider(self, provider: str, owned_by: str) -> str:
378378
if provider == CloudProvider.ELLM:
379379
return CloudProvider.ELLM
380+
# this check for provider like azure/bedrock that provides other cloud providers model (openai/anthropic)
380381
if provider in ModelProvider:
381382
return ModelProvider(provider)
383+
elif owned_by in ModelProvider:
384+
return ModelProvider(owned_by)
382385
if provider in OnPremProvider:
383386
return OnPremProvider(provider)
384387
owned_by = self.config.owned_by or ""
@@ -598,7 +601,7 @@ async def _get_deployment(
598601
deployment=deployment,
599602
api_key=api_key,
600603
routing_id=routing_id,
601-
inference_provider=self._inference_provider(provider),
604+
inference_provider=self._inference_provider(provider, self.config.owned_by),
602605
is_reasoning_model=is_reasoning_model,
603606
)
604607
self.request.state.timing["external_call"] = perf_counter() - t0
@@ -1011,7 +1014,7 @@ async def _completion_stream(
10111014
response: AsyncGenerator[ModelResponseStream, None] = await acompletion(
10121015
timeout=self.config.timeout,
10131016
api_key=ctx.api_key,
1014-
base_url=ctx.deployment.api_base,
1017+
base_url=ctx.deployment.api_base or None,
10151018
model=ctx.routing_id,
10161019
messages=messages,
10171020
stream=True,
@@ -1043,7 +1046,7 @@ async def _completion(
10431046
response = await acompletion(
10441047
timeout=self.config.timeout,
10451048
api_key=ctx.api_key,
1046-
base_url=ctx.deployment.api_base,
1049+
base_url=ctx.deployment.api_base or None,
10471050
model=ctx.routing_id,
10481051
messages=messages,
10491052
stream=False,
@@ -1104,7 +1107,7 @@ async def image_generation(
11041107
response = await acompletion(
11051108
timeout=self.config.timeout,
11061109
api_key=ctx.api_key,
1107-
base_url=ctx.deployment.api_base,
1110+
base_url=ctx.deployment.api_base or None,
11081111
model=ctx.routing_id,
11091112
messages=messages_payload,
11101113
stream=False,
@@ -1192,7 +1195,7 @@ async def image_edit(
11921195
response = await acompletion(
11931196
timeout=self.config.timeout,
11941197
api_key=ctx.api_key,
1195-
base_url=ctx.deployment.api_base,
1198+
base_url=ctx.deployment.api_base or None,
11961199
model=ctx.routing_id,
11971200
messages=messages_payload,
11981201
stream=False,
@@ -1296,7 +1299,7 @@ async def embedding(
12961299
aembedding(
12971300
timeout=self.config.timeout,
12981301
api_key=ctx.api_key,
1299-
api_base=ctx.deployment.api_base,
1302+
api_base=ctx.deployment.api_base or None,
13001303
model=ctx.routing_id,
13011304
input=txt,
13021305
dimensions=dimensions,
@@ -1359,7 +1362,7 @@ async def reranking(
13591362
arerank(
13601363
timeout=self.config.timeout,
13611364
api_key=ctx.api_key,
1362-
api_base=ctx.deployment.api_base,
1365+
api_base=ctx.deployment.api_base or None,
13631366
model=ctx.routing_id,
13641367
query=query,
13651368
documents=docs,

0 commit comments

Comments
 (0)