Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions src/praisonai-agents/praisonaiagents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import json
import logging
import threading
from praisonaiagents._logging import get_logger
from typing import List, Optional, Dict, Any, Union, Literal, Type
from pydantic import BaseModel, ConfigDict
Expand All @@ -23,15 +24,64 @@ def _rich():

# Logging is already configured in _logging.py via __init__.py

# Global list to store error logs
error_logs = []

# Separate registries for sync and async callbacks
sync_display_callbacks = {}
async_display_callbacks = {}
# Thread-safe global state protection
_main_globals_lock = threading.Lock()

# Global list to store error logs (protected by _main_globals_lock)
_error_logs = []

# Separate registries for sync and async callbacks (protected by _main_globals_lock)
_sync_display_callbacks = {}
_async_display_callbacks = {}

# Global approval callback registry (protected by _main_globals_lock)
_approval_callback = None

# Thread-safe accessor functions
def _get_error_logs():
"""Thread-safe access to error logs."""
with _main_globals_lock:
return _error_logs.copy()

def _add_error_log(error):
"""Thread-safe addition to error logs."""
with _main_globals_lock:
_error_logs.append(error)

def _get_sync_display_callbacks():
"""Thread-safe access to sync display callbacks."""
with _main_globals_lock:
return _sync_display_callbacks.copy()

def _get_async_display_callbacks():
"""Thread-safe access to async display callbacks."""
with _main_globals_lock:
return _async_display_callbacks.copy()

def _get_approval_callback():
"""Thread-safe access to approval callback."""
with _main_globals_lock:
return _approval_callback

def _set_approval_callback(callback):
"""Thread-safe setting of approval callback."""
with _main_globals_lock:
global _approval_callback
_approval_callback = callback

def _register_display_callback(display_type: str, callback_fn, is_async: bool = False):
"""Thread-safe registration of display callback."""
with _main_globals_lock:
if is_async:
_async_display_callbacks[display_type] = callback_fn
else:
_sync_display_callbacks[display_type] = callback_fn

# Global approval callback registry
approval_callback = None
# Backward compatibility - expose as module-level references
error_logs = _error_logs # Direct reference for read access, use helper functions for writes
sync_display_callbacks = _sync_display_callbacks # Direct reference for read access, use helper functions for writes
async_display_callbacks = _async_display_callbacks # Direct reference for read access, use helper functions for writes
approval_callback = _approval_callback # Direct reference for read access, use helper functions for writes

# ─────────────────────────────────────────────────────────────────────────────
# PraisonAI Unique Color Palette: "Elegant Intelligence"
Expand Down Expand Up @@ -128,19 +178,15 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F
callback_fn: The callback function to register
is_async (bool): Whether the callback is asynchronous
"""
if is_async:
async_display_callbacks[display_type] = callback_fn
else:
sync_display_callbacks[display_type] = callback_fn
_register_display_callback(display_type, callback_fn, is_async)

def register_approval_callback(callback_fn):
"""Register a global approval callback function for dangerous tool operations.

Args:
callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision
"""
global approval_callback
approval_callback = callback_fn
_set_approval_callback(callback_fn)

# Simplified aliases (consistent naming convention)
add_display_callback = register_display_callback
Expand All @@ -155,8 +201,9 @@ def execute_sync_callback(display_type: str, **kwargs):
display_type (str): Type of display event
**kwargs: Arguments to pass to the callback function
"""
if display_type in sync_display_callbacks:
callback = sync_display_callbacks[display_type]
sync_callbacks = _get_sync_display_callbacks()
if display_type in sync_callbacks:
callback = sync_callbacks[display_type]
import inspect
sig = inspect.signature(callback)

Expand All @@ -180,8 +227,9 @@ async def execute_callback(display_type: str, **kwargs):
import inspect

# Execute synchronous callback if registered
if display_type in sync_display_callbacks:
callback = sync_display_callbacks[display_type]
sync_callbacks = _get_sync_display_callbacks()
if display_type in sync_callbacks:
callback = sync_callbacks[display_type]
sig = inspect.signature(callback)

