1+ """
2+ LLM Provider Protocols for PraisonAI Agents.
3+
4+ Defines minimal protocol interfaces for LLM providers to enable
5+ extensibility without vendor lock-in to litellm or any specific provider.
6+
7+ No heavy imports - only stdlib and typing.
8+ """
9+
10+ from typing import Any , AsyncIterator , Dict , Iterator , List , Optional , Protocol , Union , runtime_checkable
11+
12+
13+ @runtime_checkable
14+ class LLMProviderProtocol (Protocol ):
15+ """
16+ Protocol defining the interface that LLM providers must implement.
17+
18+ This enables switching between different LLM backends (litellm, openai,
19+ anthropic, local models, etc.) without modifying core agent code.
20+
21+ Implementations must provide both sync and async variants for
22+ production flexibility.
23+ """
24+
25+ model : str
26+
27+ def chat (
28+ self ,
29+ messages : List [Dict [str , Any ]],
30+ * ,
31+ tools : Optional [List [Dict [str , Any ]]] = None ,
32+ temperature : float = 0.0 ,
33+ max_tokens : Optional [int ] = None ,
34+ stream : bool = False ,
35+ ** kwargs : Any ,
36+ ) -> Union [Dict [str , Any ], Iterator [Dict [str , Any ]]]:
37+ """
38+ Generate chat completion.
39+
40+ Args:
41+ messages: List of message dicts with 'role' and 'content'
42+ tools: Optional list of tool schemas for function calling
43+ temperature: Sampling temperature (0.0-2.0)
44+ max_tokens: Maximum tokens to generate
45+ stream: Whether to return streaming iterator
46+ **kwargs: Provider-specific options
47+
48+ Returns:
49+ Single response dict or streaming iterator of response chunks
50+ """
51+ ...
52+
53+ async def achat (
54+ self ,
55+ messages : List [Dict [str , Any ]],
56+ * ,
57+ tools : Optional [List [Dict [str , Any ]]] = None ,
58+ temperature : float = 0.0 ,
59+ max_tokens : Optional [int ] = None ,
60+ stream : bool = False ,
61+ ** kwargs : Any ,
62+ ) -> Union [Dict [str , Any ], AsyncIterator [Dict [str , Any ]]]:
63+ """
64+ Async version of chat.
65+
66+ Args:
67+ messages: List of message dicts with 'role' and 'content'
68+ tools: Optional list of tool schemas for function calling
69+ temperature: Sampling temperature (0.0-2.0)
70+ max_tokens: Maximum tokens to generate
71+ stream: Whether to return async streaming iterator
72+ **kwargs: Provider-specific options
73+
74+ Returns:
75+ Single response dict or async streaming iterator of response chunks
76+ """
77+ ...
78+
79+ def get_token_count (self , text : str ) -> int :
80+ """
81+ Count tokens in text for the current model.
82+
83+ Args:
84+ text: Text to count tokens for
85+
86+ Returns:
87+ Number of tokens
88+ """
89+ ...
90+
91+ def get_context_length (self ) -> int :
92+ """
93+ Get maximum context length for the current model.
94+
95+ Returns:
96+ Maximum context length in tokens
97+ """
98+ ...
99+
100+
101+ @runtime_checkable
102+ class ModelCapabilitiesProtocol (Protocol ):
103+ """
104+ Protocol for querying model capabilities.
105+
106+ Enables runtime capability detection for different models
107+ without hardcoded capability tables.
108+ """
109+
110+ def supports_streaming (self , model : str ) -> bool :
111+ """Check if model supports streaming responses."""
112+ ...
113+
114+ def supports_function_calling (self , model : str ) -> bool :
115+ """Check if model supports function/tool calling."""
116+ ...
117+
118+ def supports_structured_output (self , model : str ) -> bool :
119+ """Check if model supports structured JSON output."""
120+ ...
121+
122+ def supports_vision (self , model : str ) -> bool :
123+ """Check if model supports vision/image inputs."""
124+ ...
125+
126+ def get_max_tokens (self , model : str ) -> Optional [int ]:
127+ """Get maximum context length for model."""
128+ ...
129+
130+
131+ @runtime_checkable
132+ class LLMRateLimiterProtocol (Protocol ):
133+ """
134+ Protocol for LLM request rate limiting.
135+
136+ Enables different rate limiting strategies without
137+ coupling to specific implementations.
138+ """
139+
140+ async def acquire (self , tokens : int = 1 ) -> None :
141+ """
142+ Acquire permission to make request.
143+
144+ Args:
145+ tokens: Number of tokens in the request
146+
147+ Raises:
148+ RateLimitError: If rate limit would be exceeded
149+ """
150+ ...
151+
152+ def can_proceed (self , tokens : int = 1 ) -> bool :
153+ """
154+ Check if request can proceed without blocking.
155+
156+ Args:
157+ tokens: Number of tokens in the request
158+
159+ Returns:
160+ True if request can proceed immediately
161+ """
162+ ...
163+
164+ def get_wait_time (self , tokens : int = 1 ) -> float :
165+ """
166+ Get seconds to wait before request can proceed.
167+
168+ Args:
169+ tokens: Number of tokens in the request
170+
171+ Returns:
172+ Seconds to wait (0.0 if can proceed immediately)
173+ """
174+ ...
175+
176+
177+ @runtime_checkable
178+ class LLMFailoverProtocol (Protocol ):
179+ """
180+ Protocol for LLM failover strategies.
181+
182+ Enables automatic fallback to alternative models
183+ when primary models fail or are unavailable.
184+ """
185+
186+ def get_fallback_model (self , failed_model : str , error : Exception ) -> Optional [str ]:
187+ """
188+ Get fallback model for failed request.
189+
190+ Args:
191+ failed_model: Model that failed
192+ error: Exception that occurred
193+
194+ Returns:
195+ Alternative model name or None if no fallback available
196+ """
197+ ...
198+
199+ def should_retry (self , error : Exception ) -> bool :
200+ """
201+ Check if error indicates retryable failure.
202+
203+ Args:
204+ error: Exception that occurred
205+
206+ Returns:
207+ True if should retry with fallback
208+ """
209+ ...
210+
211+ def get_retry_delay (self , attempt : int ) -> float :
212+ """
213+ Get delay before retry attempt.
214+
215+ Args:
216+ attempt: Retry attempt number (1-based)
217+
218+ Returns:
219+ Seconds to wait before retry
220+ """
221+ ...
222+
223+
224+ class LLMProviderError (Exception ):
225+ """Base exception for LLM provider errors."""
226+
227+ def __init__ (self , message : str , provider : Optional [str ] = None , model : Optional [str ] = None ):
228+ self .provider = provider
229+ self .model = model
230+ super ().__init__ (message )
231+
232+
233+ class RateLimitError (LLMProviderError ):
234+ """Raised when rate limit is exceeded."""
235+
236+ def __init__ (self , message : Optional [str ] = None , retry_after : Optional [float ] = None , provider : Optional [str ] = None , model : Optional [str ] = None ):
237+ self .retry_after = retry_after
238+ super ().__init__ (message or "Rate limit exceeded" , provider = provider , model = model )
239+
240+
241+ class ModelNotAvailableError (LLMProviderError ):
242+ """Raised when requested model is not available."""
243+
244+ def __init__ (self , model : str , provider : Optional [str ] = None ):
245+ message = f"Model '{ model } ' is not available{ f' from provider { provider } ' if provider else '' } "
246+ super ().__init__ (message , provider = provider , model = model )
247+
248+
249+ class ContextLengthExceededError (LLMProviderError ):
250+ """Raised when input exceeds model's context length."""
251+
252+ def __init__ (self , tokens : int , max_tokens : int , provider : Optional [str ] = None , model : Optional [str ] = None ):
253+ self .tokens = tokens
254+ self .max_tokens = max_tokens
255+ super ().__init__ (f"Input length ({ tokens } tokens) exceeds model limit ({ max_tokens } tokens)" , provider = provider , model = model )
0 commit comments