Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions simple_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3
"""
Simple test to verify thread safety fixes work correctly.
"""

import sys
import os
import contextvars

# Add the src directory to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents'))

from praisonaiagents.main import error_logs, sync_display_callbacks

def test_basic_context_isolation():
"""Test basic context isolation"""
print("Testing basic context isolation...")

# Clear any existing state
error_logs.clear()
sync_display_callbacks.clear()

print(f"Initial error_logs length: {len(error_logs)}")

# Add an error in main context
error_logs.append("Main context error")
print(f"After adding main error: {len(error_logs)}")

# Create a new context
ctx = contextvars.copy_context()

def in_new_context():
print(f"In new context, error_logs length: {len(error_logs)}")
error_logs.append("New context error")
print(f"After adding new context error: {len(error_logs)}")
return list(error_logs)

result = ctx.run(in_new_context)
print(f"Result from new context: {result}")
print(f"Back in main context, error_logs length: {len(error_logs)}")
print(f"Main context errors: {list(error_logs)}")

if __name__ == "__main__":
test_basic_context_isolation()
196 changes: 187 additions & 9 deletions src/praisonai-agents/praisonaiagents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
import json
import logging
import contextvars
import threading
from typing import List, Optional, Dict, Any, Union, Literal, Type
from pydantic import BaseModel, ConfigDict
import asyncio
Expand All @@ -23,15 +25,190 @@ def _rich():

# Logging is already configured in _logging.py via __init__.py

# Global list to store error logs
error_logs = []
# Thread-safe context variables for multi-agent safety
# Each agent context gets isolated state via context variables
# Note: Default values should be immutable to avoid sharing state between contexts
_error_logs: contextvars.ContextVar[List[Any]] = contextvars.ContextVar('error_logs')
_sync_display_callbacks: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar('sync_display_callbacks')
_async_display_callbacks: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar('async_display_callbacks')
_approval_callback: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar('approval_callback')

# Backward compatibility wrappers that maintain existing API
def _get_error_logs():
try:
return _error_logs.get()
except LookupError:
# Initialize with empty list for this context
_error_logs.set([])
return _error_logs.get()

def _set_error_logs(logs):
_error_logs.set(logs)

def _get_sync_display_callbacks():
try:
return _sync_display_callbacks.get()
except LookupError:
# Initialize with empty dict for this context
_sync_display_callbacks.set({})
return _sync_display_callbacks.get()

def _set_sync_display_callbacks(callbacks):
_sync_display_callbacks.set(callbacks)

def _get_async_display_callbacks():
try:
return _async_display_callbacks.get()
except LookupError:
# Initialize with empty dict for this context
_async_display_callbacks.set({})
return _async_display_callbacks.get()

def _set_async_display_callbacks(callbacks):
_async_display_callbacks.set(callbacks)

def _get_approval_callback():
try:
return _approval_callback.get()
except LookupError:
# No default callback
return None

def _set_approval_callback(callback):
_approval_callback.set(callback)

# Separate registries for sync and async callbacks
sync_display_callbacks = {}
async_display_callbacks = {}
# Thread-safe wrapper classes for backward compatibility
class _ContextList:
"""Thread-safe list wrapper using contextvars"""

def __init__(self, context_var):
self._context_var = context_var

def _get_current(self):
"""Get current list, initializing if needed"""
if self._context_var is _error_logs:
return _get_error_logs()
else:
try:
return self._context_var.get()
except LookupError:
self._context_var.set([])
return self._context_var.get()

def append(self, item):
current = self._get_current()
# Create a new list to avoid modifying shared state
new_list = current.copy()
new_list.append(item)
self._context_var.set(new_list)

def extend(self, items):
current = self._get_current()
# Create a new list to avoid modifying shared state
new_list = current.copy()
new_list.extend(items)
self._context_var.set(new_list)

def clear(self):
self._context_var.set([])

def __iter__(self):
return iter(self._get_current())

def __len__(self):
return len(self._get_current())

def __getitem__(self, index):
return self._get_current()[index]

def __setitem__(self, index, value):
current = self._get_current()
# Create a new list to avoid modifying shared state
new_list = current.copy()
new_list[index] = value
self._context_var.set(new_list)

class _ContextDict:
"""Thread-safe dict wrapper using contextvars"""

def __init__(self, context_var):
self._context_var = context_var

def _get_current(self):
"""Get current dict, initializing if needed"""
if self._context_var is _sync_display_callbacks:
return _get_sync_display_callbacks()
elif self._context_var is _async_display_callbacks:
return _get_async_display_callbacks()
else:
try:
return self._context_var.get()
except LookupError:
self._context_var.set({})
return self._context_var.get()

def __getitem__(self, key):
return self._get_current()[key]

def __setitem__(self, key, value):
current = self._get_current()
# Create a new dict to avoid modifying shared state
new_dict = current.copy()
new_dict[key] = value
self._context_var.set(new_dict)

def __contains__(self, key):
return key in self._get_current()

def get(self, key, default=None):
return self._get_current().get(key, default)

def keys(self):
return self._get_current().keys()

def values(self):
return self._get_current().values()

def items(self):
return self._get_current().items()

def clear(self):
self._context_var.set({})

class _ContextVar:
"""Thread-safe variable wrapper using contextvars"""

def __init__(self, context_var):
self._context_var = context_var

def _get_current(self):
"""Get current value"""
if self._context_var is _approval_callback:
return _get_approval_callback()
else:
try:
return self._context_var.get()
except LookupError:
return None

def __call__(self, *args, **kwargs):
# Allow calling the callback if it exists
callback = self._get_current()
if callback:
return callback(*args, **kwargs)
return None

def __bool__(self):
return self._get_current() is not None

def set(self, value):
self._context_var.set(value)

# Global approval callback registry
approval_callback = None
# For backward compatibility, expose wrapped globals that maintain thread safety
error_logs = _ContextList(_error_logs)
sync_display_callbacks = _ContextDict(_sync_display_callbacks)
async_display_callbacks = _ContextDict(_async_display_callbacks)
approval_callback = _ContextVar(_approval_callback)

# ─────────────────────────────────────────────────────────────────────────────
# PraisonAI Unique Color Palette: "Elegant Intelligence"
Expand Down Expand Up @@ -139,8 +316,7 @@ def register_approval_callback(callback_fn):
Args:
callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision
"""
global approval_callback
approval_callback = callback_fn
approval_callback.set(callback_fn)


# Simplified aliases (consistent naming convention)
Expand Down Expand Up @@ -420,6 +596,7 @@ def display_error(message: str, console=None):
title="⚠ Error",
border_style=PRAISON_COLORS["error"]
))
# Use thread-safe context-aware error logging
error_logs.append(message)

def display_generating(content: str = "", start_time: Optional[float] = None):
Expand Down Expand Up @@ -617,6 +794,7 @@ async def adisplay_error(message: str, console=None):
await execute_callback('error', message=message)

console.print(Panel.fit(Text(message, style="bold red"), title="Error", border_style="red"))
# Use thread-safe context-aware error logging
error_logs.append(message)

async def adisplay_generating(content: str = "", start_time: Optional[float] = None):
Expand Down
Loading