Skip to content

Commit 4ee769b

Browse files
authored
feat: add LiteLLM as AI gateway provider (#10)
1 parent 3f888d6 commit 4ee769b

5 files changed

Lines changed: 379 additions & 2 deletions

File tree

corecoder/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from prompt_toolkit.key_binding import KeyBindings
1313

1414
from .agent import Agent
15-
from .llm import LLM
15+
from .llm import LLM, LiteLLM
1616
from .config import Config
1717
from .session import save_session, load_session, list_sessions
1818
from . import __version__
@@ -62,7 +62,8 @@ def main():
6262
)
6363
sys.exit(1)
6464

65-
llm = LLM(
65+
llm_cls = LiteLLM if config.provider == "litellm" else LLM
66+
llm = llm_cls(
6667
model=config.model,
6768
api_key=config.api_key,
6869
base_url=config.base_url,

corecoder/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Config:
3333
max_tokens: int = 4096
3434
temperature: float = 0.0
3535
max_context_tokens: int = 128_000
36+
provider: str = "openai"
3637

3738
@classmethod
3839
def from_env(cls) -> "Config":
@@ -52,4 +53,5 @@ def from_env(cls) -> "Config":
5253
max_tokens=int(os.getenv("CORECODER_MAX_TOKENS", "4096")),
5354
temperature=float(os.getenv("CORECODER_TEMPERATURE", "0")),
5455
max_context_tokens=int(os.getenv("CORECODER_MAX_CONTEXT", "128000")),
56+
provider=os.getenv("CORECODER_PROVIDER", "openai"),
5557
)

corecoder/llm.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Since most providers (DeepSeek, Qwen, Kimi, GLM, Ollama, etc.) expose an
44
OpenAI-compatible endpoint, we just use the openai SDK directly. Switch
55
provider by changing OPENAI_BASE_URL + OPENAI_API_KEY. That's it.
6+
7+
For providers that are NOT OpenAI-compatible (AWS Bedrock, Google Vertex,
8+
etc.), use the LiteLLM backend which routes to 100+ providers through a
9+
single unified interface. Set CORECODER_PROVIDER=litellm.
610
"""
711

812
import json
@@ -197,3 +201,127 @@ def _call_with_retry(self, params: dict, max_retries: int = 3):
197201
time.sleep(2 ** attempt)
198202
else:
199203
raise
204+
205+
206+
class LiteLLM(LLM):
207+
"""LLM backend via LiteLLM, supporting 100+ providers.
208+
209+
Use this when your target provider is NOT OpenAI-compatible
210+
(AWS Bedrock, Google Vertex, Cohere, etc.) or when you want
211+
a single interface to switch between any provider by changing
212+
the model string.
213+
214+
Set CORECODER_PROVIDER=litellm and use LiteLLM model strings
215+
like ``anthropic/claude-3-haiku``, ``bedrock/anthropic.claude-v2``,
216+
``vertex_ai/gemini-pro``, etc.
217+
"""
218+
219+
def __init__(
220+
self,
221+
model: str,
222+
api_key: str | None = None,
223+
base_url: str | None = None,
224+
**kwargs,
225+
):
226+
# skip LLM.__init__ which creates an OpenAI client
227+
self.model = model
228+
self.api_key = api_key
229+
self.base_url = base_url
230+
self.extra = kwargs
231+
self.total_prompt_tokens = 0
232+
self.total_completion_tokens = 0
233+
234+
def chat(
235+
self,
236+
messages: list[dict],
237+
tools: list[dict] | None = None,
238+
on_token=None,
239+
) -> LLMResponse:
240+
"""Send messages via litellm, stream back response, handle tool calls."""
241+
params: dict = {
242+
"model": self.model,
243+
"messages": messages,
244+
"stream": True,
245+
**self.extra,
246+
}
247+
if tools:
248+
params["tools"] = tools
249+
250+
stream = self._call_with_retry(params)
251+
252+
content_parts: list[str] = []
253+
tc_map: dict[int, dict] = {}
254+
prompt_tok = 0
255+
completion_tok = 0
256+
257+
for chunk in stream:
258+
usage = getattr(chunk, "usage", None)
259+
if usage:
260+
prompt_tok = getattr(usage, "prompt_tokens", 0) or 0
261+
completion_tok = getattr(usage, "completion_tokens", 0) or 0
262+
263+
if not getattr(chunk, "choices", None):
264+
continue
265+
delta = chunk.choices[0].delta
266+
267+
if getattr(delta, "content", None):
268+
content_parts.append(delta.content)
269+
if on_token:
270+
on_token(delta.content)
271+
272+
if getattr(delta, "tool_calls", None):
273+
for tc_delta in delta.tool_calls:
274+
idx = tc_delta.index
275+
if idx not in tc_map:
276+
tc_map[idx] = {"id": "", "name": "", "args": ""}
277+
if tc_delta.id:
278+
tc_map[idx]["id"] = tc_delta.id
279+
if tc_delta.function:
280+
if tc_delta.function.name:
281+
tc_map[idx]["name"] = tc_delta.function.name
282+
if tc_delta.function.arguments:
283+
tc_map[idx]["args"] += tc_delta.function.arguments
284+
285+
parsed: list[ToolCall] = []
286+
for idx in sorted(tc_map):
287+
raw = tc_map[idx]
288+
try:
289+
args = json.loads(raw["args"])
290+
except (json.JSONDecodeError, KeyError):
291+
args = {}
292+
parsed.append(ToolCall(id=raw["id"], name=raw["name"], arguments=args))
293+
294+
self.total_prompt_tokens += prompt_tok
295+
self.total_completion_tokens += completion_tok
296+
297+
return LLMResponse(
298+
content="".join(content_parts),
299+
tool_calls=parsed,
300+
prompt_tokens=prompt_tok,
301+
completion_tokens=completion_tok,
302+
)
303+
304+
def _call_with_retry(self, params: dict, max_retries: int = 3):
305+
"""Retry on transient errors with exponential backoff via litellm."""
306+
import litellm
307+
308+
params["drop_params"] = True
309+
if self.api_key:
310+
params["api_key"] = self.api_key
311+
if self.base_url:
312+
params["api_base"] = self.base_url
313+
314+
for attempt in range(max_retries):
315+
try:
316+
return litellm.completion(**params)
317+
except Exception as e:
318+
err = str(e).lower()
319+
is_transient = any(
320+
kw in err
321+
for kw in ["rate_limit", "timeout", "connection", "502", "503", "529"]
322+
)
323+
is_server = any(kw in err for kw in ["500", "502", "503", "504"])
324+
if (is_transient or is_server) and attempt < max_retries - 1:
325+
time.sleep(2 ** attempt)
326+
else:
327+
raise

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
packages = ["corecoder"]
4040

4141
[project.optional-dependencies]
42+
litellm = ["litellm>=1.60.0,<2.0.0"]
4243
dev = ["pytest>=7.0"]
4344

4445
[project.scripts]

0 commit comments

Comments
 (0)