Skip to content

Commit 8771b31

Browse files
authored
Merge pull request #372 from UiPath/fix/update_storage
fix: support multiple resume triggers + generic kv store
2 parents 316002d + a82fe5c commit 8771b31

File tree

12 files changed

+773
-164
lines changed

12 files changed

+773
-164
lines changed

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.1.44"
3+
version = "0.2.0"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"
77
dependencies = [
8-
"uipath>=2.2.44, <2.3.0",
8+
"uipath>=2.3.0, <2.4.0",
9+
"uipath-runtime>=0.3.2, <0.4.0",
910
"langgraph>=1.0.0, <2.0.0",
10-
"langchain-core>=1.0.0, <2.0.0",
11+
"langchain-core>=1.2.5, <2.0.0",
1112
"aiosqlite==0.21.0",
1213
"langgraph-checkpoint-sqlite>=3.0.0, <4.0.0",
1314
"langchain-openai>=1.0.0, <2.0.0",

src/uipath_langchain/runtime/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def _create_runtime_instance(
275275
delegate=base_runtime,
276276
storage=storage,
277277
trigger_manager=trigger_manager,
278+
runtime_id=runtime_id,
278279
)
279280

280281
async def new_runtime(

src/uipath_langchain/runtime/runtime.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -293,29 +293,7 @@ def _extract_graph_result(self, final_chunk: Any) -> Any:
293293

294294
def _is_interrupted(self, state: StateSnapshot) -> bool:
295295
"""Check if execution was interrupted (static or dynamic)."""
296-
# Check for static interrupts (interrupt_before/after)
297-
if hasattr(state, "next") and state.next:
298-
return True
299-
300-
# Check for dynamic interrupts (interrupt() inside node)
301-
if hasattr(state, "tasks"):
302-
for task in state.tasks:
303-
if hasattr(task, "interrupts") and task.interrupts:
304-
return True
305-
306-
return False
307-
308-
def _get_dynamic_interrupt(self, state: StateSnapshot) -> Interrupt | None:
309-
"""Get the first dynamic interrupt if any."""
310-
if not hasattr(state, "tasks"):
311-
return None
312-
313-
for task in state.tasks:
314-
if hasattr(task, "interrupts") and task.interrupts:
315-
for interrupt in task.interrupts:
316-
if isinstance(interrupt, Interrupt):
317-
return interrupt
318-
return None
296+
return bool(state.next)
319297

320298
async def _create_runtime_result(
321299
self,
@@ -344,13 +322,27 @@ async def _create_suspended_result(
344322
graph_state: StateSnapshot,
345323
) -> UiPathRuntimeResult:
346324
"""Create result for suspended execution."""
347-
# Check if it's a dynamic interrupt
348-
dynamic_interrupt = self._get_dynamic_interrupt(graph_state)
349-
350-
if dynamic_interrupt:
351-
# Dynamic interrupt - should create and save resume trigger
325+
interrupt_map: dict[str, Any] = {}
326+
327+
# Get nodes that are still scheduled to run
328+
next_nodes = set(graph_state.next) if graph_state.next else set()
329+
330+
if graph_state.interrupts:
331+
for interrupt in graph_state.interrupts:
332+
if isinstance(interrupt, Interrupt):
333+
# Find which task this interrupt belongs to
334+
for task in graph_state.tasks:
335+
if task.interrupts and interrupt in task.interrupts:
336+
# Only include if this task's node is still in next
337+
if task.name in next_nodes:
338+
interrupt_map[interrupt.id] = interrupt.value
339+
break
340+
341+
# If we have dynamic interrupts, return suspended with interrupt map
342+
# The output is used to create the resume triggers
343+
if interrupt_map:
352344
return UiPathRuntimeResult(
353-
output=dynamic_interrupt.value,
345+
output=interrupt_map,
354346
status=UiPathRuntimeStatus.SUSPENDED,
355347
)
356348
else:
Lines changed: 178 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,222 @@
11
"""SQLite implementation of UiPathResumableStorageProtocol."""
22

33
import json
4-
from typing import cast
4+
from typing import Any, cast
55

66
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
77
from pydantic import BaseModel
8-
from uipath.runtime import (
9-
UiPathApiTrigger,
10-
UiPathResumeTrigger,
11-
UiPathResumeTriggerName,
12-
UiPathResumeTriggerType,
13-
)
8+
from uipath.runtime import UiPathResumeTrigger
149

1510

1611
class SqliteResumableStorage:
17-
"""SQLite storage for resume triggers."""
12+
"""SQLite storage for resume triggers and arbitrary kv pairs."""
1813

1914
def __init__(
20-
self, memory: AsyncSqliteSaver, table_name: str = "__uipath_resume_triggers"
15+
self,
16+
memory: AsyncSqliteSaver,
2117
):
2218
self.memory = memory
23-
self.table_name = table_name
19+
self.rs_table_name = "__uipath_resume_triggers"
20+
self.kv_table_name = "__uipath_runtime_kv"
2421
self._initialized = False
2522

2623
async def _ensure_table(self) -> None:
27-
"""Create table if needed."""
24+
"""Create tables if needed."""
2825
if self._initialized:
2926
return
3027

3128
await self.memory.setup()
3229
async with self.memory.lock, self.memory.conn.cursor() as cur:
33-
await cur.execute(f"""
34-
CREATE TABLE IF NOT EXISTS {self.table_name} (
30+
# Enable WAL mode for high concurrency
31+
await cur.execute("PRAGMA journal_mode=WAL")
32+
33+
await cur.execute(
34+
f"""
35+
CREATE TABLE IF NOT EXISTS {self.rs_table_name} (
3536
id INTEGER PRIMARY KEY AUTOINCREMENT,
36-
type TEXT NOT NULL,
37-
name TEXT NOT NULL,
38-
key TEXT,
39-
folder_key TEXT,
40-
folder_path TEXT,
41-
payload TEXT,
37+
runtime_id TEXT NOT NULL,
38+
interrupt_id TEXT NOT NULL,
39+
data TEXT NOT NULL,
4240
timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc'))
4341
)
44-
""")
45-
await self.memory.conn.commit()
46-
self._initialized = True
42+
"""
43+
)
4744

48-
async def save_trigger(self, trigger: UiPathResumeTrigger) -> None:
49-
"""Save resume trigger to database."""
50-
await self._ensure_table()
45+
await cur.execute(
46+
f"""
47+
CREATE INDEX IF NOT EXISTS idx_{self.rs_table_name}_runtime_id
48+
ON {self.rs_table_name}(runtime_id)
49+
"""
50+
)
5151

52-
trigger_key = (
53-
trigger.api_resume.inbox_id if trigger.api_resume else trigger.item_key
54-
)
55-
payload = trigger.payload
56-
if payload:
57-
payload = (
58-
(
59-
payload.model_dump()
60-
if isinstance(payload, BaseModel)
61-
else json.dumps(payload)
52+
await cur.execute(
53+
f"""
54+
CREATE TABLE IF NOT EXISTS {self.kv_table_name} (
55+
runtime_id TEXT NOT NULL,
56+
namespace TEXT NOT NULL,
57+
key TEXT NOT NULL,
58+
value TEXT,
59+
timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')),
60+
PRIMARY KEY (runtime_id, namespace, key)
6261
)
63-
if isinstance(payload, dict)
64-
else str(payload)
62+
"""
6563
)
6664

65+
await self.memory.conn.commit()
66+
67+
self._initialized = True
68+
69+
async def save_triggers(
70+
self, runtime_id: str, triggers: list[UiPathResumeTrigger]
71+
) -> None:
72+
"""Save resume triggers to database, replacing all existing triggers for this runtime_id."""
73+
await self._ensure_table()
74+
6775
async with self.memory.lock, self.memory.conn.cursor() as cur:
76+
# Delete all existing triggers for this runtime_id
6877
await cur.execute(
69-
f"INSERT INTO {self.table_name} (type, key, name, payload, folder_path, folder_key) VALUES (?, ?, ?, ?, ?, ?)",
70-
(
71-
trigger.trigger_type.value,
72-
trigger_key,
73-
trigger.trigger_name.value,
74-
payload,
75-
trigger.folder_path,
76-
trigger.folder_key,
77-
),
78+
f"""
79+
DELETE FROM {self.rs_table_name}
80+
WHERE runtime_id = ?
81+
""",
82+
(runtime_id,),
7883
)
84+
85+
# Insert new triggers
86+
for trigger in triggers:
87+
trigger_data = trigger.model_dump()
88+
trigger_data["payload"] = trigger.payload
89+
trigger_data["trigger_name"] = trigger.trigger_name
90+
91+
await cur.execute(
92+
f"""
93+
INSERT INTO {self.rs_table_name}
94+
(runtime_id, interrupt_id, data)
95+
VALUES (?, ?, ?)
96+
""",
97+
(
98+
runtime_id,
99+
trigger.interrupt_id,
100+
json.dumps(trigger_data),
101+
),
102+
)
79103
await self.memory.conn.commit()
80104

81-
async def get_latest_trigger(self) -> UiPathResumeTrigger | None:
82-
"""Get most recent trigger from database."""
105+
async def get_triggers(self, runtime_id: str) -> list[UiPathResumeTrigger] | None:
106+
"""Get all triggers for runtime_id from database."""
83107
await self._ensure_table()
84108

85109
async with self.memory.lock, self.memory.conn.cursor() as cur:
86-
await cur.execute(f"""
87-
SELECT type, key, name, folder_path, folder_key, payload
88-
FROM {self.table_name}
89-
ORDER BY timestamp DESC
90-
LIMIT 1
91-
""")
92-
result = await cur.fetchone()
110+
await cur.execute(
111+
f"""
112+
SELECT data
113+
FROM {self.rs_table_name}
114+
WHERE runtime_id = ?
115+
ORDER BY timestamp ASC
116+
""",
117+
(runtime_id,),
118+
)
119+
results = await cur.fetchall()
120+
121+
if not results:
122+
return None
93123

94-
if not result:
95-
return None
124+
triggers = []
125+
for result in results:
126+
data_text = cast(str, result[0])
127+
trigger = UiPathResumeTrigger.model_validate_json(data_text)
128+
triggers.append(trigger)
96129

97-
trigger_type, key, name, folder_path, folder_key, payload = cast(
98-
tuple[str, str, str, str, str, str], tuple(result)
130+
return triggers
131+
132+
async def delete_trigger(
133+
self, runtime_id: str, trigger: UiPathResumeTrigger
134+
) -> None:
135+
"""Delete resume trigger from storage."""
136+
await self._ensure_table()
137+
138+
async with self.memory.lock, self.memory.conn.cursor() as cur:
139+
await cur.execute(
140+
f"""
141+
DELETE FROM {self.rs_table_name}
142+
WHERE runtime_id = ? AND interrupt_id = ?
143+
""",
144+
(
145+
runtime_id,
146+
trigger.interrupt_id,
147+
),
99148
)
149+
await self.memory.conn.commit()
150+
151+
async def set_value(
152+
self,
153+
runtime_id: str,
154+
namespace: str,
155+
key: str,
156+
value: Any,
157+
) -> None:
158+
"""Save arbitrary key-value pair to database."""
159+
if not (
160+
isinstance(value, str)
161+
or isinstance(value, dict)
162+
or isinstance(value, BaseModel)
163+
or value is None
164+
):
165+
raise TypeError("Value must be str, dict, BaseModel or None.")
166+
167+
await self._ensure_table()
100168

101-
resume_trigger = UiPathResumeTrigger(
102-
trigger_type=UiPathResumeTriggerType(trigger_type),
103-
trigger_name=UiPathResumeTriggerName(name),
104-
item_key=key,
105-
folder_path=folder_path,
106-
folder_key=folder_key,
107-
payload=payload,
169+
value_text = self._dump_value(value)
170+
171+
async with self.memory.lock, self.memory.conn.cursor() as cur:
172+
await cur.execute(
173+
f"""
174+
INSERT INTO {self.kv_table_name} (runtime_id, namespace, key, value)
175+
VALUES (?, ?, ?, ?)
176+
ON CONFLICT(runtime_id, namespace, key)
177+
DO UPDATE SET
178+
value = excluded.value,
179+
timestamp = (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc'))
180+
""",
181+
(runtime_id, namespace, key, value_text),
108182
)
183+
await self.memory.conn.commit()
109184

110-
if resume_trigger.trigger_type == UiPathResumeTriggerType.API:
111-
resume_trigger.api_resume = UiPathApiTrigger(
112-
inbox_id=resume_trigger.item_key, request=resume_trigger.payload
113-
)
185+
async def get_value(self, runtime_id: str, namespace: str, key: str) -> Any:
186+
"""Get arbitrary key-value pair from database (scoped by runtime_id + namespace)."""
187+
await self._ensure_table()
114188

115-
return resume_trigger
189+
async with self.memory.lock, self.memory.conn.cursor() as cur:
190+
await cur.execute(
191+
f"""
192+
SELECT value
193+
FROM {self.kv_table_name}
194+
WHERE runtime_id = ? AND namespace = ? AND key = ?
195+
LIMIT 1
196+
""",
197+
(runtime_id, namespace, key),
198+
)
199+
row = await cur.fetchone()
200+
201+
if not row:
202+
return None
203+
204+
return self._load_value(cast(str | None, row[0]))
205+
206+
def _dump_value(self, value: str | dict[str, Any] | BaseModel | None) -> str | None:
207+
if value is None:
208+
return None
209+
if isinstance(value, BaseModel):
210+
return "j:" + json.dumps(value.model_dump())
211+
if isinstance(value, dict):
212+
return "j:" + json.dumps(value)
213+
return "s:" + value
214+
215+
def _load_value(self, raw: str | None) -> Any:
216+
if raw is None:
217+
return None
218+
if raw.startswith("s:"):
219+
return raw[2:]
220+
if raw.startswith("j:"):
221+
return json.loads(raw[2:])
222+
return raw

0 commit comments

Comments
 (0)