forked from lightspeed-core/lightspeed-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsqlite_context_store.py
More file actions
144 lines (115 loc) · 4.78 KB
/
sqlite_context_store.py
File metadata and controls
144 lines (115 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""SQLite implementation of A2A context store."""
from typing import Optional
from sqlalchemy import Column, String, Table, MetaData, select, delete
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
from a2a_storage.context_store import A2AContextStore
from log import get_logger
logger = get_logger(__name__)
# Define the table metadata
metadata = MetaData()
a2a_context_table = Table(
"a2a_contexts",
metadata,
Column("context_id", String, primary_key=True),
Column("conversation_id", String, nullable=False),
)
class SQLiteA2AContextStore(A2AContextStore):
"""SQLite implementation of A2A context-to-conversation store.
Stores context mappings in a SQLite database for persistence across
restarts and sharing across workers (when using a shared database file).
The store creates a table 'a2a_contexts' with the following schema:
context_id (TEXT, PRIMARY KEY): The A2A context ID
conversation_id (TEXT, NOT NULL): The Llama Stack conversation ID
"""
def __init__(
self,
engine: AsyncEngine,
create_table: bool = True,
) -> None:
"""Initialize the SQLite context store.
Args:
engine: SQLAlchemy async engine connected to the SQLite database.
create_table: If True, create the table on initialization.
"""
logger.debug("Initializing SQLiteA2AContextStore")
self._engine = engine
self._session_maker = async_sessionmaker(engine, expire_on_commit=False)
self._create_table = create_table
self._initialized = False
async def initialize(self) -> None:
"""Initialize the store and create tables if needed."""
if self._initialized:
return
logger.debug("Initializing SQLite A2A context store schema")
if self._create_table:
async with self._engine.begin() as conn:
await conn.run_sync(metadata.create_all)
self._initialized = True
logger.info("SQLiteA2AContextStore initialized successfully")
async def _ensure_initialized(self) -> None:
"""Ensure the store is initialized before use."""
if not self._initialized:
await self.initialize()
async def get(self, context_id: str) -> Optional[str]:
"""Retrieve the conversation ID for an A2A context.
Args:
context_id: The A2A context ID.
Returns:
The Llama Stack conversation ID, or None if not found.
"""
await self._ensure_initialized()
async with self._session_maker() as session:
stmt = select(a2a_context_table.c.conversation_id).where(
a2a_context_table.c.context_id == context_id
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row:
logger.debug("Context %s maps to conversation %s", context_id, row)
return row
logger.debug("Context %s not found in store", context_id)
return None
async def set(self, context_id: str, conversation_id: str) -> None:
"""Store a context-to-conversation mapping.
Uses delete-then-insert to handle both new and existing mappings.
Args:
context_id: The A2A context ID.
conversation_id: The Llama Stack conversation ID.
"""
await self._ensure_initialized()
async with self._session_maker.begin() as session:
# Upsert by deleting existing row and inserting new values
await session.execute(
a2a_context_table.delete().where(
a2a_context_table.c.context_id == context_id
)
)
await session.execute(
a2a_context_table.insert().values(
context_id=context_id,
conversation_id=conversation_id,
)
)
logger.debug(
"Stored mapping: context %s -> conversation %s",
context_id,
conversation_id,
)
async def delete(self, context_id: str) -> None:
"""Delete a context-to-conversation mapping.
Args:
context_id: The A2A context ID to delete.
"""
await self._ensure_initialized()
async with self._session_maker.begin() as session:
stmt = delete(a2a_context_table).where(
a2a_context_table.c.context_id == context_id
)
await session.execute(stmt)
logger.debug("Deleted context mapping for %s", context_id)
def ready(self) -> bool:
"""Check if the store is ready for use.
Returns:
True if the store is initialized, False otherwise.
"""
return self._initialized