Skip to content

Commit 384ac13

Browse files
feat: add storage and tests
1 parent 800ebc4 commit 384ac13

5 files changed

Lines changed: 2179 additions & 1333 deletions

File tree

packages/uipath-llamaindex/src/uipath_llamaindex/runtime/chat/messages.py

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from datetime import datetime, timezone
34
from typing import Any
@@ -22,23 +23,36 @@
2223
UiPathConversationToolCallStartEvent,
2324
)
2425

26+
from uipath_llamaindex.runtime.storage import SqliteResumableStorage
27+
2528
logger = logging.getLogger(__name__)
2629

30+
STORAGE_NAMESPACE_EVENT_MAPPER = "chat-event-mapper"
31+
STORAGE_KEY_TOOL_ID_TO_MESSAGE_ID_MAP = "tool_id_map"
32+
2733

2834
class UiPathChatMessagesMapper:
2935
"""Stateful mapper that converts LlamaIndex agent events to UiPath message events.
3036
3137
Maintains state across events to properly track:
3238
- The current AI message ID (generated per agent turn, since LlamaIndex doesn't provide one)
3339
- Pending tool calls per message ID for correct message_end timing
40+
41+
When a storage backend is provided, the tool_id → message_id mapping is persisted
42+
so that it survives workflow suspension and can be correctly resolved on resume.
3443
"""
3544

36-
def __init__(self, runtime_id: str) -> None:
45+
def __init__(
46+
self,
47+
runtime_id: str,
48+
storage: SqliteResumableStorage | None = None,
49+
) -> None:
3750
self.runtime_id = runtime_id
51+
self.storage = storage
3852
self._current_message_id: str | None = None
39-
# message_id -> set of tool_ids still pending completion
53+
self._storage_lock = asyncio.Lock()
54+
# In-memory fallback state used when no storage is provided
4055
self._pending_tool_calls: dict[str, set[str]] = {}
41-
# tool_id -> message_id for correlating ToolCallResult with its parent AI message
4256
self._tool_id_to_message_id: dict[str, str] = {}
4357

