Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 103 additions & 10 deletions src/praisonai-agents/praisonaiagents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import List, Optional, Dict, Any, Union, Literal, Type
from pydantic import BaseModel, ConfigDict
import asyncio
import contextvars
from copy import deepcopy

def _rich():
"""Lazy-load Rich display classes (cached by sys.modules after first call)."""
Expand All @@ -23,15 +25,107 @@ def _rich():

# Logging is already configured in _logging.py via __init__.py

# Global list to store error logs
error_logs = []
# Thread-safe context variables for multi-agent isolation
# Each agent context gets its own isolated state
_error_logs_context: contextvars.ContextVar[list] = contextvars.ContextVar('error_logs', default=[])
_sync_display_callbacks_context: contextvars.ContextVar[dict] = contextvars.ContextVar('sync_display_callbacks', default={})
_async_display_callbacks_context: contextvars.ContextVar[dict] = contextvars.ContextVar('async_display_callbacks', default={})
_approval_callback_context: contextvars.ContextVar = contextvars.ContextVar('approval_callback', default=None)

# Separate registries for sync and async callbacks
sync_display_callbacks = {}
async_display_callbacks = {}

# Global approval callback registry
approval_callback = None
class _ContextList:
"""Thread-safe wrapper for context-local lists with copy-on-write semantics."""

def __init__(self, context_var: contextvars.ContextVar):
self._context_var = context_var

def append(self, item):
current_list = self._context_var.get([])
# Create a copy to avoid modifying shared state
new_list = current_list.copy()
new_list.append(item)
self._context_var.set(new_list)

def extend(self, items):
current_list = self._context_var.get([])
new_list = current_list.copy()
new_list.extend(items)
self._context_var.set(new_list)

def clear(self):
self._context_var.set([])

def __iter__(self):
return iter(self._context_var.get([]))

def __len__(self):
return len(self._context_var.get([]))

def __getitem__(self, index):
return self._context_var.get([])[index]

def __bool__(self):
return bool(self._context_var.get([]))


class _ContextDict:
"""Thread-safe wrapper for context-local dicts with copy-on-write semantics."""

def __init__(self, context_var: contextvars.ContextVar):
self._context_var = context_var

def __getitem__(self, key):
return self._context_var.get({})[key]

def __setitem__(self, key, value):
current_dict = self._context_var.get({})
new_dict = current_dict.copy()
new_dict[key] = value
self._context_var.set(new_dict)

def __contains__(self, key):
return key in self._context_var.get({})

def get(self, key, default=None):
return self._context_var.get({}).get(key, default)

def keys(self):
return self._context_var.get({}).keys()

def values(self):
return self._context_var.get({}).values()

def items(self):
return self._context_var.get({}).items()

def __iter__(self):
return iter(self._context_var.get({}))

def __bool__(self):
return bool(self._context_var.get({}))


class _ContextVar:
"""Thread-safe wrapper for context-local variables."""

def __init__(self, context_var: contextvars.ContextVar):
self._context_var = context_var

def get(self):
return self._context_var.get(None)

def set(self, value):
self._context_var.set(value)

def __bool__(self):
return bool(self._context_var.get(None))


# Create thread-safe, context-isolated instances
error_logs = _ContextList(_error_logs_context)
sync_display_callbacks = _ContextDict(_sync_display_callbacks_context)
async_display_callbacks = _ContextDict(_async_display_callbacks_context)
approval_callback = _ContextVar(_approval_callback_context)

# ─────────────────────────────────────────────────────────────────────────────
# PraisonAI Unique Color Palette: "Elegant Intelligence"
Expand Down Expand Up @@ -134,13 +228,12 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F
sync_display_callbacks[display_type] = callback_fn

def register_approval_callback(callback_fn):
"""Register a global approval callback function for dangerous tool operations.
"""Register a context-local approval callback function for dangerous tool operations.

Args:
callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision
"""
global approval_callback
approval_callback = callback_fn
approval_callback.set(callback_fn)

# Simplified aliases (consistent naming convention)
add_display_callback = register_display_callback
Expand Down
250 changes: 250 additions & 0 deletions test_thread_safety_main_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
#!/usr/bin/env python3
"""
Comprehensive thread safety test for main.py global variables fix.