# Filter kwargs to what the callback accepts to maintain backward compatibility
Expand All @@ -196,8 +244,9 @@ async def execute_callback(display_type: str, **kwargs):
await loop.run_in_executor(None, lambda: callback(**supported_kwargs))

# Execute asynchronous callback if registered
if display_type in async_display_callbacks:
callback = async_display_callbacks[display_type]
async_callbacks = _get_async_display_callbacks()
if display_type in async_callbacks:
callback = async_callbacks[display_type]
sig = inspect.signature(callback)

# Filter kwargs to what the callback accepts to maintain backward compatibility
Expand Down Expand Up @@ -418,7 +467,7 @@ def display_error(message: str, console=None):
title="⚠ Error",
border_style=PRAISON_COLORS["error"]
))
error_logs.append(message)
_add_error_log(message)

def display_generating(content: str = "", start_time: Optional[float] = None):
if not content or not str(content).strip():
Expand Down Expand Up @@ -615,7 +664,7 @@ async def adisplay_error(message: str, console=None):
await execute_callback('error', message=message)

console.print(Panel.fit(Text(message, style="bold red"), title="Error", border_style="red"))
error_logs.append(message)
_add_error_log(message)

async def adisplay_generating(content: str = "", start_time: Optional[float] = None):
"""Async version of display_generating."""
Expand Down
69 changes: 69 additions & 0 deletions src/praisonai-agents/test_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python
"""
Test thread safety of main.py global state fixes.
"""
import threading
import time
import concurrent.futures

def test_thread_safe_callback_registration():
"""Test that display callback registration is thread-safe."""
from praisonaiagents.main import register_display_callback, _get_sync_display_callbacks, _get_async_display_callbacks

def register_callbacks_concurrently(worker_id):
"""Register both sync and async callbacks from multiple threads."""
register_display_callback(f'sync_test_{worker_id}', lambda: f'sync_{worker_id}', is_async=False)
time.sleep(0.001)
register_display_callback(f'async_test_{worker_id}', lambda: f'async_{worker_id}', is_async=True)
return worker_id

# Run 20 concurrent workers
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(register_callbacks_concurrently, i) for i in range(20)]
concurrent.futures.wait(futures)

# Verify all callbacks were registered successfully
sync_callbacks = _get_sync_display_callbacks()
async_callbacks = _get_async_display_callbacks()

sync_count = len([k for k in sync_callbacks.keys() if k.startswith('sync_test_')])
async_count = len([k for k in async_callbacks.keys() if k.startswith('async_test_')])

assert sync_count == 20, f"Expected 20 sync callbacks, got {sync_count}"
assert async_count == 20, f"Expected 20 async callbacks, got {async_count}"
print(f"✅ Thread-safe callback registration: {sync_count} sync + {async_count} async callbacks")


def test_thread_safe_error_logging():
"""Test that error logging is thread-safe."""
from praisonaiagents.main import _add_error_log, _get_error_logs

def log_errors_concurrently(worker_id):
"""Log multiple errors from different threads."""
for i in range(10):
_add_error_log(f"Thread-{worker_id}: Error message #{i}")
time.sleep(0.0001)
return worker_id

# Run 15 concurrent workers
with concurrent.futures.ThreadPoolExecutor(max_workers=15) as executor:
futures = [executor.submit(log_errors_concurrently, i) for i in range(15)]
concurrent.futures.wait(futures)

# Verify all errors were logged
all_errors = _get_error_logs()
thread_errors = [e for e in all_errors if e.startswith('Thread-')]

assert len(thread_errors) == 150, f"Expected 150 error logs, got {len(thread_errors)}"
print(f"✅ Thread-safe error logging: {len(thread_errors)} errors from 15 threads")


if __name__ == '__main__':
print("Testing thread safety fixes for main.py global state...")
print("=" * 60)

test_thread_safe_callback_registration()
test_thread_safe_error_logging()

print("=" * 60)
print("🎉 All thread safety tests passed!")