|
11 | 11 |
|
12 | 12 | logger = logging.getLogger(__name__) |
13 | 13 |
|
| 14 | +# Module-level cached implementation class |
| 15 | +_impl_cls = None |
14 | 16 |
|
15 | | -class PraisonAISessionDataStore: |
16 | | - """Adapter that bridges PraisonAI SessionStoreProtocol to aiui BaseDataStore. |
| 17 | + |
| 18 | +def _build_impl_cls(): |
| 19 | + """Build the implementation class with BaseDataStore inheritance. |
17 | 20 | |
18 | | - Note: This class defers dependency imports until instantiation to allow |
19 | | - test collection without optional modules installed. |
| 21 | + Thread-safe lazy class factory that imports BaseDataStore and creates |
| 22 | + a proper subclass. This avoids runtime __bases__ mutation and ensures |
| 23 | + BaseDataStore.__init__ is properly called. |
20 | 24 | """ |
| 25 | + global _impl_cls |
| 26 | + if _impl_cls is not None: |
| 27 | + return _impl_cls |
| 28 | + |
| 29 | + try: |
| 30 | + from praisonaiui.datastore import BaseDataStore |
| 31 | + except ImportError as e: |
| 32 | + raise ImportError( |
| 33 | + "praisonaiui is required for PraisonAISessionDataStore. " |
| 34 | + "Install with: pip install 'praisonai[ui]'" |
| 35 | + ) from e |
21 | 36 |
|
22 | | - def __init__(self, store: Optional[Any] = None): |
23 | | - """Initialize with an optional session store, defaults to hierarchical store. |
| 37 | + try: |
| 38 | + from praisonaiagents.session import get_hierarchical_session_store |
| 39 | + except ImportError as e: |
| 40 | + raise ImportError( |
| 41 | + "praisonaiagents is required for PraisonAISessionDataStore. " |
| 42 | + "Install with: pip install praisonaiagents" |
| 43 | + ) from e |
| 44 | + |
| 45 | + class _PraisonAISessionDataStoreImpl(BaseDataStore): |
| 46 | + """Implementation class that properly inherits from BaseDataStore.""" |
24 | 47 |
|
25 | | - Args: |
26 | | - store: Optional SessionStoreProtocol implementation |
27 | | - """ |
28 | | - # Lazy import dependencies only when class is instantiated |
29 | | - try: |
30 | | - from praisonaiui.datastore import BaseDataStore |
31 | | - # Make this class inherit from BaseDataStore at runtime |
32 | | - if BaseDataStore not in self.__class__.__bases__: |
33 | | - self.__class__.__bases__ = (BaseDataStore,) + self.__class__.__bases__ |
34 | | - except ImportError as e: |
35 | | - raise ImportError( |
36 | | - "praisonaiui is required for PraisonAISessionDataStore. " |
37 | | - "Install with: pip install 'praisonai[ui]'" |
38 | | - ) from e |
39 | | - |
40 | | - try: |
41 | | - from praisonaiagents.session import get_hierarchical_session_store |
42 | | - except ImportError as e: |
43 | | - raise ImportError( |
44 | | - "praisonaiagents is required for PraisonAISessionDataStore. " |
45 | | - "Install with: pip install praisonaiagents" |
46 | | - ) from e |
| 48 | + def __init__(self, store: Optional[Any] = None): |
| 49 | + """Initialize with an optional session store, defaults to hierarchical store. |
47 | 50 | |
48 | | - self._store = store or get_hierarchical_session_store() |
49 | | - |
50 | | - def _new_id(self) -> str: |
51 | | - """Generate a new session ID.""" |
52 | | - return str(uuid.uuid4()) |
53 | | - |
54 | | - async def list_sessions(self) -> list[dict[str, Any]]: |
55 | | - """List all available sessions.""" |
56 | | - # Check if store supports listing (DefaultSessionStore/HierarchicalSessionStore do) |
57 | | - list_fn = getattr(self._store, "list_sessions", None) |
58 | | - if list_fn is None: |
59 | | - return [] # Protocol implementation doesn't support listing |
60 | | - |
61 | | - try: |
62 | | - # DefaultSessionStore/HierarchicalSessionStore return list[dict] |
63 | | - return list_fn(limit=50) or [] |
64 | | - except Exception: |
65 | | - logger.exception("Failed to list sessions") |
66 | | - return [] |
67 | | - |
68 | | - async def get_session(self, session_id: str) -> Optional[dict[str, Any]]: |
69 | | - """Get a specific session by ID.""" |
70 | | - if not self._store.session_exists(session_id): |
71 | | - return None |
72 | | - |
73 | | - try: |
74 | | - chat_history = self._store.get_chat_history(session_id) |
| 51 | + Args: |
| 52 | + store: Optional SessionStoreProtocol implementation |
| 53 | + """ |
| 54 | + super().__init__() # Properly call BaseDataStore.__init__() |
| 55 | + self._store = store or get_hierarchical_session_store() |
| 56 | + |
| 57 | + def _new_id(self) -> str: |
| 58 | + """Generate a new session ID.""" |
| 59 | + return str(uuid.uuid4()) |
| 60 | + |
| 61 | + async def list_sessions(self) -> list[dict[str, Any]]: |
| 62 | + """List all available sessions.""" |
| 63 | + # Check if store supports listing (DefaultSessionStore/HierarchicalSessionStore do) |
| 64 | + list_fn = getattr(self._store, "list_sessions", None) |
| 65 | + if list_fn is None: |
| 66 | + return [] # Protocol implementation doesn't support listing |
| 67 | + |
| 68 | + try: |
| 69 | + # DefaultSessionStore/HierarchicalSessionStore return list[dict] |
| 70 | + return list_fn(limit=50) or [] |
| 71 | + except Exception: |
| 72 | + logger.exception("Failed to list sessions") |
| 73 | + return [] |
| 74 | + |
| 75 | + async def get_session(self, session_id: str) -> Optional[dict[str, Any]]: |
| 76 | + """Get a specific session by ID.""" |
| 77 | + if not self._store.session_exists(session_id): |
| 78 | + return None |
| 79 | + |
| 80 | + try: |
| 81 | + chat_history = self._store.get_chat_history(session_id) |
| 82 | + return { |
| 83 | + "id": session_id, |
| 84 | + "messages": chat_history or [], |
| 85 | + } |
| 86 | + except Exception: |
| 87 | + logger.exception("Failed to load session %s", session_id) |
| 88 | + return None |
| 89 | + |
| 90 | + async def create_session(self, session_id: Optional[str] = None) -> dict[str, Any]: |
| 91 | + """Create a new session.""" |
| 92 | + sid = session_id or self._new_id() |
| 93 | + # Sessions are created lazily on first add_message |
75 | 94 | return { |
76 | | - "id": session_id, |
77 | | - "messages": chat_history or [], |
| 95 | + "id": sid, |
| 96 | + "messages": [] |
78 | 97 | } |
79 | | - except Exception: |
80 | | - logger.exception("Failed to load session %s", session_id) |
81 | | - return None |
82 | | - |
83 | | - async def create_session(self, session_id: Optional[str] = None) -> dict[str, Any]: |
84 | | - """Create a new session.""" |
85 | | - sid = session_id or self._new_id() |
86 | | - # Sessions are created lazily on first add_message |
87 | | - return { |
88 | | - "id": sid, |
89 | | - "messages": [] |
90 | | - } |
91 | | - |
92 | | - async def delete_session(self, session_id: str) -> bool: |
93 | | - """Delete a session and return success status.""" |
94 | | - try: |
95 | | - return self._store.delete_session(session_id) |
96 | | - except Exception: |
97 | | - logger.exception("Failed to delete session %s", session_id) |
98 | | - return False |
99 | | - |
100 | | - async def add_message(self, session_id: str, message: dict[str, Any]): |
101 | | - """Add a message to a session.""" |
102 | | - self._store.add_message( |
103 | | - session_id=session_id, |
104 | | - role=message.get("role", "user"), |
105 | | - content=message.get("content", ""), |
106 | | - metadata=message.get("metadata") |
107 | | - ) |
108 | | - |
109 | | - async def get_messages(self, session_id: str) -> list[dict[str, Any]]: |
110 | | - """Get all messages for a session.""" |
111 | | - if not self._store.session_exists(session_id): |
112 | | - return [] |
113 | | - |
114 | | - try: |
115 | | - return self._store.get_chat_history(session_id) or [] |
116 | | - except Exception: |
117 | | - logger.exception("Failed to load messages for session %s", session_id) |
118 | | - return [] |
| 98 | + |
| 99 | + async def delete_session(self, session_id: str) -> bool: |
| 100 | + """Delete a session and return success status.""" |
| 101 | + try: |
| 102 | + return self._store.delete_session(session_id) |
| 103 | + except Exception: |
| 104 | + logger.exception("Failed to delete session %s", session_id) |
| 105 | + return False |
| 106 | + |
| 107 | + async def add_message(self, session_id: str, message: dict[str, Any]): |
| 108 | + """Add a message to a session.""" |
| 109 | + self._store.add_message( |
| 110 | + session_id=session_id, |
| 111 | + role=message.get("role", "user"), |
| 112 | + content=message.get("content", ""), |
| 113 | + metadata=message.get("metadata") |
| 114 | + ) |
| 115 | + |
| 116 | + async def get_messages(self, session_id: str) -> list[dict[str, Any]]: |
| 117 | + """Get all messages for a session.""" |
| 118 | + if not self._store.session_exists(session_id): |
| 119 | + return [] |
| 120 | + |
| 121 | + try: |
| 122 | + return self._store.get_chat_history(session_id) or [] |
| 123 | + except Exception: |
| 124 | + logger.exception("Failed to load messages for session %s", session_id) |
| 125 | + return [] |
| 126 | + |
| 127 | + _impl_cls = _PraisonAISessionDataStoreImpl |
| 128 | + return _impl_cls |
| 129 | + |
| 130 | + |
| 131 | +class PraisonAISessionDataStore: |
| 132 | + """Adapter that bridges PraisonAI SessionStoreProtocol to aiui BaseDataStore. |
| 133 | + |
| 134 | + Dependency imports are deferred until instantiation to allow test |
| 135 | + collection without optional modules installed. |
| 136 | + """ |
| 137 | + |
| 138 | + def __new__(cls, store: Optional[Any] = None): |
| 139 | + """Factory method that returns a properly configured implementation instance.""" |
| 140 | + impl_cls = _build_impl_cls() |
| 141 | + return impl_cls(store) |
0 commit comments