diff --git a/plugins/litellm/README.md b/plugins/litellm/README.md new file mode 100644 index 00000000..ec4896d0 --- /dev/null +++ b/plugins/litellm/README.md @@ -0,0 +1,24 @@ +# vision-agents-plugins-litellm + +LiteLLM plugin for [Vision Agents](https://github.com/GetStream/Vision-Agents), enabling access to 100+ LLM providers through a single unified interface. + +## Installation + +```bash +pip install vision-agents-plugins-litellm +``` + +## Usage + +```python +from vision_agents.plugins.litellm import LiteLLMChatCompletions + +# Use any litellm model string +llm = LiteLLMChatCompletions(model="anthropic/claude-sonnet-4-20250514") +llm = LiteLLMChatCompletions(model="azure/gpt-4o", api_key="...") +llm = LiteLLMChatCompletions(model="bedrock/anthropic.claude-3-haiku") +``` + +LiteLLM reads provider API keys from environment variables automatically (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, etc.). + +See https://docs.litellm.ai/docs/providers for all supported models. diff --git a/plugins/litellm/py.typed b/plugins/litellm/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/plugins/litellm/pyproject.toml b/plugins/litellm/pyproject.toml new file mode 100644 index 00000000..42beccea --- /dev/null +++ b/plugins/litellm/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "vision-agents-plugins-litellm" +dynamic = ["version"] +description = "LiteLLM plugin for Vision Agents - access 100+ LLM providers" +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +dependencies = [ + "vision-agents", + "litellm>=1.60.0,<2.0.0", +] + +[project.urls] +Documentation = "https://visionagents.ai/" +Website = "https://visionagents.ai/" +Source = "https://github.com/GetStream/Vision-Agents" + +[tool.hatch.version] +source = "vcs" +raw-options = { root = "..", search_parent_directories = true, fallback_version = "0.0.0" } + +[tool.hatch.build.targets.wheel] +packages = ["vision_agents"] diff --git a/plugins/litellm/tests/test_litellm_llm.py b/plugins/litellm/tests/test_litellm_llm.py new file mode 100644 index 00000000..469050e7 --- /dev/null +++ b/plugins/litellm/tests/test_litellm_llm.py @@ -0,0 +1,98 @@ +"""Tests for LiteLLM plugin.""" + +import ast +from pathlib import Path + +import pytest + +PLUGIN_PATH = ( + Path(__file__).resolve().parents[1] + / "vision_agents" + / "plugins" + / "litellm" + / "litellm_llm.py" +) + + +class TestLiteLLMPluginStructure: + def _parse(self): + return ast.parse(PLUGIN_PATH.read_text()) + + def test_file_exists(self): + assert PLUGIN_PATH.exists() + + def test_has_litellm_chat_completions_class(self): + tree = self._parse() + classes = [n.name for n in ast.walk(tree) if isinstance(n, ast.ClassDef)] + assert "LiteLLMChatCompletions" in classes + + def test_inherits_llm(self): + tree = self._parse() + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "LiteLLMChatCompletions": + base_names = [b.id for b in node.bases if isinstance(b, ast.Name)] + assert "LLM" in base_names + return + pytest.fail("LiteLLMChatCompletions not found") + + def test_has_simple_response(self): + tree = self._parse() + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "LiteLLMChatCompletions": + methods = [ + n.name for n in node.body if isinstance(n, ast.AsyncFunctionDef) + ] + assert "simple_response" in methods + assert "create_response" in methods + return + + def test_has_streaming_handler(self): + src = PLUGIN_PATH.read_text() + assert "_handle_streaming" in src + assert "_handle_non_streaming" in src + + def test_uses_drop_params_true(self): + src = PLUGIN_PATH.read_text() + assert '"drop_params": True' in src + + def test_uses_litellm_acompletion(self): + src = PLUGIN_PATH.read_text() + assert "litellm.acompletion" in src + + def test_emits_events(self): + src = PLUGIN_PATH.read_text() + assert "LLMRequestStartedEvent" in src + assert "LLMResponseChunkEvent" in src + assert "LLMResponseCompletedEvent" in src + + def test_plugin_name(self): + src = PLUGIN_PATH.read_text() + assert 'PLUGIN_NAME = "litellm"' in src + + def test_converts_tools_to_provider_format(self): + src = PLUGIN_PATH.read_text() + assert "_convert_tools_to_provider_format" in src + + def test_extracts_tool_calls(self): + src = PLUGIN_PATH.read_text() + assert "_extract_tool_calls_from_response" in src + + +class TestPluginPackage: + def test_pyproject_exists(self): + pyproject = Path(__file__).resolve().parents[1] / "pyproject.toml" + assert pyproject.exists() + + def test_litellm_in_dependencies(self): + pyproject = (Path(__file__).resolve().parents[1] / "pyproject.toml").read_text() + assert "litellm" in pyproject + + def test_init_exports_class(self): + init = ( + Path(__file__).resolve().parents[1] + / "vision_agents" + / "plugins" + / "litellm" + / "__init__.py" + ).read_text() + assert "LiteLLMChatCompletions" in init diff --git a/plugins/litellm/vision_agents/plugins/litellm/__init__.py b/plugins/litellm/vision_agents/plugins/litellm/__init__.py new file mode 100644 index 00000000..49bae6d1 --- /dev/null +++ b/plugins/litellm/vision_agents/plugins/litellm/__init__.py @@ -0,0 +1,14 @@ +"""LiteLLM plugin for Vision Agents. + +Routes to 100+ LLM providers (OpenAI, Anthropic, Google, Azure, Bedrock, +Ollama, etc.) via the litellm SDK. No proxy server needed. + +Model strings use the provider/model format, e.g. +anthropic/claude-sonnet-4-20250514, azure/gpt-4o, openai/gpt-4o. + +See https://docs.litellm.ai/docs/providers for all supported models. +""" + +from .litellm_llm import LiteLLMChatCompletions + +__all__ = ["LiteLLMChatCompletions"] diff --git a/plugins/litellm/vision_agents/plugins/litellm/litellm_llm.py b/plugins/litellm/vision_agents/plugins/litellm/litellm_llm.py new file mode 100644 index 00000000..00a0291d --- /dev/null +++ b/plugins/litellm/vision_agents/plugins/litellm/litellm_llm.py @@ -0,0 +1,268 @@ +"""LiteLLM Chat Completions LLM plugin. + +Provides a text-in/text-out LLM backed by the litellm SDK, +supporting 100+ providers through a single interface. +""" + +import json +import logging +import time + +import litellm +from litellm.types.utils import ModelResponse +from vision_agents.core.llm.events import ( + LLMRequestStartedEvent, + LLMResponseChunkEvent, + LLMResponseCompletedEvent, +) +from vision_agents.core.llm.llm import LLM, LLMResponseEvent +from vision_agents.core.llm.llm_types import NormalizedToolCallItem, ToolSchema + +logger = logging.getLogger(__name__) + +PLUGIN_NAME = "litellm" + + +class LiteLLMChatCompletions(LLM): + """LLM plugin that routes to 100+ providers via the litellm SDK. + + Examples: + + from vision_agents.plugins.litellm import LiteLLMChatCompletions + + llm = LiteLLMChatCompletions(model="anthropic/claude-sonnet-4-20250514") + llm = LiteLLMChatCompletions(model="azure/gpt-4o", api_key="...") + """ + + def __init__( + self, + model: str = "openai/gpt-4o", + api_key: str | None = None, + tools_max_rounds: int = 3, + ): + super().__init__() + self.model = model + self._api_key = api_key + self._tools_max_rounds = max(tools_max_rounds, 1) + self._pending_tool_calls: dict[int, dict[str, str]] = {} + + async def simple_response( + self, + text: str, + participant: object | None = None, + ) -> LLMResponseEvent: + """Request an LLM response for the given text. + + Args: + text: The text to respond to. + participant: Participant info. When None the message is added + to the conversation as a plain user message. When provided + the message is assumed to already be in the conversation + (e.g. added by the STT pipeline). + """ + if self._conversation is None: + logger.warning( + 'Cannot request a response from "%s" ' + "- conversation not initialized yet.", + self.model, + ) + return LLMResponseEvent(original=None, text="") + + if participant is None: + await self._conversation.send_message( + role="user", user_id="user", content=text + ) + + messages = await self._build_model_request() + return await self.create_response(messages=messages) + + async def create_response( + self, + messages: list[dict[str, object]] | None = None, + stream: bool = True, + **kwargs: object, + ) -> LLMResponseEvent: + if messages is None: + messages = await self._build_model_request() + + tools_param = None + tools_spec = self.get_available_functions() + if tools_spec: + tools_param = self._convert_tools_to_provider_format(tools_spec) + + return await self._create_response_internal( + messages=messages, + tools=tools_param, + stream=stream, + **kwargs, + ) + + async def _create_response_internal( + self, + messages: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, + stream: bool = True, + **kwargs: object, + ) -> LLMResponseEvent: + model = str(kwargs.get("model", self.model)) + + params: dict[str, object] = { + "model": model, + "messages": messages, + "stream": stream, + "drop_params": True, + } + if self._api_key: + params["api_key"] = self._api_key + if tools: + params["tools"] = tools + + self.events.send( + LLMRequestStartedEvent( + plugin_name=PLUGIN_NAME, + model=model, + streaming=stream, + ) + ) + + request_start = time.perf_counter() + + try: + if stream: + return await self._handle_streaming(params, request_start) + return await self._handle_non_streaming(params, request_start) + except litellm.exceptions.AuthenticationError as exc: + logger.exception("LiteLLM auth error for model %s", model) + return LLMResponseEvent(original=None, text="", exception=exc) + except litellm.exceptions.RateLimitError as exc: + logger.exception("LiteLLM rate limit for model %s", model) + return LLMResponseEvent(original=None, text="", exception=exc) + except litellm.exceptions.APIConnectionError as exc: + logger.exception("LiteLLM connection error for model %s", model) + return LLMResponseEvent(original=None, text="", exception=exc) + except litellm.exceptions.APIError as exc: + logger.exception("LiteLLM API error for model %s", model) + return LLMResponseEvent(original=None, text="", exception=exc) + + async def _handle_streaming( + self, params: dict[str, object], request_start: float + ) -> LLMResponseEvent: + response = await litellm.acompletion(**params) + + content_parts: list[str] = [] + first_token_time: float | None = None + self._pending_tool_calls = {} + + async for chunk in response: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + + if delta.content: + if first_token_time is None: + first_token_time = time.perf_counter() + content_parts.append(delta.content) + self.events.send( + LLMResponseChunkEvent( + plugin_name=PLUGIN_NAME, + text=delta.content, + ) + ) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index if tc.index is not None else 0 + entry = self._pending_tool_calls.setdefault( + idx, {"id": "", "name": "", "arguments": ""} + ) + if tc.id: + entry["id"] = tc.id + if tc.function and tc.function.name: + entry["name"] = tc.function.name + if tc.function and tc.function.arguments: + entry["arguments"] += tc.function.arguments + + full_text = "".join(content_parts) + total_time = time.perf_counter() - request_start + ttft = (first_token_time - request_start) if first_token_time else total_time + + self.events.send( + LLMResponseCompletedEvent( + plugin_name=PLUGIN_NAME, + text=full_text, + ttft=ttft, + duration=total_time, + ) + ) + + return LLMResponseEvent(original=None, text=full_text) + + async def _handle_non_streaming( + self, params: dict[str, object], request_start: float + ) -> LLMResponseEvent: + response = await litellm.acompletion(**params) + + if not response.choices: + logger.warning("LiteLLM returned empty choices") + return LLMResponseEvent(original=response, text="") + + content = response.choices[0].message.content or "" + total_time = time.perf_counter() - request_start + + self.events.send( + LLMResponseCompletedEvent( + plugin_name=PLUGIN_NAME, + text=content, + ttft=total_time, + duration=total_time, + ) + ) + + return LLMResponseEvent(original=response, text=content) + + async def _build_model_request(self) -> list[dict[str, object]]: + messages: list[dict[str, object]] = [] + if self._instructions: + messages.append({"role": "system", "content": self._instructions}) + if self._conversation: + messages.extend(await self._conversation.get_messages()) + return messages + + def _convert_tools_to_provider_format( + self, tools: list[ToolSchema] + ) -> list[dict[str, object]]: + return [ + { + "type": "function", + "function": { + "name": t.name, + "description": t.description, + "parameters": t.parameters, + }, + } + for t in tools + ] + + def _extract_tool_calls_from_response( + self, response: ModelResponse + ) -> list[NormalizedToolCallItem]: + if not response.choices: + return [] + message = response.choices[0].message + if not message.tool_calls: + return [] + result: list[NormalizedToolCallItem] = [] + for tc in message.tool_calls: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + logger.warning("Malformed tool call arguments for %s", tc.function.name) + args = {} + result.append( + NormalizedToolCallItem( + id=tc.id, + name=tc.function.name, + arguments_json=args, + ) + ) + return result diff --git a/pyproject.toml b/pyproject.toml index b20992d2..b30c5819 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ vision-agents-plugins-assemblyai = { workspace = true } vision-agents-plugins-local = { workspace = true } vision-agents-plugins-anam = { workspace = true } vision-agents-plugins-sarvam = { workspace = true } +vision-agents-plugins-litellm = { workspace = true } [tool.uv] # Workspace-level override to resolve numpy version conflicts @@ -85,6 +86,7 @@ members = [ "plugins/local", "plugins/anam", "plugins/sarvam", + "plugins/litellm", ] exclude = [ "**/__pycache__",