|
3 | 3 | Since most providers (DeepSeek, Qwen, Kimi, GLM, Ollama, etc.) expose an |
4 | 4 | OpenAI-compatible endpoint, we just use the openai SDK directly. Switch |
5 | 5 | provider by changing OPENAI_BASE_URL + OPENAI_API_KEY. That's it. |
| 6 | +
|
| 7 | +For providers that are NOT OpenAI-compatible (AWS Bedrock, Google Vertex, |
| 8 | +etc.), use the LiteLLM backend which routes to 100+ providers through a |
| 9 | +single unified interface. Set CORECODER_PROVIDER=litellm. |
6 | 10 | """ |
7 | 11 |
|
8 | 12 | import json |
@@ -197,3 +201,127 @@ def _call_with_retry(self, params: dict, max_retries: int = 3): |
197 | 201 | time.sleep(2 ** attempt) |
198 | 202 | else: |
199 | 203 | raise |
| 204 | + |
| 205 | + |
| 206 | +class LiteLLM(LLM): |
| 207 | + """LLM backend via LiteLLM, supporting 100+ providers. |
| 208 | +
|
| 209 | + Use this when your target provider is NOT OpenAI-compatible |
| 210 | + (AWS Bedrock, Google Vertex, Cohere, etc.) or when you want |
| 211 | + a single interface to switch between any provider by changing |
| 212 | + the model string. |
| 213 | +
|
| 214 | + Set CORECODER_PROVIDER=litellm and use LiteLLM model strings |
| 215 | + like ``anthropic/claude-3-haiku``, ``bedrock/anthropic.claude-v2``, |
| 216 | + ``vertex_ai/gemini-pro``, etc. |
| 217 | + """ |
| 218 | + |
| 219 | + def __init__( |
| 220 | + self, |
| 221 | + model: str, |
| 222 | + api_key: str | None = None, |
| 223 | + base_url: str | None = None, |
| 224 | + **kwargs, |
| 225 | + ): |
| 226 | + # skip LLM.__init__ which creates an OpenAI client |
| 227 | + self.model = model |
| 228 | + self.api_key = api_key |
| 229 | + self.base_url = base_url |
| 230 | + self.extra = kwargs |
| 231 | + self.total_prompt_tokens = 0 |
| 232 | + self.total_completion_tokens = 0 |
| 233 | + |
| 234 | + def chat( |
| 235 | + self, |
| 236 | + messages: list[dict], |
| 237 | + tools: list[dict] | None = None, |
| 238 | + on_token=None, |
| 239 | + ) -> LLMResponse: |
| 240 | + """Send messages via litellm, stream back response, handle tool calls.""" |
| 241 | + params: dict = { |
| 242 | + "model": self.model, |
| 243 | + "messages": messages, |
| 244 | + "stream": True, |
| 245 | + **self.extra, |
| 246 | + } |
| 247 | + if tools: |
| 248 | + params["tools"] = tools |
| 249 | + |
| 250 | + stream = self._call_with_retry(params) |
| 251 | + |
| 252 | + content_parts: list[str] = [] |
| 253 | + tc_map: dict[int, dict] = {} |
| 254 | + prompt_tok = 0 |
| 255 | + completion_tok = 0 |
| 256 | + |
| 257 | + for chunk in stream: |
| 258 | + usage = getattr(chunk, "usage", None) |
| 259 | + if usage: |
| 260 | + prompt_tok = getattr(usage, "prompt_tokens", 0) or 0 |
| 261 | + completion_tok = getattr(usage, "completion_tokens", 0) or 0 |
| 262 | + |
| 263 | + if not getattr(chunk, "choices", None): |
| 264 | + continue |
| 265 | + delta = chunk.choices[0].delta |
| 266 | + |
| 267 | + if getattr(delta, "content", None): |
| 268 | + content_parts.append(delta.content) |
| 269 | + if on_token: |
| 270 | + on_token(delta.content) |
| 271 | + |
| 272 | + if getattr(delta, "tool_calls", None): |
| 273 | + for tc_delta in delta.tool_calls: |
| 274 | + idx = tc_delta.index |
| 275 | + if idx not in tc_map: |
| 276 | + tc_map[idx] = {"id": "", "name": "", "args": ""} |
| 277 | + if tc_delta.id: |
| 278 | + tc_map[idx]["id"] = tc_delta.id |
| 279 | + if tc_delta.function: |
| 280 | + if tc_delta.function.name: |
| 281 | + tc_map[idx]["name"] = tc_delta.function.name |
| 282 | + if tc_delta.function.arguments: |
| 283 | + tc_map[idx]["args"] += tc_delta.function.arguments |
| 284 | + |
| 285 | + parsed: list[ToolCall] = [] |
| 286 | + for idx in sorted(tc_map): |
| 287 | + raw = tc_map[idx] |
| 288 | + try: |
| 289 | + args = json.loads(raw["args"]) |
| 290 | + except (json.JSONDecodeError, KeyError): |
| 291 | + args = {} |
| 292 | + parsed.append(ToolCall(id=raw["id"], name=raw["name"], arguments=args)) |
| 293 | + |
| 294 | + self.total_prompt_tokens += prompt_tok |
| 295 | + self.total_completion_tokens += completion_tok |
| 296 | + |
| 297 | + return LLMResponse( |
| 298 | + content="".join(content_parts), |
| 299 | + tool_calls=parsed, |
| 300 | + prompt_tokens=prompt_tok, |
| 301 | + completion_tokens=completion_tok, |
| 302 | + ) |
| 303 | + |
| 304 | + def _call_with_retry(self, params: dict, max_retries: int = 3): |
| 305 | + """Retry on transient errors with exponential backoff via litellm.""" |
| 306 | + import litellm |
| 307 | + |
| 308 | + params["drop_params"] = True |
| 309 | + if self.api_key: |
| 310 | + params["api_key"] = self.api_key |
| 311 | + if self.base_url: |
| 312 | + params["api_base"] = self.base_url |
| 313 | + |
| 314 | + for attempt in range(max_retries): |
| 315 | + try: |
| 316 | + return litellm.completion(**params) |
| 317 | + except Exception as e: |
| 318 | + err = str(e).lower() |
| 319 | + is_transient = any( |
| 320 | + kw in err |
| 321 | + for kw in ["rate_limit", "timeout", "connection", "502", "503", "529"] |
| 322 | + ) |
| 323 | + is_server = any(kw in err for kw in ["500", "502", "503", "504"]) |
| 324 | + if (is_transient or is_server) and attempt < max_retries - 1: |
| 325 | + time.sleep(2 ** attempt) |
| 326 | + else: |
| 327 | + raise |
0 commit comments