Skip to content

Commit d4a1b6e

Browse files
Merge pull request #1286 from MervinPraison/claude/issue-1280-20260407-2043
feat: add missing LLM and retrieval provider protocols
2 parents 01a20e3 + 5990367 commit d4a1b6e

3 files changed

Lines changed: 359 additions & 1 deletion

File tree

src/praisonai-agents/praisonaiagents/llm/__init__.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,38 @@ def __getattr__(name):
9898
from .llm import TokenUsage
9999
_lazy_cache[name] = TokenUsage
100100
return TokenUsage
101+
elif name == "LLMProviderProtocol":
102+
from .protocols import LLMProviderProtocol
103+
_lazy_cache[name] = LLMProviderProtocol
104+
return LLMProviderProtocol
105+
elif name == "ModelCapabilitiesProtocol":
106+
from .protocols import ModelCapabilitiesProtocol
107+
_lazy_cache[name] = ModelCapabilitiesProtocol
108+
return ModelCapabilitiesProtocol
109+
elif name == "LLMRateLimiterProtocol":
110+
from .protocols import LLMRateLimiterProtocol
111+
_lazy_cache[name] = LLMRateLimiterProtocol
112+
return LLMRateLimiterProtocol
113+
elif name == "LLMFailoverProtocol":
114+
from .protocols import LLMFailoverProtocol
115+
_lazy_cache[name] = LLMFailoverProtocol
116+
return LLMFailoverProtocol
117+
elif name == "LLMProviderError":
118+
from .protocols import LLMProviderError
119+
_lazy_cache[name] = LLMProviderError
120+
return LLMProviderError
121+
elif name == "RateLimitError":
122+
from .protocols import RateLimitError
123+
_lazy_cache[name] = RateLimitError
124+
return RateLimitError
125+
elif name == "ModelNotAvailableError":
126+
from .protocols import ModelNotAvailableError
127+
_lazy_cache[name] = ModelNotAvailableError
128+
return ModelNotAvailableError
129+
elif name == "ContextLengthExceededError":
130+
from .protocols import ContextLengthExceededError
131+
_lazy_cache[name] = ContextLengthExceededError
132+
return ContextLengthExceededError
101133

102134
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
103135

