Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions src/praisonai-agents/praisonaiagents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import json
import logging
import threading
from praisonaiagents._logging import get_logger
from typing import List, Optional, Dict, Any, Union, Literal, Type
from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -33,6 +34,9 @@ def _rich():
# Global approval callback registry
approval_callback = None

# Thread locks for callback registry protection
_callbacks_lock = threading.RLock()

# ─────────────────────────────────────────────────────────────────────────────
# PraisonAI Unique Color Palette: "Elegant Intelligence"
# Creates a visual narrative flow: Agent → Task → Working → Response
Expand Down Expand Up @@ -128,10 +132,11 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F
callback_fn: The callback function to register
is_async (bool): Whether the callback is asynchronous
"""
if is_async:
async_display_callbacks[display_type] = callback_fn
else:
sync_display_callbacks[display_type] = callback_fn
with _callbacks_lock:
if is_async:
async_display_callbacks[display_type] = callback_fn
else:
sync_display_callbacks[display_type] = callback_fn

def register_approval_callback(callback_fn):
"""Register a global approval callback function for dangerous tool operations.
Expand All @@ -140,7 +145,8 @@ def register_approval_callback(callback_fn):
callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision
"""
global approval_callback
approval_callback = callback_fn
with _callbacks_lock:
approval_callback = callback_fn

# Simplified aliases (consistent naming convention)
add_display_callback = register_display_callback
Expand All @@ -155,8 +161,12 @@ def execute_sync_callback(display_type: str, **kwargs):
display_type (str): Type of display event
**kwargs: Arguments to pass to the callback function
"""
if display_type in sync_display_callbacks:
callback = sync_display_callbacks[display_type]
with _callbacks_lock:
if display_type in sync_display_callbacks:
callback = sync_display_callbacks[display_type]

# Execute callback outside the lock to avoid potential deadlocks
if 'callback' in locals():
import inspect
sig = inspect.signature(callback)

Expand All @@ -179,10 +189,18 @@ async def execute_callback(display_type: str, **kwargs):
"""
import inspect

# Get callbacks under lock
sync_callback = None
async_callback = None
with _callbacks_lock:
if display_type in sync_display_callbacks:
sync_callback = sync_display_callbacks[display_type]
if display_type in async_display_callbacks:
async_callback = async_display_callbacks[display_type]

# Execute synchronous callback if registered
if display_type in sync_display_callbacks:
callback = sync_display_callbacks[display_type]
sig = inspect.signature(callback)
if sync_callback:
sig = inspect.signature(sync_callback)

# Filter kwargs to what the callback accepts to maintain backward compatibility
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
Expand All @@ -193,12 +211,11 @@ async def execute_callback(display_type: str, **kwargs):
supported_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}

loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: callback(**supported_kwargs))
await loop.run_in_executor(None, lambda: sync_callback(**supported_kwargs))

# Execute asynchronous callback if registered
if display_type in async_display_callbacks:
callback = async_display_callbacks[display_type]
sig = inspect.signature(callback)
if async_callback:
sig = inspect.signature(async_callback)

# Filter kwargs to what the callback accepts to maintain backward compatibility
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
Expand All @@ -208,7 +225,7 @@ async def execute_callback(display_type: str, **kwargs):
# Only pass arguments that the callback signature supports
supported_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}

await callback(**supported_kwargs)
await async_callback(**supported_kwargs)

def _clean_display_content(content: str, max_length: int = 20000) -> str:
"""Helper function to clean and truncate content for display."""
Expand Down
154 changes: 154 additions & 0 deletions src/praisonai-agents/tests/unit/test_main_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Additional thread safety tests for main.py callback registries.

Tests the fixes made to global mutable state in main.py.
"""
import threading
import time
import pytest


def test_callback_registry_has_locks():
"""Test that callback registries are protected by locks."""
import praisonaiagents.main as main

# Verify lock exists
assert hasattr(main, '_callbacks_lock')
assert isinstance(main._callbacks_lock, type(threading.RLock()))


def test_concurrent_callback_registration():
"""Test concurrent callback registration is thread-safe."""
import praisonaiagents.main as main

# Clear initial state
main.sync_display_callbacks.clear()
main.async_display_callbacks.clear()

errors = []
num_threads = 5
num_iterations = 10

def callback_worker(thread_id):
try:
for i in range(num_iterations):
# Register callbacks
main.register_display_callback(
f"sync_{thread_id}_{i}",
lambda **kwargs: None,
is_async=False
)
main.register_display_callback(
f"async_{thread_id}_{i}",
lambda **kwargs: None,
is_async=True
)
time.sleep(0.001) # Small delay
except Exception as e:
errors.append(f"Thread {thread_id}: {e}")

# Start threads
threads = []
for i in range(num_threads):
t = threading.Thread(target=callback_worker, args=(i,))
threads.append(t)
t.start()

# Wait for completion
for t in threads:
t.join()

# Check results
assert len(errors) == 0, f"Thread errors: {errors}"
assert len(main.sync_display_callbacks) == num_threads * num_iterations
assert len(main.async_display_callbacks) == num_threads * num_iterations


def test_concurrent_approval_callback_registration():
"""Test concurrent approval callback registration is thread-safe."""
import praisonaiagents.main as main

errors = []
num_threads = 5

def approval_worker(thread_id):
try:
for i in range(10):
main.register_approval_callback(
lambda func, args, risk: f"approval_{thread_id}_{i}"
)
time.sleep(0.001)
except Exception as e:
errors.append(f"Thread {thread_id}: {e}")

# Start threads
threads = []
for i in range(num_threads):
t = threading.Thread(target=approval_worker, args=(i,))
threads.append(t)
t.start()

# Wait for completion
for t in threads:
t.join()

# Check results
assert len(errors) == 0, f"Thread errors: {errors}"
assert main.approval_callback is not None


def test_callback_execution_with_concurrent_registration():
"""Test callback execution while registration happens concurrently."""
import praisonaiagents.main as main

# Set up test callback
call_count = [0]
call_lock = threading.Lock()

def test_callback(**kwargs):
with call_lock:
call_count[0] += 1

main.register_display_callback('test_concurrent', test_callback, is_async=False)

errors = []

def execution_worker():
try:
for _ in range(20):
main.execute_sync_callback('test_concurrent', data='test')
time.sleep(0.001)
except Exception as e:
errors.append(f"Execution error: {e}")

def registration_worker():
try:
for i in range(20):
main.register_display_callback(
f'test_reg_{i}',
lambda **kwargs: None,
is_async=False
)
time.sleep(0.001)
except Exception as e:
errors.append(f"Registration error: {e}")

# Start both execution and registration threads
threads = [
threading.Thread(target=execution_worker),
threading.Thread(target=registration_worker),
]

for t in threads:
t.start()

for t in threads:
t.join()

# Check results
assert len(errors) == 0, f"Thread errors: {errors}"
assert call_count[0] == 20 # All executions should succeed


if __name__ == "__main__":
pytest.main([__file__, "-v"])