This test verifies that the contextvars-based fix properly isolates state
between concurrent agents/contexts, preventing race conditions and cross-contamination.
"""

import asyncio
import contextvars
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict

# Import the thread-safe globals
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents'))

from praisonaiagents.main import (
error_logs,
sync_display_callbacks,
async_display_callbacks,
approval_callback,
register_display_callback,
register_approval_callback
)


def test_error_logs_isolation():
"""Test that error logs are isolated between different contexts."""
results = {}

def worker(worker_id):
# Each worker runs in its own context - need to create a fresh context for each worker
def worker_task():
worker_errors = [f"Error {i} from worker {worker_id}" for i in range(10)]

for error in worker_errors:
error_logs.append(error)

# Verify only this worker's errors are visible
current_errors = list(error_logs)
return current_errors

# Run the task in a new context
ctx = contextvars.copy_context()
current_errors = ctx.run(worker_task)
results[worker_id] = current_errors

# Verify no contamination from other workers
for error in current_errors:
assert f"worker {worker_id}" in error, f"Cross-contamination detected: {error}"

# Run 10 concurrent workers
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(worker, i) for i in range(10)]
for future in futures:
future.result()

# Verify each worker saw only its own errors
for worker_id, errors in results.items():
assert len(errors) == 10, f"Worker {worker_id} saw {len(errors)} errors, expected 10"
for error in errors:
assert f"worker {worker_id}" in error, f"Worker {worker_id} saw foreign error: {error}"

print("✅ Error logs isolation test passed")


def test_callback_isolation():
"""Test that display callbacks are isolated between contexts."""
results = {}

def worker(worker_id):
def worker_task():
# Each worker registers its own callback
def my_callback(message=None, **kwargs):
return f"Worker {worker_id} processed: {message}"

register_display_callback('test_event', my_callback)

# Verify only this worker's callback is registered
callback = sync_display_callbacks.get('test_event')
assert callback is not None, f"Worker {worker_id} callback not found"

# Test the callback
result = callback(message=f"test message {worker_id}")
return result

# Run in fresh context
ctx = contextvars.copy_context()
result = ctx.run(worker_task)
results[worker_id] = result

# Verify isolation
assert f"Worker {worker_id}" in result, f"Callback isolation failed for worker {worker_id}"

# Run 5 concurrent workers
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(worker, i) for i in range(5)]
for future in futures:
future.result()

# Verify each worker got its own callback result
for worker_id, result in results.items():
assert f"Worker {worker_id}" in result, f"Callback result contamination for worker {worker_id}"
assert f"test message {worker_id}" in result, f"Message processing failed for worker {worker_id}"

print("✅ Callback isolation test passed")


def test_approval_callback_isolation():
"""Test that approval callbacks are isolated between contexts."""
results = {}

def worker(worker_id):
def worker_task():
# Each worker sets its own approval callback
def my_approval(func_name, args, risk_level):
return f"Worker {worker_id} approved {func_name}"

register_approval_callback(my_approval)

# Verify this worker's callback is set
current_callback = approval_callback.get()
assert current_callback is not None, f"Worker {worker_id} approval callback not set"

# Test the callback
result = current_callback("test_func", {}, "low")
return result

# Run in fresh context
ctx = contextvars.copy_context()
result = ctx.run(worker_task)
results[worker_id] = result

# Verify isolation
assert f"Worker {worker_id}" in result, f"Approval callback isolation failed for worker {worker_id}"

# Run 5 concurrent workers
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(worker, i) for i in range(5)]
for future in futures:
future.result()

# Verify each worker got its own approval result
for worker_id, result in results.items():
assert f"Worker {worker_id}" in result, f"Approval callback result contamination for worker {worker_id}"

print("✅ Approval callback isolation test passed")


async def test_async_callback_isolation():
"""Test that async callbacks are also properly isolated."""
results = {}

async def worker(worker_id):
# Create a task within the current context - asyncio automatically copies context
async def worker_task():
# Each worker registers its own async callback
async def my_async_callback(message=None, **kwargs):
await asyncio.sleep(0.01) # Small async operation
return f"Async Worker {worker_id} processed: {message}"

register_display_callback('async_test_event', my_async_callback, is_async=True)

# Verify only this worker's callback is registered
callback = async_display_callbacks.get('async_test_event')
assert callback is not None, f"Async Worker {worker_id} callback not found"

# Test the callback
result = await callback(message=f"async test message {worker_id}")
return result

result = await worker_task()
results[worker_id] = result

# Verify isolation
assert f"Async Worker {worker_id}" in result, f"Async callback isolation failed for worker {worker_id}"

# Run 5 concurrent async workers - each will run in its own context copy
tasks = [worker(i) for i in range(5)]
await asyncio.gather(*tasks)

# Verify each worker got its own callback result
for worker_id, result in results.items():
assert f"Async Worker {worker_id}" in result, f"Async callback result contamination for worker {worker_id}"
assert f"async test message {worker_id}" in result, f"Async message processing failed for worker {worker_id}"

print("✅ Async callback isolation test passed")


def test_context_inheritance():
"""Test that context variables provide proper isolation between parent and child contexts."""

# Test that contexts start clean by default
initial_errors = list(error_logs)
assert len(initial_errors) == 0, f"Expected clean context, found {len(initial_errors)} errors"

# Add something in the current context
error_logs.append("Current context error")
current_errors = list(error_logs)
assert len(current_errors) == 1, "Current context should have 1 error"

# Test that a new context starts fresh (doesn't inherit from this context unless explicitly copied)
def new_context_task():
new_errors = list(error_logs)
return len(new_errors)

ctx = contextvars.Context() # Create truly empty context
new_context_error_count = ctx.run(new_context_task)

# The new context should start with empty state (our ContextVar has default=[])
assert new_context_error_count == 0, f"New context should start empty, but had {new_context_error_count} errors"

# Original context should still have its error
final_current_errors = list(error_logs)
assert len(final_current_errors) == 1, "Original context should still have 1 error"

print("✅ Context inheritance test passed")


def run_all_tests():
"""Run all thread safety tests."""
print("🧪 Running thread safety tests for main.py globals fix...\n")

# Run each test in its own clean context
def run_test_in_clean_context(test_func):
ctx = contextvars.copy_context()
ctx.run(test_func)

# Run sync tests
run_test_in_clean_context(test_error_logs_isolation)
run_test_in_clean_context(test_callback_isolation)
run_test_in_clean_context(test_approval_callback_isolation)
run_test_in_clean_context(test_context_inheritance)

# Run async tests in clean context too
def run_async_test():
asyncio.run(test_async_callback_isolation())

run_test_in_clean_context(run_async_test)

print("\n🎉 All thread safety tests passed!")
print("✅ main.py global variables are now thread-safe with proper context isolation")


if __name__ == "__main__":
run_all_tests()