@@ -122,5 +154,15 @@ def __getattr__(name):
122154
"TaskComplexity",
123155
"create_routing_agent",
124156
"RateLimiter",
125-
"TokenUsage"
157+
"TokenUsage",
158+
# Protocols
159+
"LLMProviderProtocol",
160+
"ModelCapabilitiesProtocol",
161+
"LLMRateLimiterProtocol",
162+
"LLMFailoverProtocol",
163+
# Exceptions
164+
"LLMProviderError",
165+
"RateLimitError",
166+
"ModelNotAvailableError",
167+
"ContextLengthExceededError"
126168
]
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""
2+
LLM Provider Protocols for PraisonAI Agents.
3+
4+
Defines minimal protocol interfaces for LLM providers to enable
5+
extensibility without vendor lock-in to litellm or any specific provider.
6+
7+
No heavy imports - only stdlib and typing.
8+
"""
9+
10+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Protocol, Union, runtime_checkable
11+
12+
13+
@runtime_checkable
14+
class LLMProviderProtocol(Protocol):
15+
"""
16+
Protocol defining the interface that LLM providers must implement.
17+
18+
This enables switching between different LLM backends (litellm, openai,
19+
anthropic, local models, etc.) without modifying core agent code.
20+
21+
Implementations must provide both sync and async variants for
22+
production flexibility.
23+
"""
24+
25+
model: str
26+
27+
def chat(
28+
self,
29+
messages: List[Dict[str, Any]],
30+
*,
31+
tools: Optional[List[Dict[str, Any]]] = None,
32+
temperature: float = 0.0,
33+
max_tokens: Optional[int] = None,
34+
stream: bool = False,
35+
**kwargs: Any,
36+
) -> Union[Dict[str, Any], Iterator[Dict[str, Any]]]:
37+
"""
38+
Generate chat completion.
39+
40+
Args:
41+
messages: List of message dicts with 'role' and 'content'
42+
tools: Optional list of tool schemas for function calling
43+
temperature: Sampling temperature (0.0-2.0)
44+
max_tokens: Maximum tokens to generate
45+
stream: Whether to return streaming iterator
46+
**kwargs: Provider-specific options
47+
48+
Returns:
49+
Single response dict or streaming iterator of response chunks
50+
"""
51+
...
52+
53+
async def achat(
54+
self,
55+
messages: List[Dict[str, Any]],
56+
*,
57+
tools: Optional[List[Dict[str, Any]]] = None,
58+
temperature: float = 0.0,
59+
max_tokens: Optional[int] = None,
60+
stream: bool = False,
61+
**kwargs: Any,
62+
) -> Union[Dict[str, Any], AsyncIterator[Dict[str, Any]]]:
63+
"""
64+
Async version of chat.
65+
66+
Args:
67+
messages: List of message dicts with 'role' and 'content'
68+
tools: Optional list of tool schemas for function calling
69+
temperature: Sampling temperature (0.0-2.0)
70+
max_tokens: Maximum tokens to generate
71+
stream: Whether to return async streaming iterator
72+
**kwargs: Provider-specific options
73+
74+
Returns:
75+
Single response dict or async streaming iterator of response chunks
76+
"""
77+
...
78+
79+
def get_token_count(self, text: str) -> int:
80+
"""
81+
Count tokens in text for the current model.
82+
83+
Args:
84+
text: Text to count tokens for
85+
86+
Returns:
87+
Number of tokens
88+
"""
89+
...
90+
91+
def get_context_length(self) -> int:
92+
"""
93+
Get maximum context length for the current model.
94+
95+
Returns:
96+
Maximum context length in tokens
97+
"""
98+
...
99+
100+
101+
@runtime_checkable
102+
class ModelCapabilitiesProtocol(Protocol):
103+
"""
104+
Protocol for querying model capabilities.
105+
106+
Enables runtime capability detection for different models
107+
without hardcoded capability tables.
108+
"""
109+
110+
def supports_streaming(self, model: str) -> bool:
111+
"""Check if model supports streaming responses."""
112+
...
113+
114+
def supports_function_calling(self, model: str) -> bool:
115+
"""Check if model supports function/tool calling."""
116+
...
117+
118+
def supports_structured_output(self, model: str) -> bool:
119+
"""Check if model supports structured JSON output."""
120+
...
121+
122+
def supports_vision(self, model: str) -> bool:
123+
"""Check if model supports vision/image inputs."""
124+
...
125+
126+
def get_max_tokens(self, model: str) -> Optional[int]:
127+
"""Get maximum context length for model."""
128+
...
129+
130+
131+
@runtime_checkable
132+
class LLMRateLimiterProtocol(Protocol):
133+
"""
134+
Protocol for LLM request rate limiting.
135+
136+
Enables different rate limiting strategies without
137+
coupling to specific implementations.
138+
"""
139+
140+
async def acquire(self, tokens: int = 1) -> None:
141+
"""
142+
Acquire permission to make request.
143+
144+
Args:
145+
tokens: Number of tokens in the request
146+
147+
Raises:
148+
RateLimitError: If rate limit would be exceeded
149+
"""
150+
...
151+
152+
def can_proceed(self, tokens: int = 1) -> bool:
153+
"""
154+
Check if request can proceed without blocking.
155+
156+
Args:
157+
tokens: Number of tokens in the request
158+
159+
Returns:
160+
True if request can proceed immediately
161+
"""
162+
...
163+
164+
def get_wait_time(self, tokens: int = 1) -> float:
165+
"""
166+
Get seconds to wait before request can proceed.
167+
168+
Args:
169+
tokens: Number of tokens in the request
170+
171+
Returns:
172+
Seconds to wait (0.0 if can proceed immediately)
173+
"""
174+
...
175+
176+
177+
@runtime_checkable
178+
class LLMFailoverProtocol(Protocol):
179+
"""
180+
Protocol for LLM failover strategies.
181+
182+
Enables automatic fallback to alternative models
183+
when primary models fail or are unavailable.
184+
"""
185+
186+
def get_fallback_model(self, failed_model: str, error: Exception) -> Optional[str]:
187+
"""
188+
Get fallback model for failed request.
189+
190+
Args:
191+
failed_model: Model that failed
192+
error: Exception that occurred
193+
194+
Returns:
195+
Alternative model name or None if no fallback available
196+
"""
197+
...
198+
199+
def should_retry(self, error: Exception) -> bool:
200+
"""
201+
Check if error indicates retryable failure.
202+
203+
Args:
204+
error: Exception that occurred
205+
206+
Returns:
207+
True if should retry with fallback
208+
"""
209+
...
210+
211+
def get_retry_delay(self, attempt: int) -> float:
212+
"""
213+
Get delay before retry attempt.
214+
215+
Args:
216+
attempt: Retry attempt number (1-based)
217+
218+
Returns:
219+
Seconds to wait before retry
220+
"""
221+
...
222+
223+
224+
class LLMProviderError(Exception):
225+
"""Base exception for LLM provider errors."""
226+
227+
def __init__(self, message: str, provider: Optional[str] = None, model: Optional[str] = None):
228+
self.provider = provider
229+
self.model = model
230+
super().__init__(message)
231+
232+
233+
class RateLimitError(LLMProviderError):
234+
"""Raised when rate limit is exceeded."""
235+
236+
def __init__(self, message: Optional[str] = None, retry_after: Optional[float] = None, provider: Optional[str] = None, model: Optional[str] = None):
237+
self.retry_after = retry_after
238+
super().__init__(message or "Rate limit exceeded", provider=provider, model=model)
239+
240+
241+
class ModelNotAvailableError(LLMProviderError):
242+
"""Raised when requested model is not available."""
243+
244+
def __init__(self, model: str, provider: Optional[str] = None):
245+
message = f"Model '{model}' is not available{f' from provider {provider}' if provider else ''}"
246+
super().__init__(message, provider=provider, model=model)
247+
248+
249+
class ContextLengthExceededError(LLMProviderError):
250+
"""Raised when input exceeds model's context length."""
251+
252+
def __init__(self, tokens: int, max_tokens: int, provider: Optional[str] = None, model: Optional[str] = None):
253+
self.tokens = tokens
254+
self.max_tokens = max_tokens
255+
super().__init__(f"Input length ({tokens} tokens) exceeds model limit ({max_tokens} tokens)", provider=provider, model=model)