4458
@staticmethod
@@ -96,14 +110,14 @@ async def map_event(
96110
return self._map_agent_stream(event)
97111

98112
if isinstance(event, AgentOutput):
99-
return self._map_agent_output(event)
113+
return await self._map_agent_output(event)
100114

101115
# ToolCall start is handled via AgentOutput to have the message_id available
102116
if isinstance(event, ToolCall):
103117
return None
104118

105119
if isinstance(event, ToolCallResult):
106-
return self._map_tool_call_result(event)
120+
return await self._map_tool_call_result(event)
107121

108122
return None
109123

@@ -124,7 +138,7 @@ def _map_agent_stream(
124138

125139
return events if events else None
126140

127-
def _map_agent_output(
141+
async def _map_agent_output(
128142
self, event: AgentOutput
129143
) -> list[UiPathConversationMessageEvent] | None:
130144
message_id = self._current_message_id
@@ -137,31 +151,58 @@ def _map_agent_output(
137151
events: list[UiPathConversationMessageEvent] = []
138152

139153
if event.tool_calls:
140-
# Emit a tool_call_start event for each tool call and track them as pending
141-
pending: set[str] = set()
142-
for tool_call in event.tool_calls:
143-
self._tool_id_to_message_id[tool_call.tool_id] = message_id
144-
pending.add(tool_call.tool_id)
145-
events.append(
146-
self._create_tool_call_start_event(
147-
message_id=message_id,
148-
tool_call_id=tool_call.tool_id,
149-
tool_name=tool_call.tool_name,
150-
input=tool_call.tool_kwargs,
154+
if self.storage is not None:
155+
async with self._storage_lock:
156+
existing: dict[str, str] | None = await self.storage.get_value(
157+
self.runtime_id,
158+
STORAGE_NAMESPACE_EVENT_MAPPER,
159+
STORAGE_KEY_TOOL_ID_TO_MESSAGE_ID_MAP,
151160
)
152-
)
153-
self._pending_tool_calls[message_id] = pending
161+
tool_id_to_message_id: dict[str, str] = existing or {}
162+
163+
for tool_call in event.tool_calls:
164+
tool_id_to_message_id[tool_call.tool_id] = message_id
165+
events.append(
166+
self._create_tool_call_start_event(
167+
message_id=message_id,
168+
tool_call_id=tool_call.tool_id,
169+
tool_name=tool_call.tool_name,
170+
input=tool_call.tool_kwargs,
171+
)
172+
)
173+
174+
await self.storage.set_value(
175+
self.runtime_id,
176+
STORAGE_NAMESPACE_EVENT_MAPPER,
177+
STORAGE_KEY_TOOL_ID_TO_MESSAGE_ID_MAP,
178+
tool_id_to_message_id,
179+
)
180+
else:
181+
# In-memory fallback (no suspend/resume support)
182+
pending: set[str] = set()
183+
for tool_call in event.tool_calls:
184+
self._tool_id_to_message_id[tool_call.tool_id] = message_id
185+
pending.add(tool_call.tool_id)
186+
events.append(
187+
self._create_tool_call_start_event(
188+
message_id=message_id,
189+
tool_call_id=tool_call.tool_id,
190+
tool_name=tool_call.tool_name,
191+
input=tool_call.tool_kwargs,
192+
)
193+
)
194+
self._pending_tool_calls[message_id] = pending
154195
# message_end will be emitted once the last ToolCallResult comes in
155196
else:
156197
# No tool calls: this is the final text response, close the message now
157198
events.append(self._create_message_end_event(message_id))
158199

159200
return events if events else None
160201

161-
def _map_tool_call_result(
202+
async def _map_tool_call_result(
162203
self, event: ToolCallResult
163204
) -> list[UiPathConversationMessageEvent] | None:
164-
message_id = self._tool_id_to_message_id.pop(event.tool_id, None)
205+
message_id, is_last = await self._get_message_id_for_tool_call(event.tool_id)
165206
if message_id is None:
166207
logger.warning(
167208
"ToolCallResult received for unknown tool_id '%s' — skipping.",
@@ -180,14 +221,70 @@ def _map_tool_call_result(
180221
]
181222

182223
# Close the message once all tool calls for it have completed
224+
if is_last:
225+
events.append(self._create_message_end_event(message_id))
226+
227+
return events
228+
229+
async def _get_message_id_for_tool_call(
230+
self, tool_id: str
231+
) -> tuple[str | None, bool]:
232+
"""Look up the message_id for a tool_id and remove it from the map.
233+
234+
Returns (message_id, is_last) where is_last is True when no other
235+
pending tool calls remain for the same message.
236+
"""
237+
if self.storage is not None:
238+
async with self._storage_lock:
239+
tool_id_to_message_id: dict[str, str] | None = (
240+
await self.storage.get_value(
241+
self.runtime_id,
242+
STORAGE_NAMESPACE_EVENT_MAPPER,
243+
STORAGE_KEY_TOOL_ID_TO_MESSAGE_ID_MAP,
244+
)
245+
)
246+
247+
if tool_id_to_message_id is None:
248+
logger.error(
249+
"attempt to lookup tool_id %s when no map present in storage",
250+
tool_id,
251+
)
252+
return None, False
253+
254+
message_id = tool_id_to_message_id.get(tool_id)
255+
if message_id is None:
256+
logger.error(
257+
"tool_id to message map does not contain tool_id %s",
258+
tool_id,
259+
)
260+
return None, False
261+
262+
del tool_id_to_message_id[tool_id]
263+
264+
await self.storage.set_value(
265+
self.runtime_id,
266+
STORAGE_NAMESPACE_EVENT_MAPPER,
267+
STORAGE_KEY_TOOL_ID_TO_MESSAGE_ID_MAP,
268+
tool_id_to_message_id,
269+
)
270+
271+
is_last = message_id not in tool_id_to_message_id.values()
272+
273+
return message_id, is_last
274+
275+
# In-memory fallback
276+
message_id = self._tool_id_to_message_id.pop(tool_id, None)
277+
if message_id is None:
278+
return None, False
279+
183280
pending = self._pending_tool_calls.get(message_id)
184281
if pending is not None:
185-
pending.discard(event.tool_id)
282+
pending.discard(tool_id)
186283
if not pending:
187284
del self._pending_tool_calls[message_id]
188-
events.append(self._create_message_end_event(message_id))
285+
return message_id, True
189286

190-
return events
287+
return message_id, False
191288

192289
# ── Factory helpers ────────────────────────────────────────────────────────
193290

packages/uipath-llamaindex/src/uipath_llamaindex/runtime/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ async def _run_workflow(
169169

170170
event_stream = handler.stream_events(expose_internal=True)
171171
suspended_event: InputRequiredEvent | None = None
172-
chat = UiPathChatMessagesMapper(runtime_id=self.runtime_id)
172+
chat = UiPathChatMessagesMapper(runtime_id=self.runtime_id, storage=self.storage)
173173

174174
is_resumed: bool = False
175175
async for event in event_stream:

0 commit comments

Comments
 (0)