Skip to content

Commit 4fdc51c

Browse files
committed
Add provider registry system with custom provider support
Implements extensible provider registry allowing users to register custom providers (e.g., Cloudflare, Ollama) alongside built-in providers. Updates `providers` CLI command to show registered/builtin/custom providers with `list` and `info` subcommands. Enhances `createProvider()` to accept provider instances, spec objects, or strings. Updates `hasApiKey()` to support provider aliases and custom env vars. Exports registry functions (`register
1 parent ce5026f commit 4fdc51c

11 files changed

Lines changed: 2258 additions & 66 deletions

File tree

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
Custom Provider Registry Example (Python)
3+
4+
This example demonstrates how to register and use custom LLM providers
5+
with the PraisonAI Python wrapper.
6+
7+
Note: The Python provider registry is for custom provider extensions.
8+
Built-in providers (OpenAI, Anthropic, Google) are handled by LiteLLM
9+
in praisonaiagents automatically.
10+
"""
11+
12+
import sys
13+
sys.path.insert(0, '../../src/praisonai')
14+
15+
from praisonai.llm import (
16+
LLMProviderRegistry,
17+
register_llm_provider,
18+
unregister_llm_provider,
19+
has_llm_provider,
20+
list_llm_providers,
21+
create_llm_provider,
22+
get_default_llm_registry,
23+
parse_model_string
24+
)
25+
26+
27+
# Example 1: Simple Custom Provider
28+
# ---------------------------------
29+
30+
class SimpleCustomProvider:
31+
"""A minimal custom provider example."""
32+
33+
provider_id = "simple-custom"
34+
35+
def __init__(self, model_id: str, config: dict = None):
36+
self.model_id = model_id
37+
self.config = config or {}
38+
self.api_endpoint = self.config.get('api_endpoint', 'https://api.example.com')
39+
40+
def generate(self, prompt: str) -> str:
41+
"""Generate a response (simulated)."""
42+
print(f"[SimpleCustomProvider] Generating with model: {self.model_id}")
43+
print(f"[SimpleCustomProvider] Prompt: {prompt[:50]}...")
44+
return f"Response from {self.provider_id}/{self.model_id}: Hello! This is a simulated response."
45+
46+
47+
# Example 2: Ollama Provider
48+
# --------------------------
49+
50+
class OllamaProvider:
51+
"""Custom provider for local Ollama integration."""
52+
53+
provider_id = "ollama"
54+
55+
def __init__(self, model_id: str, config: dict = None):
56+
self.model_id = model_id
57+
self.config = config or {}
58+
self.base_url = self.config.get('base_url', 'http://localhost:11434')
59+
60+
def generate(self, prompt: str) -> str:
61+
"""Generate a response using Ollama API."""
62+
import requests
63+
64+
response = requests.post(
65+
f"{self.base_url}/api/generate",
66+
json={
67+
"model": self.model_id,
68+
"prompt": prompt,
69+
"stream": False
70+
}
71+
)
72+
73+
if response.status_code != 200:
74+
raise Exception(f"Ollama API error: {response.text}")
75+
76+
return response.json().get("response", "")
77+
78+
def generate_stream(self, prompt: str):
79+
"""Generate a streaming response using Ollama API."""
80+
import requests
81+
82+
response = requests.post(
83+
f"{self.base_url}/api/generate",
84+
json={
85+
"model": self.model_id,
86+
"prompt": prompt,
87+
"stream": True
88+
},
89+
stream=True
90+
)
91+
92+
for line in response.iter_lines():
93+
if line:
94+
import json
95+
data = json.loads(line)
96+
yield data.get("response", "")
97+
if data.get("done"):
98+
break
99+
100+
101+
# Example 3: Cloudflare Workers AI Provider
102+
# -----------------------------------------
103+
104+
class CloudflareProvider:
105+
"""Custom provider for Cloudflare Workers AI."""
106+
107+
provider_id = "cloudflare"
108+
109+
def __init__(self, model_id: str, config: dict = None):
110+
self.model_id = model_id
111+
self.config = config or {}
112+
self.account_id = self.config.get('account_id')
113+
self.api_token = self.config.get('api_token')
114+
115+
def generate(self, prompt: str) -> str:
116+
"""Generate a response using Cloudflare Workers AI."""
117+
import requests
118+
119+
if not self.account_id or not self.api_token:
120+
raise ValueError("Cloudflare account_id and api_token are required")
121+
122+
response = requests.post(
123+
f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{self.model_id}",
124+
headers={
125+
"Authorization": f"Bearer {self.api_token}",
126+
"Content-Type": "application/json"
127+
},
128+
json={"prompt": prompt}
129+
)
130+
131+
if response.status_code != 200:
132+
raise Exception(f"Cloudflare API error: {response.text}")
133+
134+
return response.json().get("result", {}).get("response", "")
135+
136+
137+
def main():
138+
print("=== Provider Registry Example (Python) ===\n")
139+
140+
# Check initial state
141+
print("Initial providers:", list_llm_providers())
142+
print()
143+
144+
# Register custom providers
145+
print("Registering custom providers...")
146+
register_llm_provider("simple-custom", SimpleCustomProvider)
147+
register_llm_provider("ollama", OllamaProvider, aliases=["local"])
148+
register_llm_provider("cloudflare", CloudflareProvider, aliases=["cf", "workers-ai"])
149+
print()
150+
151+
# Check providers after registration
152+
print("Providers after registration:", list_llm_providers())
153+
print("Has ollama:", has_llm_provider("ollama"))
154+
print("Has local (alias):", has_llm_provider("local"))
155+
print("Has cloudflare:", has_llm_provider("cloudflare"))
156+
print("Has cf (alias):", has_llm_provider("cf"))
157+
print()
158+
159+
# Parse model strings
160+
print("=== Model String Parsing ===\n")
161+
162+
test_strings = [
163+
"openai/gpt-4o-mini",
164+
"gpt-4o-mini",
165+
"claude-3-5-sonnet-latest",
166+
"gemini-2.0-flash",
167+
"ollama/llama2",
168+
"cloudflare/workers-ai-model"
169+
]
170+
171+
for model_str in test_strings:
172+
parsed = parse_model_string(model_str)
173+
print(f" '{model_str}' -> provider={parsed['provider_id']}, model={parsed['model_id']}")
174+
print()
175+
176+
# Create and use providers
177+
print("=== Using Custom Providers ===\n")
178+
179+
# Use simple custom provider
180+
provider = create_llm_provider("simple-custom/test-model")
181+
print(f"Created provider: {provider.provider_id}/{provider.model_id}")
182+
response = provider.generate("Hello, world!")
183+
print(f"Response: {response}")
184+
print()
185+
186+
# Use ollama provider via alias
187+
provider = create_llm_provider("local/llama2", config={"base_url": "http://localhost:11434"})
188+
print(f"Created provider: {provider.provider_id}/{provider.model_id}")
189+
print()
190+
191+
# Demonstrate error handling
192+
print("=== Error Handling ===\n")
193+
try:
194+
create_llm_provider("unknown-provider/model")
195+
except ValueError as e:
196+
print(f"Expected error: {e}")
197+
print()
198+
199+
# Demonstrate isolated registries
200+
print("=== Isolated Registries ===\n")
201+
202+
# Create isolated registry
203+
isolated_registry = LLMProviderRegistry()
204+
isolated_registry.register("isolated-provider", SimpleCustomProvider)
205+
206+
print(f"Default registry providers: {list_llm_providers()}")
207+
print(f"Isolated registry providers: {isolated_registry.list()}")
208+
209+
# Use isolated registry
210+
provider = create_llm_provider("isolated-provider/model", registry=isolated_registry)
211+
print(f"Created from isolated registry: {provider.provider_id}/{provider.model_id}")
212+
print()
213+
214+
# Cleanup
215+
print("=== Cleanup ===\n")
216+
unregister_llm_provider("simple-custom")
217+
unregister_llm_provider("ollama")
218+
unregister_llm_provider("cloudflare")
219+
print(f"Providers after cleanup: {list_llm_providers()}")
220+
221+
print("\n=== Example Complete ===")
222+
223+
224+
if __name__ == "__main__":
225+
main()

0 commit comments

Comments
 (0)