src/praisonai-agents/praisonaiagents/rag/protocols.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,67 @@ def stream(self, question: str, **kwargs) -> Iterator[str]:
118118
...
119119

120120

121+
@runtime_checkable
122+
class RetrievalStrategyProtocol(Protocol):
123+
"""
124+
Protocol for pluggable retrieval strategies.
125+
126+
Enables custom retrieval algorithms without modifying core RAG pipeline.
127+
Implementations can provide different retrieval approaches (semantic,
128+
keyword, hybrid, graph-based, etc.) through a unified interface.
129+
"""
130+
131+
name: str
132+
133+
def retrieve(
134+
self,
135+
query: str,
136+
knowledge_store: Any, # KnowledgeStoreProtocol
137+
*,
138+
limit: int = 10,
139+
filters: Optional[Dict[str, Any]] = None,
140+
**kwargs: Any,
141+
) -> List[Dict[str, Any]]:
142+
"""
143+
Retrieve relevant documents for query.
144+
145+
Args:
146+
query: Search query string
147+
knowledge_store: Knowledge store to search
148+
limit: Maximum number of results
149+
filters: Optional metadata filters
150+
**kwargs: Strategy-specific options
151+
152+
Returns:
153+
List of retrieved documents with metadata
154+
"""
155+
...
156+
157+
async def aretrieve(
158+
self,
159+
query: str,
160+
knowledge_store: Any, # KnowledgeStoreProtocol
161+
*,
162+
limit: int = 10,
163+
filters: Optional[Dict[str, Any]] = None,
164+
**kwargs: Any,
165+
) -> List[Dict[str, Any]]:
166+
"""
167+
Async version of retrieve.
168+
169+
Args:
170+
query: Search query string
171+
knowledge_store: Knowledge store to search
172+
limit: Maximum number of results
173+
filters: Optional metadata filters
174+
**kwargs: Strategy-specific options
175+
176+
Returns:
177+
List of retrieved documents with metadata
178+
"""
179+
...
180+
181+
121182
@runtime_checkable
122183
class GraphHookProtocol(Protocol):
123184
"""

0 commit comments

Comments
 (0)