Skip to content

Commit fab29b2

Browse files
feat: support self serve BYO (#521)
1 parent 4acba2b commit fab29b2

2 files changed

Lines changed: 372 additions & 17 deletions

File tree

src/uipath_langchain/chat/chat_model_factory.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
LLMProvider.VERTEX: APIFlavor.VERTEX_GEMINI_GENERATE_CONTENT,
1111
}
1212

13+
_API_FLAVOR_TO_PROVIDER: dict[APIFlavor, LLMProvider] = {
14+
APIFlavor.OPENAI_RESPONSES: LLMProvider.OPENAI,
15+
APIFlavor.OPENAI_COMPLETIONS: LLMProvider.OPENAI,
16+
APIFlavor.AWS_BEDROCK_CONVERSE: LLMProvider.BEDROCK,
17+
APIFlavor.AWS_BEDROCK_INVOKE: LLMProvider.BEDROCK,
18+
APIFlavor.VERTEX_GEMINI_GENERATE_CONTENT: LLMProvider.VERTEX,
19+
APIFlavor.VERTEX_ANTHROPIC_CLAUDE: LLMProvider.VERTEX,
20+
}
21+
1322

1423
def _fetch_discovery(agenthub_config: str) -> list[dict[str, Any]]:
1524
"""Fetch available models from LLM Gateway discovery endpoint."""
@@ -126,26 +135,56 @@ def _create_vertex_llm(
126135
raise ValueError(f"Unknown api_flavor={api_flavor} for Vertex")
127136

128137

129-
def _compute_api_flavor(
138+
def _resolve_vendor(api_flavor: APIFlavor) -> LLMProvider:
139+
return _API_FLAVOR_TO_PROVIDER[api_flavor]
140+
141+
142+
def _resolve_api_flavor(vendor: LLMProvider, model_name: str) -> APIFlavor:
143+
if vendor == LLMProvider.VERTEX and "claude" in model_name:
144+
return APIFlavor.VERTEX_ANTHROPIC_CLAUDE
145+
return _DEFAULT_API_FLAVOR[vendor]
146+
147+
148+
def _compute_vendor_and_api_flavor(
130149
model: dict[str, Any],
131-
) -> APIFlavor:
150+
) -> tuple[LLMProvider, APIFlavor]:
132151
vendor = model.get("vendor")
133152
api_flavor = model.get("apiFlavor")
134153
model_name = model.get("modelName", "")
135154

136-
if api_flavor is None and vendor == LLMProvider.VERTEX and "claude" in model_name:
137-
api_flavor = APIFlavor.VERTEX_ANTHROPIC_CLAUDE
138-
139-
if api_flavor is None and vendor is not None:
140-
api_flavor = _DEFAULT_API_FLAVOR[LLMProvider(vendor)]
155+
if api_flavor is None and vendor is None:
156+
raise ValueError(
157+
f"Neither vendor nor apiFlavor provided for model '{model_name}'. "
158+
"At least one must be present."
159+
)
141160

142-
if api_flavor not in [p.value for p in APIFlavor]:
161+
if api_flavor is not None and api_flavor not in [p.value for p in APIFlavor]:
143162
raise ValueError(
144-
f"Unknown apiFlavor '{api_flavor}' for model '{model.get('modelName')}'. "
163+
f"Unknown apiFlavor '{api_flavor}' for model '{model_name}'. "
145164
f"Supported apiFlavors: {[p.value for p in APIFlavor]}"
146165
)
147166

148-
return APIFlavor(api_flavor)
167+
if vendor is not None and vendor not in [p.value for p in LLMProvider]:
168+
raise ValueError(
169+
f"Unknown vendor '{vendor}' for model '{model_name}'. "
170+
f"Supported vendors: {[p.value for p in LLMProvider]}"
171+
)
172+
173+
resolved_vendor: LLMProvider
174+
resolved_api_flavor: APIFlavor
175+
176+
if vendor is None and api_flavor is not None:
177+
resolved_api_flavor = APIFlavor(api_flavor)
178+
resolved_vendor = _resolve_vendor(resolved_api_flavor)
179+
elif api_flavor is None and vendor is not None:
180+
resolved_vendor = LLMProvider(vendor)
181+
resolved_api_flavor = _resolve_api_flavor(resolved_vendor, model_name)
182+
else:
183+
assert vendor is not None and api_flavor is not None
184+
resolved_vendor = LLMProvider(vendor)
185+
resolved_api_flavor = APIFlavor(api_flavor)
186+
187+
return resolved_vendor, resolved_api_flavor
149188

150189

151190
def _get_model_info(
@@ -192,13 +231,7 @@ def get_chat_model(
192231
"""
193232
model_info = _get_model_info(model, agenthub_config, byo_connection_id)
194233

195-
vendor = model_info.get("vendor")
196-
if vendor not in [p.value for p in LLMProvider]:
197-
raise ValueError(
198-
f"Unknown vendor '{vendor}' for model '{model}'. "
199-
f"Supported vendors: {[p.value for p in LLMProvider]}"
200-
)
201-
api_flavor = _compute_api_flavor(model_info)
234+
vendor, api_flavor = _compute_vendor_and_api_flavor(model_info)
202235
model_name: str = model_info.get("modelName", model)
203236

204237
match LLMProvider(vendor):

0 commit comments

Comments
 (0)