From 985d987237558f73d9263230bdbd382d74af8a47 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:20:34 +0000 Subject: [PATCH] fix: add thread safety to callback registries in main.py - Add threading.RLock() protection for global callback dictionaries - Protect sync_display_callbacks and async_display_callbacks from race conditions - Protect approval_callback registration - Use lock/execute separation pattern to avoid deadlocks - Add comprehensive thread safety tests for callback registration Co-authored-by: Mervin Praison --- src/praisonai-agents/praisonaiagents/main.py | 47 ++++-- .../tests/unit/test_main_thread_safety.py | 154 ++++++++++++++++++ 2 files changed, 186 insertions(+), 15 deletions(-) create mode 100644 src/praisonai-agents/tests/unit/test_main_thread_safety.py diff --git a/src/praisonai-agents/praisonaiagents/main.py b/src/praisonai-agents/praisonaiagents/main.py index 42003ca7c..c7f52c99e 100644 --- a/src/praisonai-agents/praisonaiagents/main.py +++ b/src/praisonai-agents/praisonaiagents/main.py @@ -2,6 +2,7 @@ import time import json import logging +import threading from praisonaiagents._logging import get_logger from typing import List, Optional, Dict, Any, Union, Literal, Type from pydantic import BaseModel, ConfigDict @@ -33,6 +34,9 @@ def _rich(): # Global approval callback registry approval_callback = None +# Thread locks for callback registry protection +_callbacks_lock = threading.RLock() + # ───────────────────────────────────────────────────────────────────────────── # PraisonAI Unique Color Palette: "Elegant Intelligence" # Creates a visual narrative flow: Agent → Task → Working → Response @@ -128,10 +132,11 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F callback_fn: The callback function to register is_async (bool): Whether the callback is asynchronous """ - if is_async: - async_display_callbacks[display_type] = callback_fn - else: - sync_display_callbacks[display_type] = callback_fn + with _callbacks_lock: + if is_async: + async_display_callbacks[display_type] = callback_fn + else: + sync_display_callbacks[display_type] = callback_fn def register_approval_callback(callback_fn): """Register a global approval callback function for dangerous tool operations. @@ -140,7 +145,8 @@ def register_approval_callback(callback_fn): callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision """ global approval_callback - approval_callback = callback_fn + with _callbacks_lock: + approval_callback = callback_fn # Simplified aliases (consistent naming convention) add_display_callback = register_display_callback @@ -155,8 +161,12 @@ def execute_sync_callback(display_type: str, **kwargs): display_type (str): Type of display event **kwargs: Arguments to pass to the callback function """ - if display_type in sync_display_callbacks: - callback = sync_display_callbacks[display_type] + with _callbacks_lock: + if display_type in sync_display_callbacks: + callback = sync_display_callbacks[display_type] + + # Execute callback outside the lock to avoid potential deadlocks + if 'callback' in locals(): import inspect sig = inspect.signature(callback) @@ -179,10 +189,18 @@ async def execute_callback(display_type: str, **kwargs): """ import inspect + # Get callbacks under lock + sync_callback = None + async_callback = None + with _callbacks_lock: + if display_type in sync_display_callbacks: + sync_callback = sync_display_callbacks[display_type] + if display_type in async_display_callbacks: + async_callback = async_display_callbacks[display_type] + # Execute synchronous callback if registered - if display_type in sync_display_callbacks: - callback = sync_display_callbacks[display_type] - sig = inspect.signature(callback) + if sync_callback: + sig = inspect.signature(sync_callback) # Filter kwargs to what the callback accepts to maintain backward compatibility if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): @@ -193,12 +211,11 @@ async def execute_callback(display_type: str, **kwargs): supported_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: callback(**supported_kwargs)) + await loop.run_in_executor(None, lambda: sync_callback(**supported_kwargs)) # Execute asynchronous callback if registered - if display_type in async_display_callbacks: - callback = async_display_callbacks[display_type] - sig = inspect.signature(callback) + if async_callback: + sig = inspect.signature(async_callback) # Filter kwargs to what the callback accepts to maintain backward compatibility if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): @@ -208,7 +225,7 @@ async def execute_callback(display_type: str, **kwargs): # Only pass arguments that the callback signature supports supported_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} - await callback(**supported_kwargs) + await async_callback(**supported_kwargs) def _clean_display_content(content: str, max_length: int = 20000) -> str: """Helper function to clean and truncate content for display.""" diff --git a/src/praisonai-agents/tests/unit/test_main_thread_safety.py b/src/praisonai-agents/tests/unit/test_main_thread_safety.py new file mode 100644 index 000000000..e859a8800 --- /dev/null +++ b/src/praisonai-agents/tests/unit/test_main_thread_safety.py @@ -0,0 +1,154 @@ +""" +Additional thread safety tests for main.py callback registries. + +Tests the fixes made to global mutable state in main.py. +""" +import threading +import time +import pytest + + +def test_callback_registry_has_locks(): + """Test that callback registries are protected by locks.""" + import praisonaiagents.main as main + + # Verify lock exists + assert hasattr(main, '_callbacks_lock') + assert isinstance(main._callbacks_lock, type(threading.RLock())) + + +def test_concurrent_callback_registration(): + """Test concurrent callback registration is thread-safe.""" + import praisonaiagents.main as main + + # Clear initial state + main.sync_display_callbacks.clear() + main.async_display_callbacks.clear() + + errors = [] + num_threads = 5 + num_iterations = 10 + + def callback_worker(thread_id): + try: + for i in range(num_iterations): + # Register callbacks + main.register_display_callback( + f"sync_{thread_id}_{i}", + lambda **kwargs: None, + is_async=False + ) + main.register_display_callback( + f"async_{thread_id}_{i}", + lambda **kwargs: None, + is_async=True + ) + time.sleep(0.001) # Small delay + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + # Start threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=callback_worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for completion + for t in threads: + t.join() + + # Check results + assert len(errors) == 0, f"Thread errors: {errors}" + assert len(main.sync_display_callbacks) == num_threads * num_iterations + assert len(main.async_display_callbacks) == num_threads * num_iterations + + +def test_concurrent_approval_callback_registration(): + """Test concurrent approval callback registration is thread-safe.""" + import praisonaiagents.main as main + + errors = [] + num_threads = 5 + + def approval_worker(thread_id): + try: + for i in range(10): + main.register_approval_callback( + lambda func, args, risk: f"approval_{thread_id}_{i}" + ) + time.sleep(0.001) + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + # Start threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=approval_worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for completion + for t in threads: + t.join() + + # Check results + assert len(errors) == 0, f"Thread errors: {errors}" + assert main.approval_callback is not None + + +def test_callback_execution_with_concurrent_registration(): + """Test callback execution while registration happens concurrently.""" + import praisonaiagents.main as main + + # Set up test callback + call_count = [0] + call_lock = threading.Lock() + + def test_callback(**kwargs): + with call_lock: + call_count[0] += 1 + + main.register_display_callback('test_concurrent', test_callback, is_async=False) + + errors = [] + + def execution_worker(): + try: + for _ in range(20): + main.execute_sync_callback('test_concurrent', data='test') + time.sleep(0.001) + except Exception as e: + errors.append(f"Execution error: {e}") + + def registration_worker(): + try: + for i in range(20): + main.register_display_callback( + f'test_reg_{i}', + lambda **kwargs: None, + is_async=False + ) + time.sleep(0.001) + except Exception as e: + errors.append(f"Registration error: {e}") + + # Start both execution and registration threads + threads = [ + threading.Thread(target=execution_worker), + threading.Thread(target=registration_worker), + ] + + for t in threads: + t.start() + + for t in threads: + t.join() + + # Check results + assert len(errors) == 0, f"Thread errors: {errors}" + assert call_count[0] == 20 # All executions should succeed + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file