|
1 | 1 | """SQLite implementation of UiPathResumableStorageProtocol.""" |
2 | 2 |
|
3 | 3 | import json |
4 | | -from typing import cast |
| 4 | +from typing import Any, cast |
5 | 5 |
|
6 | 6 | from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver |
7 | 7 | from pydantic import BaseModel |
8 | | -from uipath.runtime import ( |
9 | | - UiPathApiTrigger, |
10 | | - UiPathResumeTrigger, |
11 | | - UiPathResumeTriggerName, |
12 | | - UiPathResumeTriggerType, |
13 | | -) |
| 8 | +from uipath.runtime import UiPathResumeTrigger |
14 | 9 |
|
15 | 10 |
|
16 | 11 | class SqliteResumableStorage: |
17 | | - """SQLite storage for resume triggers.""" |
| 12 | + """SQLite storage for resume triggers and arbitrary kv pairs.""" |
18 | 13 |
|
19 | 14 | def __init__( |
20 | | - self, memory: AsyncSqliteSaver, table_name: str = "__uipath_resume_triggers" |
| 15 | + self, |
| 16 | + memory: AsyncSqliteSaver, |
21 | 17 | ): |
22 | 18 | self.memory = memory |
23 | | - self.table_name = table_name |
| 19 | + self.rs_table_name = "__uipath_resume_triggers" |
| 20 | + self.kv_table_name = "__uipath_runtime_kv" |
24 | 21 | self._initialized = False |
25 | 22 |
|
26 | 23 | async def _ensure_table(self) -> None: |
27 | | - """Create table if needed.""" |
| 24 | + """Create tables if needed.""" |
28 | 25 | if self._initialized: |
29 | 26 | return |
30 | 27 |
|
31 | 28 | await self.memory.setup() |
32 | 29 | async with self.memory.lock, self.memory.conn.cursor() as cur: |
33 | | - await cur.execute(f""" |
34 | | - CREATE TABLE IF NOT EXISTS {self.table_name} ( |
| 30 | + # Enable WAL mode for high concurrency |
| 31 | + await cur.execute("PRAGMA journal_mode=WAL") |
| 32 | + |
| 33 | + await cur.execute( |
| 34 | + f""" |
| 35 | + CREATE TABLE IF NOT EXISTS {self.rs_table_name} ( |
35 | 36 | id INTEGER PRIMARY KEY AUTOINCREMENT, |
36 | | - type TEXT NOT NULL, |
37 | | - name TEXT NOT NULL, |
38 | | - key TEXT, |
39 | | - folder_key TEXT, |
40 | | - folder_path TEXT, |
41 | | - payload TEXT, |
| 37 | + runtime_id TEXT NOT NULL, |
| 38 | + interrupt_id TEXT NOT NULL, |
| 39 | + data TEXT NOT NULL, |
42 | 40 | timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')) |
43 | 41 | ) |
44 | | - """) |
45 | | - await self.memory.conn.commit() |
46 | | - self._initialized = True |
| 42 | + """ |
| 43 | + ) |
47 | 44 |
|
48 | | - async def save_trigger(self, trigger: UiPathResumeTrigger) -> None: |
49 | | - """Save resume trigger to database.""" |
50 | | - await self._ensure_table() |
| 45 | + await cur.execute( |
| 46 | + f""" |
| 47 | + CREATE INDEX IF NOT EXISTS idx_{self.rs_table_name}_runtime_id |
| 48 | + ON {self.rs_table_name}(runtime_id) |
| 49 | + """ |
| 50 | + ) |
51 | 51 |
|
52 | | - trigger_key = ( |
53 | | - trigger.api_resume.inbox_id if trigger.api_resume else trigger.item_key |
54 | | - ) |
55 | | - payload = trigger.payload |
56 | | - if payload: |
57 | | - payload = ( |
58 | | - ( |
59 | | - payload.model_dump() |
60 | | - if isinstance(payload, BaseModel) |
61 | | - else json.dumps(payload) |
| 52 | + await cur.execute( |
| 53 | + f""" |
| 54 | + CREATE TABLE IF NOT EXISTS {self.kv_table_name} ( |
| 55 | + runtime_id TEXT NOT NULL, |
| 56 | + namespace TEXT NOT NULL, |
| 57 | + key TEXT NOT NULL, |
| 58 | + value TEXT, |
| 59 | + timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')), |
| 60 | + PRIMARY KEY (runtime_id, namespace, key) |
62 | 61 | ) |
63 | | - if isinstance(payload, dict) |
64 | | - else str(payload) |
| 62 | + """ |
65 | 63 | ) |
66 | 64 |
|
| 65 | + await self.memory.conn.commit() |
| 66 | + |
| 67 | + self._initialized = True |
| 68 | + |
| 69 | + async def save_triggers( |
| 70 | + self, runtime_id: str, triggers: list[UiPathResumeTrigger] |
| 71 | + ) -> None: |
| 72 | + """Save resume triggers to database, replacing all existing triggers for this runtime_id.""" |
| 73 | + await self._ensure_table() |
| 74 | + |
67 | 75 | async with self.memory.lock, self.memory.conn.cursor() as cur: |
| 76 | + # Delete all existing triggers for this runtime_id |
68 | 77 | await cur.execute( |
69 | | - f"INSERT INTO {self.table_name} (type, key, name, payload, folder_path, folder_key) VALUES (?, ?, ?, ?, ?, ?)", |
70 | | - ( |
71 | | - trigger.trigger_type.value, |
72 | | - trigger_key, |
73 | | - trigger.trigger_name.value, |
74 | | - payload, |
75 | | - trigger.folder_path, |
76 | | - trigger.folder_key, |
77 | | - ), |
| 78 | + f""" |
| 79 | + DELETE FROM {self.rs_table_name} |
| 80 | + WHERE runtime_id = ? |
| 81 | + """, |
| 82 | + (runtime_id,), |
78 | 83 | ) |
| 84 | + |
| 85 | + # Insert new triggers |
| 86 | + for trigger in triggers: |
| 87 | + trigger_data = trigger.model_dump() |
| 88 | + trigger_data["payload"] = trigger.payload |
| 89 | + trigger_data["trigger_name"] = trigger.trigger_name |
| 90 | + |
| 91 | + await cur.execute( |
| 92 | + f""" |
| 93 | + INSERT INTO {self.rs_table_name} |
| 94 | + (runtime_id, interrupt_id, data) |
| 95 | + VALUES (?, ?, ?) |
| 96 | + """, |
| 97 | + ( |
| 98 | + runtime_id, |
| 99 | + trigger.interrupt_id, |
| 100 | + json.dumps(trigger_data), |
| 101 | + ), |
| 102 | + ) |
79 | 103 | await self.memory.conn.commit() |
80 | 104 |
|
81 | | - async def get_latest_trigger(self) -> UiPathResumeTrigger | None: |
82 | | - """Get most recent trigger from database.""" |
| 105 | + async def get_triggers(self, runtime_id: str) -> list[UiPathResumeTrigger] | None: |
| 106 | + """Get all triggers for runtime_id from database.""" |
83 | 107 | await self._ensure_table() |
84 | 108 |
|
85 | 109 | async with self.memory.lock, self.memory.conn.cursor() as cur: |
86 | | - await cur.execute(f""" |
87 | | - SELECT type, key, name, folder_path, folder_key, payload |
88 | | - FROM {self.table_name} |
89 | | - ORDER BY timestamp DESC |
90 | | - LIMIT 1 |
91 | | - """) |
92 | | - result = await cur.fetchone() |
| 110 | + await cur.execute( |
| 111 | + f""" |
| 112 | + SELECT data |
| 113 | + FROM {self.rs_table_name} |
| 114 | + WHERE runtime_id = ? |
| 115 | + ORDER BY timestamp ASC |
| 116 | + """, |
| 117 | + (runtime_id,), |
| 118 | + ) |
| 119 | + results = await cur.fetchall() |
| 120 | + |
| 121 | + if not results: |
| 122 | + return None |
93 | 123 |
|
94 | | - if not result: |
95 | | - return None |
| 124 | + triggers = [] |
| 125 | + for result in results: |
| 126 | + data_text = cast(str, result[0]) |
| 127 | + trigger = UiPathResumeTrigger.model_validate_json(data_text) |
| 128 | + triggers.append(trigger) |
96 | 129 |
|
97 | | - trigger_type, key, name, folder_path, folder_key, payload = cast( |
98 | | - tuple[str, str, str, str, str, str], tuple(result) |
| 130 | + return triggers |
| 131 | + |
| 132 | + async def delete_trigger( |
| 133 | + self, runtime_id: str, trigger: UiPathResumeTrigger |
| 134 | + ) -> None: |
| 135 | + """Delete resume trigger from storage.""" |
| 136 | + await self._ensure_table() |
| 137 | + |
| 138 | + async with self.memory.lock, self.memory.conn.cursor() as cur: |
| 139 | + await cur.execute( |
| 140 | + f""" |
| 141 | + DELETE FROM {self.rs_table_name} |
| 142 | + WHERE runtime_id = ? AND interrupt_id = ? |
| 143 | + """, |
| 144 | + ( |
| 145 | + runtime_id, |
| 146 | + trigger.interrupt_id, |
| 147 | + ), |
99 | 148 | ) |
| 149 | + await self.memory.conn.commit() |
| 150 | + |
| 151 | + async def set_value( |
| 152 | + self, |
| 153 | + runtime_id: str, |
| 154 | + namespace: str, |
| 155 | + key: str, |
| 156 | + value: Any, |
| 157 | + ) -> None: |
| 158 | + """Save arbitrary key-value pair to database.""" |
| 159 | + if not ( |
| 160 | + isinstance(value, str) |
| 161 | + or isinstance(value, dict) |
| 162 | + or isinstance(value, BaseModel) |
| 163 | + or value is None |
| 164 | + ): |
| 165 | + raise TypeError("Value must be str, dict, BaseModel or None.") |
| 166 | + |
| 167 | + await self._ensure_table() |
100 | 168 |
|
101 | | - resume_trigger = UiPathResumeTrigger( |
102 | | - trigger_type=UiPathResumeTriggerType(trigger_type), |
103 | | - trigger_name=UiPathResumeTriggerName(name), |
104 | | - item_key=key, |
105 | | - folder_path=folder_path, |
106 | | - folder_key=folder_key, |
107 | | - payload=payload, |
| 169 | + value_text = self._dump_value(value) |
| 170 | + |
| 171 | + async with self.memory.lock, self.memory.conn.cursor() as cur: |
| 172 | + await cur.execute( |
| 173 | + f""" |
| 174 | + INSERT INTO {self.kv_table_name} (runtime_id, namespace, key, value) |
| 175 | + VALUES (?, ?, ?, ?) |
| 176 | + ON CONFLICT(runtime_id, namespace, key) |
| 177 | + DO UPDATE SET |
| 178 | + value = excluded.value, |
| 179 | + timestamp = (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')) |
| 180 | + """, |
| 181 | + (runtime_id, namespace, key, value_text), |
108 | 182 | ) |
| 183 | + await self.memory.conn.commit() |
109 | 184 |
|
110 | | - if resume_trigger.trigger_type == UiPathResumeTriggerType.API: |
111 | | - resume_trigger.api_resume = UiPathApiTrigger( |
112 | | - inbox_id=resume_trigger.item_key, request=resume_trigger.payload |
113 | | - ) |
| 185 | + async def get_value(self, runtime_id: str, namespace: str, key: str) -> Any: |
| 186 | + """Get arbitrary key-value pair from database (scoped by runtime_id + namespace).""" |
| 187 | + await self._ensure_table() |
114 | 188 |
|
115 | | - return resume_trigger |
| 189 | + async with self.memory.lock, self.memory.conn.cursor() as cur: |
| 190 | + await cur.execute( |
| 191 | + f""" |
| 192 | + SELECT value |
| 193 | + FROM {self.kv_table_name} |
| 194 | + WHERE runtime_id = ? AND namespace = ? AND key = ? |
| 195 | + LIMIT 1 |
| 196 | + """, |
| 197 | + (runtime_id, namespace, key), |
| 198 | + ) |
| 199 | + row = await cur.fetchone() |
| 200 | + |
| 201 | + if not row: |
| 202 | + return None |
| 203 | + |
| 204 | + return self._load_value(cast(str | None, row[0])) |
| 205 | + |
| 206 | + def _dump_value(self, value: str | dict[str, Any] | BaseModel | None) -> str | None: |
| 207 | + if value is None: |
| 208 | + return None |
| 209 | + if isinstance(value, BaseModel): |
| 210 | + return "j:" + json.dumps(value.model_dump()) |
| 211 | + if isinstance(value, dict): |
| 212 | + return "j:" + json.dumps(value) |
| 213 | + return "s:" + value |
| 214 | + |
| 215 | + def _load_value(self, raw: str | None) -> Any: |
| 216 | + if raw is None: |
| 217 | + return None |
| 218 | + if raw.startswith("s:"): |
| 219 | + return raw[2:] |
| 220 | + if raw.startswith("j:"): |
| 221 | + return json.loads(raw[2:]) |
| 222 | + return raw |
0 commit comments