Skip to content

Commit cfc221f

Browse files
arch: Fix 3 key architectural gaps for DRY, protocol-driven, multi-agent safety
Gap 1: Enable unified LLM dispatch by default - Changed _use_unified_llm_dispatch default from False to True - Removed dead if False: code block in chat_mixin.py:583-607 - Eliminates triple execution paths for better sync/async feature parity Gap 2: Extend provider adapter protocol for complete dispatch replacement - Added 6 new methods to LLMProviderAdapterProtocol: parse_tool_calls, should_skip_streaming_with_tools, recover_tool_calls_from_text, inject_cache_control, extract_reasoning_tokens - Implemented Ollama-specific tool recovery and Gemini streaming skip logic - Foundation to replace 24+ hardcoded provider checks with adapter calls Gap 3: Replace singletons with instance-based registries - Converted IndexRegistry from __new__ singleton to normal class with default() - Replaced module-level _memory_registry with get_default_memory_registry() - Added ServerRegistry class to replace global _server_* dicts in agent.py - Enables multi-agent safe execution and independent registries per agent Fixes #1362 Co-authored-by: MervinPraison <MervinPraison@users.noreply.github.com>
1 parent be72b83 commit cfc221f

File tree

6 files changed

+154
-64
lines changed

6 files changed

+154
-64
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,51 @@ def _is_file_path(value: str) -> bool:
169169
# Applied even when context management is disabled to prevent runaway tool outputs
170170
DEFAULT_TOOL_OUTPUT_LIMIT = 16000
171171

172-
# Global variables for API server (protected by _server_lock for thread safety)
173-
_server_lock = threading.Lock()
174-
_server_started = {} # Dict of port -> started boolean
175-
_registered_agents = {} # Dict of port -> Dict of path -> agent_id
176-
_shared_apps = {} # Dict of port -> FastAPI app
172+
class ServerRegistry:
173+
"""Registry for API server state per-port."""
174+
175+
def __init__(self):
176+
self._lock = threading.Lock()
177+
self._server_started = {} # Dict of port -> started boolean
178+
self._registered_agents = {} # Dict of port -> Dict of path -> agent_id
179+
self._shared_apps = {} # Dict of port -> FastAPI app
180+
181+
@staticmethod
182+
def get_default_instance():
183+
"""Get default global registry for backward compatibility."""
184+
if not hasattr(ServerRegistry, '_default_instance'):
185+
ServerRegistry._default_instance = ServerRegistry()
186+
return ServerRegistry._default_instance
187+
188+
def is_server_started(self, port: int) -> bool:
189+
with self._lock:
190+
return self._server_started.get(port, False)
191+
192+
def set_server_started(self, port: int, started: bool) -> None:
193+
with self._lock:
194+
self._server_started[port] = started
195+
196+
def get_shared_app(self, port: int):
197+
with self._lock:
198+
return self._shared_apps.get(port)
199+
200+
def set_shared_app(self, port: int, app) -> None:
201+
with self._lock:
202+
self._shared_apps[port] = app
203+
204+
def register_agent(self, port: int, path: str, agent_id: str) -> None:
205+
with self._lock:
206+
if port not in self._registered_agents:
207+
self._registered_agents[port] = {}
208+
self._registered_agents[port][path] = agent_id
209+
210+
def get_registered_agents(self, port: int) -> dict:
211+
with self._lock:
212+
return self._registered_agents.get(port, {}).copy()
213+
214+
# Backward compatibility - use default instance
215+
def _get_default_server_registry() -> ServerRegistry:
216+
return ServerRegistry.get_default_instance()
177217

178218
# Don't import FastAPI dependencies here - use lazy loading instead
179219

