diff --git a/pyproject.toml b/pyproject.toml index 0a2a49ca8..7e6999b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,10 @@ dependencies = [ "prometheus_client>=0.21.0", "cloudevents>=1.12.0", ] + +[project.optional-dependencies] +litellm = ["litellm>=1.65,<1.85"] + [dependency-groups] dev = [ "pytest>=8.2.2", diff --git a/src/config.py b/src/config.py index 1703b5cd2..50786c52f 100644 --- a/src/config.py +++ b/src/config.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -ModelTransport = Literal["anthropic", "openai", "gemini"] +ModelTransport = Literal["anthropic", "openai", "gemini", "litellm"] EmbeddingTransport = Literal["openai", "gemini"] diff --git a/src/llm/backends/__init__.py b/src/llm/backends/__init__.py index dfba81eca..82e469181 100644 --- a/src/llm/backends/__init__.py +++ b/src/llm/backends/__init__.py @@ -1,9 +1,11 @@ from .anthropic import AnthropicBackend from .gemini import GeminiBackend +from .litellm import LiteLLMBackend from .openai import OpenAIBackend __all__ = [ "AnthropicBackend", "GeminiBackend", + "LiteLLMBackend", "OpenAIBackend", ] diff --git a/src/llm/backends/litellm.py b/src/llm/backends/litellm.py new file mode 100644 index 000000000..f35acf468 --- /dev/null +++ b/src/llm/backends/litellm.py @@ -0,0 +1,223 @@ +"""LiteLLM provider backend. + +Routes to 100+ LLM providers via a unified interface using provider-prefixed +model names (e.g. ``anthropic/claude-sonnet-4-6``, ``gemini/gemini-2.5-flash``). + +Install: ``pip install litellm`` +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncIterator +from typing import Any + +from pydantic import BaseModel + +from src.exceptions import ValidationException +from src.llm.backend import CompletionResult, StreamChunk, ToolCallResult + +logger = logging.getLogger(__name__) + + +class LiteLLMBackend: + """Provider backend wrapping litellm.acompletion.""" + + def __init__(self, api_key: str | None = None, api_base: str | None = None) -> None: + self._api_key = api_key + self._api_base = api_base + + def _base_kwargs(self) -> dict[str, Any]: + kwargs: dict[str, Any] = {"drop_params": True} + if self._api_key: + kwargs["api_key"] = self._api_key + if self._api_base: + kwargs["api_base"] = self._api_base + return kwargs + + @staticmethod + def _import_litellm() -> Any: + try: + import litellm + except ModuleNotFoundError as exc: + raise ValidationException( + "LiteLLM transport requires optional dependency 'litellm'. " + "Install with: pip install honcho[litellm]" + ) from exc + return litellm + + async def complete( + self, + *, + model: str, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float | None = None, + stop: list[str] | None = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + response_format: type[BaseModel] | dict[str, Any] | None = None, + thinking_budget_tokens: int | None = None, + thinking_effort: str | None = None, + max_output_tokens: int | None = None, + extra_params: dict[str, Any] | None = None, + ) -> CompletionResult: + litellm = self._import_litellm() + + params = self._build_params( + model=model, + messages=messages, + max_tokens=max_output_tokens or max_tokens, + temperature=temperature, + stop=stop, + tools=tools, + tool_choice=tool_choice, + response_format=response_format, + thinking_effort=thinking_effort, + extra_params=extra_params, + ) + + response = await litellm.acompletion(**params) + return self._normalize_response(response) + + async def stream( + self, + *, + model: str, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float | None = None, + stop: list[str] | None = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + response_format: type[BaseModel] | dict[str, Any] | None = None, + thinking_budget_tokens: int | None = None, + thinking_effort: str | None = None, + max_output_tokens: int | None = None, + extra_params: dict[str, Any] | None = None, + ) -> AsyncIterator[StreamChunk]: + litellm = self._import_litellm() + + params = self._build_params( + model=model, + messages=messages, + max_tokens=max_output_tokens or max_tokens, + temperature=temperature, + stop=stop, + tools=tools, + tool_choice=tool_choice, + response_format=response_format, + thinking_effort=thinking_effort, + extra_params=extra_params, + ) + params["stream"] = True + + response_stream = await litellm.acompletion(**params) + finish_reason: str | None = None + async for chunk in response_stream: + if chunk.choices and chunk.choices[0].delta.content: + yield StreamChunk(content=chunk.choices[0].delta.content) + if chunk.choices and chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + usage = getattr(chunk, "usage", None) + if usage: + yield StreamChunk( + is_done=True, + finish_reason=finish_reason, + output_tokens=getattr(usage, "completion_tokens", None), + ) + return + + if finish_reason: + yield StreamChunk(is_done=True, finish_reason=finish_reason) + + def _build_params( + self, + *, + model: str, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float | None, + stop: list[str] | None, + tools: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, + response_format: type[BaseModel] | dict[str, Any] | None, + thinking_effort: str | None, + extra_params: dict[str, Any] | None, + ) -> dict[str, Any]: + params: dict[str, Any] = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + **self._base_kwargs(), + } + if temperature is not None: + params["temperature"] = temperature + if stop: + params["stop"] = stop + if tools: + params["tools"] = self._convert_tools(tools) + if tool_choice is not None: + params["tool_choice"] = tool_choice + if response_format is not None: + if isinstance(response_format, type) and issubclass( + response_format, BaseModel + ): + params["response_format"] = response_format + else: + params["response_format"] = response_format + if thinking_effort: + params["reasoning_effort"] = thinking_effort + if extra_params: + for key in ("top_p", "frequency_penalty", "presence_penalty", "seed"): + if key in extra_params: + params[key] = extra_params[key] + return params + + @staticmethod + def _normalize_response(response: Any) -> CompletionResult: + usage = getattr(response, "usage", None) + message = response.choices[0].message + finish_reason = response.choices[0].finish_reason + + tool_calls: list[ToolCallResult] = [] + for tc in getattr(message, "tool_calls", None) or []: + tool_input: dict[str, Any] = {} + if tc.function.arguments: + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + logger.warning( + "Malformed tool arguments for %s (id=%s)", + tc.function.name, + tc.id, + ) + tool_calls.append( + ToolCallResult(id=tc.id, name=tc.function.name, input=tool_input) + ) + + return CompletionResult( + content=getattr(message, "content", "") or "", + input_tokens=getattr(usage, "prompt_tokens", 0) if usage else 0, + output_tokens=getattr(usage, "completion_tokens", 0) if usage else 0, + finish_reason=finish_reason or "stop", + tool_calls=tool_calls, + raw_response=response, + ) + + @staticmethod + def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not tools or tools[0].get("type") == "function": + return tools + return [ + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["input_schema"], + }, + } + for tool in tools + ] diff --git a/src/llm/registry.py b/src/llm/registry.py index 73cf60c8b..71347f8ab 100644 --- a/src/llm/registry.py +++ b/src/llm/registry.py @@ -22,6 +22,7 @@ from .backend import ProviderBackend from .backends.anthropic import AnthropicBackend from .backends.gemini import GeminiBackend +from .backends.litellm import LiteLLMBackend from .backends.openai import OpenAIBackend from .credentials import default_transport_api_key from .history_adapters import ( @@ -108,7 +109,7 @@ def get_gemini_override_client( def client_for_model_config( provider: ModelTransport, model_config: ModelConfig, -) -> ProviderClient: +) -> ProviderClient | None: """Resolve the provider client for a ModelConfig. Fast path: no overrides → reuse the module-level default client from @@ -131,12 +132,14 @@ def client_for_model_config( return get_openai_override_client(base_url, api_key) if provider == "gemini": return get_gemini_override_client(base_url, api_key) + if provider == "litellm": + return None # LiteLLMBackend manages its own credentials assert_never(provider) def backend_for_provider( provider: ModelTransport, - client: ProviderClient, + client: ProviderClient | None, ) -> ProviderBackend: """Wrap a raw provider SDK client in the matching ProviderBackend adapter.""" if provider == "anthropic": @@ -145,6 +148,8 @@ def backend_for_provider( return OpenAIBackend(client) if provider == "gemini": return GeminiBackend(client) + if provider == "litellm": + return LiteLLMBackend() assert_never(provider) @@ -154,11 +159,11 @@ def history_adapter_for_provider(provider: ModelTransport) -> HistoryAdapter: return AnthropicHistoryAdapter() if provider == "gemini": return GeminiHistoryAdapter() - return OpenAIHistoryAdapter() + return OpenAIHistoryAdapter() # litellm uses OpenAI message format def get_backend(config: ModelConfig) -> ProviderBackend: - """High-level one-shot backend factory: ModelConfig → ProviderBackend. + """High-level one-shot backend factory: ModelConfig -> ProviderBackend. Delegates client resolution to ``client_for_model_config``, which owns the CLIENTS fast-path and the missing-API-key validation. Both the @@ -166,6 +171,8 @@ def get_backend(config: ModelConfig) -> ProviderBackend: (via this function) now construct clients through the same helper, so validation behavior stays consistent. """ + if config.transport == "litellm": + return LiteLLMBackend(api_key=config.api_key, api_base=config.base_url) client = client_for_model_config(config.transport, config) return backend_for_provider(config.transport, client) diff --git a/src/llm/runtime.py b/src/llm/runtime.py index 2c29f397a..20f6d6963 100644 --- a/src/llm/runtime.py +++ b/src/llm/runtime.py @@ -71,7 +71,7 @@ class AttemptPlan: provider: ModelTransport model: str - client: ProviderClient + client: ProviderClient | None thinking_budget_tokens: int | None reasoning_effort: ReasoningEffortType selected_config: ModelConfig diff --git a/tests/test_litellm_backend.py b/tests/test_litellm_backend.py new file mode 100644 index 000000000..9d45170d6 --- /dev/null +++ b/tests/test_litellm_backend.py @@ -0,0 +1,160 @@ +"""Tests for the LiteLLM provider backend.""" + +from __future__ import annotations + +import json +import sys +import types +from typing import Any +from unittest import mock + +import pytest + +from src.llm.backend import CompletionResult, ToolCallResult + + +def _install_litellm_stub(): + fake = types.ModuleType("litellm") + fake.acompletion = mock.AsyncMock(name="litellm.acompletion") + sys.modules["litellm"] = fake + return fake + + +@pytest.fixture(autouse=True) +def litellm_stub(): + fake = _install_litellm_stub() + yield fake + sys.modules.pop("litellm", None) + + +def _mock_response(content: str = "Hello!", tool_calls: Any = None): + from types import SimpleNamespace + + msg = SimpleNamespace(content=content, tool_calls=tool_calls) + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5) + return SimpleNamespace( + choices=[SimpleNamespace(message=msg, finish_reason="stop")], + usage=usage, + ) + + +@pytest.mark.asyncio +async def test_complete_calls_acompletion(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response("test reply") + + from src.llm.backends.litellm import LiteLLMBackend + + backend = LiteLLMBackend(api_key="sk-test") + result = await backend.complete( + model="anthropic/claude-haiku-4-5", + messages=[{"role": "user", "content": "Hi"}], + max_tokens=100, + ) + + litellm_stub.acompletion.assert_called_once() + kwargs = litellm_stub.acompletion.call_args.kwargs + assert kwargs["model"] == "anthropic/claude-haiku-4-5" + assert kwargs["api_key"] == "sk-test" + assert kwargs["drop_params"] is True + assert isinstance(result, CompletionResult) + assert result.content == "test reply" + assert result.input_tokens == 10 + assert result.output_tokens == 5 + + +@pytest.mark.asyncio +async def test_complete_omits_blank_credentials(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response() + + from src.llm.backends.litellm import LiteLLMBackend + + backend = LiteLLMBackend() + await backend.complete( + model="openai/gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + max_tokens=100, + ) + + kwargs = litellm_stub.acompletion.call_args.kwargs + assert "api_key" not in kwargs + assert "api_base" not in kwargs + + +@pytest.mark.asyncio +async def test_complete_forwards_tools(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response() + + from src.llm.backends.litellm import LiteLLMBackend + + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + backend = LiteLLMBackend(api_key="k") + await backend.complete( + model="openai/gpt-4o", + messages=[{"role": "user", "content": "Find info"}], + max_tokens=100, + tools=tools, + tool_choice="auto", + ) + + kwargs = litellm_stub.acompletion.call_args.kwargs + assert kwargs["tool_choice"] == "auto" + assert len(kwargs["tools"]) == 1 + + +@pytest.mark.asyncio +async def test_complete_parses_tool_calls(litellm_stub): + from types import SimpleNamespace + + tc = SimpleNamespace( + id="call_1", + function=SimpleNamespace(name="search", arguments=json.dumps({"q": "test"})), + ) + litellm_stub.acompletion.return_value = _mock_response("", tool_calls=[tc]) + + from src.llm.backends.litellm import LiteLLMBackend + + backend = LiteLLMBackend(api_key="k") + result = await backend.complete( + model="openai/gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + max_tokens=100, + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "search" + assert result.tool_calls[0].input == {"q": "test"} + assert isinstance(result.tool_calls[0], ToolCallResult) + + +@pytest.mark.asyncio +async def test_complete_forwards_temperature(litellm_stub): + litellm_stub.acompletion.return_value = _mock_response() + + from src.llm.backends.litellm import LiteLLMBackend + + backend = LiteLLMBackend(api_key="k") + await backend.complete( + model="openai/gpt-4o", + messages=[{"role": "user", "content": "Hi"}], + max_tokens=100, + temperature=0.7, + ) + + kwargs = litellm_stub.acompletion.call_args.kwargs + assert kwargs["temperature"] == 0.7 + + +def test_model_transport_includes_litellm(): + from src.config import ModelTransport + from typing import get_args + + assert "litellm" in get_args(ModelTransport)