|
10 | 10 | LLMProvider.VERTEX: APIFlavor.VERTEX_GEMINI_GENERATE_CONTENT, |
11 | 11 | } |
12 | 12 |
|
| 13 | +_API_FLAVOR_TO_PROVIDER: dict[APIFlavor, LLMProvider] = { |
| 14 | + APIFlavor.OPENAI_RESPONSES: LLMProvider.OPENAI, |
| 15 | + APIFlavor.OPENAI_COMPLETIONS: LLMProvider.OPENAI, |
| 16 | + APIFlavor.AWS_BEDROCK_CONVERSE: LLMProvider.BEDROCK, |
| 17 | + APIFlavor.AWS_BEDROCK_INVOKE: LLMProvider.BEDROCK, |
| 18 | + APIFlavor.VERTEX_GEMINI_GENERATE_CONTENT: LLMProvider.VERTEX, |
| 19 | + APIFlavor.VERTEX_ANTHROPIC_CLAUDE: LLMProvider.VERTEX, |
| 20 | +} |
| 21 | + |
13 | 22 |
|
14 | 23 | def _fetch_discovery(agenthub_config: str) -> list[dict[str, Any]]: |
15 | 24 | """Fetch available models from LLM Gateway discovery endpoint.""" |
@@ -126,26 +135,56 @@ def _create_vertex_llm( |
126 | 135 | raise ValueError(f"Unknown api_flavor={api_flavor} for Vertex") |
127 | 136 |
|
128 | 137 |
|
129 | | -def _compute_api_flavor( |
| 138 | +def _resolve_vendor(api_flavor: APIFlavor) -> LLMProvider: |
| 139 | + return _API_FLAVOR_TO_PROVIDER[api_flavor] |
| 140 | + |
| 141 | + |
| 142 | +def _resolve_api_flavor(vendor: LLMProvider, model_name: str) -> APIFlavor: |
| 143 | + if vendor == LLMProvider.VERTEX and "claude" in model_name: |
| 144 | + return APIFlavor.VERTEX_ANTHROPIC_CLAUDE |
| 145 | + return _DEFAULT_API_FLAVOR[vendor] |
| 146 | + |
| 147 | + |
| 148 | +def _compute_vendor_and_api_flavor( |
130 | 149 | model: dict[str, Any], |
131 | | -) -> APIFlavor: |
| 150 | +) -> tuple[LLMProvider, APIFlavor]: |
132 | 151 | vendor = model.get("vendor") |
133 | 152 | api_flavor = model.get("apiFlavor") |
134 | 153 | model_name = model.get("modelName", "") |
135 | 154 |
|
136 | | - if api_flavor is None and vendor == LLMProvider.VERTEX and "claude" in model_name: |
137 | | - api_flavor = APIFlavor.VERTEX_ANTHROPIC_CLAUDE |
138 | | - |
139 | | - if api_flavor is None and vendor is not None: |
140 | | - api_flavor = _DEFAULT_API_FLAVOR[LLMProvider(vendor)] |
| 155 | + if api_flavor is None and vendor is None: |
| 156 | + raise ValueError( |
| 157 | + f"Neither vendor nor apiFlavor provided for model '{model_name}'. " |
| 158 | + "At least one must be present." |
| 159 | + ) |
141 | 160 |
|
142 | | - if api_flavor not in [p.value for p in APIFlavor]: |
| 161 | + if api_flavor is not None and api_flavor not in [p.value for p in APIFlavor]: |
143 | 162 | raise ValueError( |
144 | | - f"Unknown apiFlavor '{api_flavor}' for model '{model.get('modelName')}'. " |
| 163 | + f"Unknown apiFlavor '{api_flavor}' for model '{model_name}'. " |
145 | 164 | f"Supported apiFlavors: {[p.value for p in APIFlavor]}" |
146 | 165 | ) |
147 | 166 |
|
148 | | - return APIFlavor(api_flavor) |
| 167 | + if vendor is not None and vendor not in [p.value for p in LLMProvider]: |
| 168 | + raise ValueError( |
| 169 | + f"Unknown vendor '{vendor}' for model '{model_name}'. " |
| 170 | + f"Supported vendors: {[p.value for p in LLMProvider]}" |
| 171 | + ) |
| 172 | + |
| 173 | + resolved_vendor: LLMProvider |
| 174 | + resolved_api_flavor: APIFlavor |
| 175 | + |
| 176 | + if vendor is None and api_flavor is not None: |
| 177 | + resolved_api_flavor = APIFlavor(api_flavor) |
| 178 | + resolved_vendor = _resolve_vendor(resolved_api_flavor) |
| 179 | + elif api_flavor is None and vendor is not None: |
| 180 | + resolved_vendor = LLMProvider(vendor) |
| 181 | + resolved_api_flavor = _resolve_api_flavor(resolved_vendor, model_name) |
| 182 | + else: |
| 183 | + assert vendor is not None and api_flavor is not None |
| 184 | + resolved_vendor = LLMProvider(vendor) |
| 185 | + resolved_api_flavor = APIFlavor(api_flavor) |
| 186 | + |
| 187 | + return resolved_vendor, resolved_api_flavor |
149 | 188 |
|
150 | 189 |
|
151 | 190 | def _get_model_info( |
@@ -192,13 +231,7 @@ def get_chat_model( |
192 | 231 | """ |
193 | 232 | model_info = _get_model_info(model, agenthub_config, byo_connection_id) |
194 | 233 |
|
195 | | - vendor = model_info.get("vendor") |
196 | | - if vendor not in [p.value for p in LLMProvider]: |
197 | | - raise ValueError( |
198 | | - f"Unknown vendor '{vendor}' for model '{model}'. " |
199 | | - f"Supported vendors: {[p.value for p in LLMProvider]}" |
200 | | - ) |
201 | | - api_flavor = _compute_api_flavor(model_info) |
| 234 | + vendor, api_flavor = _compute_vendor_and_api_flavor(model_info) |
202 | 235 | model_name: str = model_info.get("modelName", model) |
203 | 236 |
|
204 | 237 | match LLMProvider(vendor): |
|
0 commit comments