src/praisonai-agents/praisonaiagents/agent/chat_mixin.py

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,9 @@ def _chat_completion(self, messages, temperature=1.0, tools=None, stream=True, r
536536
formatted_tools = self._format_tools_for_completion(tools)
537537

538538
try:
539-
# NEW: Unified protocol dispatch path (Issue #1304)
540-
# Check if unified dispatch is enabled (opt-in for backward compatibility)
541-
if getattr(self, '_use_unified_llm_dispatch', False):
539+
# NEW: Unified protocol dispatch path (Issue #1304, #1362)
540+
# Enable unified dispatch by default for DRY and feature parity
541+
if getattr(self, '_use_unified_llm_dispatch', True):
542542
# Use composition instead of runtime class mutation for safety
543543
final_response = self._execute_unified_chat_completion(
544544
messages=messages,
@@ -579,50 +579,24 @@ def _chat_completion(self, messages, temperature=1.0, tools=None, stream=True, r
579579
reasoning_steps=reasoning_steps
580580
)
581581
else:
582-
# Non-streaming with custom LLM - don't show streaming-like behavior
583-
if False: # Don't use display_generating when stream=False to avoid streaming-like behavior
584-
# This block is disabled to maintain consistency with the OpenAI path fix
585-
with _get_live()(
586-
_get_display_functions()['display_generating']("", start_time),
587-
console=self.console,
588-
refresh_per_second=4,
589-
) as live:
590-
final_response = self.llm_instance.get_response(
591-
prompt=messages[1:],
592-
system_prompt=messages[0]['content'] if messages and messages[0]['role'] == 'system' else None,
593-
temperature=temperature,
594-
tools=formatted_tools if formatted_tools else None,
595-
verbose=self.verbose,
596-
markdown=self.markdown,
597-
stream=stream,
598-
console=self.console,
599-
execute_tool_fn=self.execute_tool,
600-
agent_name=self.name,
601-
agent_role=self.role,
602-
agent_tools=[getattr(t, '__name__', str(t)) for t in self.tools] if self.tools else None,
603-
task_name=task_name,
604-
task_description=task_description,
605-
task_id=task_id,
606-
reasoning_steps=reasoning_steps
607-
)
608-
else:
609-
final_response = self.llm_instance.get_response(
610-
prompt=messages[1:],
611-
system_prompt=messages[0]['content'] if messages and messages[0]['role'] == 'system' else None,
612-
temperature=temperature,
613-
tools=formatted_tools if formatted_tools else None,
614-
verbose=self.verbose,
615-
markdown=self.markdown,
616-
stream=stream,
617-
console=self.console,
618-
execute_tool_fn=self.execute_tool,
619-
agent_name=self.name,
620-
agent_role=self.role,
621-
agent_tools=[getattr(t, '__name__', str(t)) for t in self.tools] if self.tools else None,
622-
task_name=task_name,
623-
task_description=task_description,
624-
task_id=task_id,
625-
reasoning_steps=reasoning_steps
582+
# Non-streaming with custom LLM - direct execution
583+
final_response = self.llm_instance.get_response(
584+
prompt=messages[1:],
585+
system_prompt=messages[0]['content'] if messages and messages[0]['role'] == 'system' else None,
586+
temperature=temperature,
587+
tools=formatted_tools if formatted_tools else None,
588+
verbose=self.verbose,
589+
markdown=self.markdown,
590+
stream=stream,
591+
console=self.console,
592+
execute_tool_fn=self.execute_tool,
593+
agent_name=self.name,
594+
agent_role=self.role,
595+
agent_tools=[getattr(t, '__name__', str(t)) for t in self.tools] if self.tools else None,
596+
task_name=task_name,
597+
task_description=task_description,
598+
task_id=task_id,
599+
reasoning_steps=reasoning_steps
626600
)
627601
else:
628602
# Use the standard OpenAI client approach with tool support

src/praisonai-agents/praisonaiagents/knowledge/index.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,16 @@ def stats(self) -> IndexStats:
101101
class IndexRegistry:
102102
"""Registry for index types."""
103103

104-
_instance: Optional["IndexRegistry"] = None
104+
def __init__(self):
105+
"""Initialize a new registry instance."""
106+
self._indices: Dict[str, Callable[..., IndexProtocol]] = {}
105107

106-
def __new__(cls) -> "IndexRegistry":
107-
if cls._instance is None:
108-
cls._instance = super().__new__(cls)
109-
cls._instance._indices: Dict[str, Callable[..., IndexProtocol]] = {}
110-
return cls._instance
108+
@classmethod
109+
def default(cls) -> "IndexRegistry":
110+
"""Get a default global registry instance for convenience."""
111+
if not hasattr(cls, '_default_instance'):
112+
cls._default_instance = cls()
113+
return cls._default_instance
111114

112115
def register(self, name: str, factory: Callable[..., IndexProtocol]) -> None:
113116
"""Register an index factory."""

src/praisonai-agents/praisonaiagents/llm/adapters/__init__.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,25 @@ def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool:
5757

5858
def get_default_settings(self) -> Dict[str, Any]:
5959
return {} # No provider-specific defaults
60+
61+
def parse_tool_calls(self, raw_response: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
62+
"""Default tool call parsing - use OpenAI-style format."""
63+
if "choices" in raw_response and len(raw_response["choices"]) > 0:
64+
message = raw_response["choices"][0].get("message", {})
65+
return message.get("tool_calls")
66+
return None
67+
68+
def should_skip_streaming_with_tools(self) -> bool:
69+
return False # Most providers support streaming with tools
70+
71+
def recover_tool_calls_from_text(self, response_text: str, tools: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
72+
return None # No text recovery by default
73+
74+
def inject_cache_control(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
75+
return messages # No cache control by default
76+
77+
def extract_reasoning_tokens(self, response: Dict[str, Any]) -> int:
78+
return 0 # No reasoning tokens by default
6079

6180

6281
class OllamaAdapter(DefaultAdapter):
@@ -99,6 +118,29 @@ def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool:
99118
return True # Signal that special handling is needed
100119
return False
101120

121+
def recover_tool_calls_from_text(self, response_text: str, tools: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
122+
"""Ollama-specific tool call recovery from response text."""
123+
if not response_text or not tools:
124+
return None
125+
126+
try:
127+
import json
128+
response_json = json.loads(response_text.strip())
129+
if isinstance(response_json, dict) and "name" in response_json:
130+
# Convert Ollama format to standard tool_calls format
131+
return [{
132+
"id": f"call_{response_json['name']}_{hash(response_text) % 10000}",
133+
"type": "function",
134+
"function": {
135+
"name": response_json["name"],
136+
"arguments": json.dumps(response_json.get("arguments", {}))
137+
}
138+
}]
139+
except (json.JSONDecodeError, TypeError, KeyError):
140+
pass
141+
142+
return None
143+
102144
def post_tool_iteration(self, state: Dict[str, Any]) -> None:
103145
# Replaces: Ollama-specific post-tool summary branches
104146
if (not state.get('response_text', '').strip() and
@@ -142,6 +184,13 @@ class GeminiAdapter(DefaultAdapter):
142184
- Supports structured output
143185
"""
144186

187+
def should_skip_streaming_with_tools(self) -> bool:
188+
"""Gemini should skip streaming when tools are present."""
189+
return True
190+
191+
def supports_structured_output(self) -> bool:
192+
return True
193+
145194
def format_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
146195
# Replaces: gemini_internal_tools handling in llm.py
147196
# Internal tool names match GEMINI_INTERNAL_TOOLS: {'googleSearch', 'urlContext', 'codeExecution'}

src/praisonai-agents/praisonaiagents/llm/protocols.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,26 @@ def handle_empty_response_with_tools(self, state: Dict[str, Any]) -> bool:
326326
def get_default_settings(self) -> Dict[str, Any]:
327327
"""Get provider-specific default settings."""
328328
...
329+
330+
def parse_tool_calls(self, raw_response: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
331+
"""Parse tool calls from provider-specific response format."""
332+
...
333+
334+
def should_skip_streaming_with_tools(self) -> bool:
335+
"""Check if provider should skip streaming when tools are present."""
336+
...
337+
338+
def recover_tool_calls_from_text(self, response_text: str, tools: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
339+
"""Attempt to recover tool calls from response text for providers that don't format them properly."""
340+
...
341+
342+
def inject_cache_control(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
343+
"""Inject provider-specific cache control headers."""
344+
...
345+
346+
def extract_reasoning_tokens(self, response: Dict[str, Any]) -> int:
347+
"""Extract reasoning token count from provider-specific response."""
348+
...
329349

330350

331351
@runtime_checkable

src/praisonai-agents/praisonaiagents/memory/adapters/registry.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,32 @@ def __init__(self):
3333
super().__init__(adapter_type_name="Memory")
3434

3535

36-
# Global registry instance
37-
_memory_registry = MemoryAdapterRegistry()
36+
# Default registry instance for backward compatibility
37+
def get_default_memory_registry() -> MemoryAdapterRegistry:
38+
"""Get the default global registry instance for convenience."""
39+
if not hasattr(get_default_memory_registry, '_default_instance'):
40+
get_default_memory_registry._default_instance = MemoryAdapterRegistry()
41+
return get_default_memory_registry._default_instance
3842

3943

4044
def register_memory_adapter(name: str, adapter_class: Type[MemoryProtocol]) -> None:
4145
"""Register a memory adapter class."""
42-
_memory_registry.register_adapter(name, adapter_class)
46+
get_default_memory_registry().register_adapter(name, adapter_class)
4347

4448

4549
def register_memory_factory(name: str, factory_func: Callable[..., MemoryProtocol]) -> None:
4650
"""Register a memory adapter factory function."""
47-
_memory_registry.register_factory(name, factory_func)
51+
get_default_memory_registry().register_factory(name, factory_func)
4852

4953

5054
def get_memory_adapter(name: str, **kwargs) -> Optional[MemoryProtocol]:
5155
"""Get memory adapter instance by name."""
52-
return _memory_registry.get_adapter(name, **kwargs)
56+
return get_default_memory_registry().get_adapter(name, **kwargs)
5357

5458

5559
def list_memory_adapters() -> List[str]:
5660
"""List all registered memory adapter names."""
57-
return _memory_registry.list_adapters()
61+
return get_default_memory_registry().list_adapters()
5862

5963

6064
# Canonical aliases per AGENTS.md naming conventions

0 commit comments

Comments
 (0)