Skip to content

Commit a702176

Browse files
committed
fix: resume trigger
1 parent 05be23d commit a702176

3 files changed

Lines changed: 50 additions & 30 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-llamaindex"
3-
version = "0.0.17"
3+
version = "0.0.18"
44
description = "UiPath LlamaIndex SDK"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.10"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

3-
from llama_index.core.workflow import Workflow
4-
from uipath._cli._runtime._contracts import UiPathRuntimeContext
3+
from llama_index.core.workflow import Context, Workflow
4+
from uipath._cli._runtime._contracts import UiPathResumeTrigger, UiPathRuntimeContext
55

66
from .._utils._config import LlamaIndexConfig
77

@@ -11,3 +11,5 @@ class UiPathLlamaIndexRuntimeContext(UiPathRuntimeContext):
1111

1212
config: Optional[LlamaIndexConfig] = None
1313
workflow: Optional[Workflow] = None
14+
workflow_context: Optional[Context] = None
15+
resume_trigger: Optional[UiPathResumeTrigger] = None

src/uipath_llamaindex/_cli/_runtime/_runtime.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,31 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
7070
start_event_class = self.context.workflow._start_event_class
7171
ev = start_event_class(**self.context.input_json)
7272

73-
ctx: Context = await self._get_context()
73+
await self.load_context()
7474

7575
handler: WorkflowHandler = self.context.workflow.run(
76-
start_event=ev, ctx=ctx, **self.context.input_json
76+
start_event=ev,
77+
ctx=self.context.workflow_context,
78+
**self.context.input_json,
7779
)
7880

7981
resume_trigger: UiPathResumeTrigger = None
8082

83+
resume_applied = False
8184
async for event in handler.stream_events():
8285
if isinstance(event, InputRequiredEvent):
83-
resume_trigger = UiPathResumeTrigger(
84-
api_resume=UiPathApiTrigger(
85-
inbox_id=str(uuid.uuid4()), request=event.prefix
86+
if self.context.resume and not resume_applied:
87+
# If we are resuming, we need to apply the resume trigger to the event stream.
88+
resume_applied = True
89+
self.context.workflow_context.send_event(
90+
await self.get_resume_event()
91+
)
92+
else:
93+
resume_trigger = UiPathResumeTrigger(
94+
api_resume=UiPathApiTrigger(
95+
inbox_id=str(uuid.uuid4()), request=event.prefix
96+
)
8697
)
87-
)
8898
break
8999
print(event)
90100

@@ -102,7 +112,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
102112

103113
if self.state_file_path:
104114
serializer = JsonPickleSerializer()
105-
ctx_dict = ctx.to_dict(serializer=serializer)
115+
ctx_dict = self.context.workflow_context.to_dict(serializer=serializer)
106116
ctx_dict["uipath_resume_trigger"] = (
107117
serializer.serialize(resume_trigger) if resume_trigger else None
108118
)
@@ -212,45 +222,53 @@ async def cleanup(self) -> None:
212222
"""Clean up all resources."""
213223
pass
214224

215-
async def _get_context(self) -> Context:
225+
async def load_context(self):
216226
"""
217-
Get the context for the LlamaIndex agent.
218-
219-
Returns:
220-
The context object for the LlamaIndex agent.
227+
Load the context for the LlamaIndex agent.
221228
"""
222229
logger.debug(f"Resumed: {self.context.resume} Input: {self.context.input_json}")
223230

224231
if not self.context.resume:
225-
return Context(self.context.workflow)
232+
self.context.workflow_context = Context(self.context.workflow)
233+
return
226234

227235
if not self.state_file_path or not os.path.exists(self.state_file_path):
228-
return Context(self.context.workflow)
236+
self.context.workflow_context = Context(self.context.workflow)
237+
return
229238

230239
serializer = JsonPickleSerializer()
231-
ctx: Context = None
232240

233241
with open(self.state_file_path, "rb") as f:
234242
loaded_ctx_dict = pickle.load(f)
235-
ctx = Context.from_dict(
243+
self.context.workflow_context = Context.from_dict(
236244
self.context.workflow,
237245
loaded_ctx_dict,
238246
serializer=serializer,
239247
)
240248

241-
if self.context.input_json:
242-
ctx.send_event(HumanResponseEvent(response=self.context.input_json))
249+
resumed_trigger_data = loaded_ctx_dict["uipath_resume_trigger"]
250+
if resumed_trigger_data:
251+
self.context.resume_trigger = cast(
252+
UiPathResumeTrigger, serializer.deserialize(resumed_trigger_data)
253+
)
243254

244-
resumed_trigger_data = loaded_ctx_dict["uipath_resume_trigger"]
245-
if resumed_trigger_data:
246-
resumed_trigger = cast(
247-
UiPathResumeTrigger, serializer.deserialize(resumed_trigger_data)
248-
)
249-
inbox_id = resumed_trigger.api_resume.inbox_id
250-
payload = await self._get_api_payload(inbox_id)
251-
ctx.send_event(HumanResponseEvent(response=payload))
255+
async def get_resume_event(self) -> Optional[HumanResponseEvent]:
256+
"""
257+
Get the resume event for the LlamaIndex agent.
252258
253-
return ctx
259+
Returns:
260+
The resume event if available, otherwise None.
261+
"""
262+
if self.context.input_json:
263+
# If input_json is provided, use it to create a HumanResponseEvent
264+
return HumanResponseEvent(response=self.context.input_json)
265+
# If resume_trigger is set, fetch the payload from the API
266+
if self.context.resume_trigger:
267+
inbox_id = self.context.resume_trigger.api_resume.inbox_id
268+
payload = await self._get_api_payload(inbox_id)
269+
if payload:
270+
return HumanResponseEvent(response=payload)
271+
return None
254272

255273
async def _get_api_payload(self, inbox_id: str) -> Any:
256274
"""

0 commit comments

Comments
 (0)