Skip to content

Commit 4b9d70d

Browse files
fix messages when resuming
1 parent 9a0abeb commit 4b9d70d

3 files changed

Lines changed: 95 additions & 5 deletions

File tree

src/uipath_langchain/agent/react/init_node.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ def graph_state_init(state: Any) -> Any:
2525
resolved_messages = list(messages)
2626
if is_conversational:
2727
# For conversational agents we need to reorder the messages so that the system message is first, followed by
28-
# the initial user message. The initial user message is put in the state by UiPathLangGraphRuntime. The add
29-
# reducer is used for the messages property in the state, so by default new messages are appended to the end.
30-
resolved_messages = Overwrite([*resolved_messages, *state.messages])
28+
# the initial user message. When resuming the conversation, the state will have the entire message history,
29+
# including the system message. In this case, we need to replace the system message from the state with the
30+
# newly generated one. It will have the current date/time and reflect any changes to user settings. The add
31+
# reducer is used for the messages property in the state, so by default new messages are appended to the end
32+
# and using Overwrite will cause LangGraph to replace the entire array instead.
33+
if len(state.messages) > 0 and isinstance(state.messages[0], SystemMessage):
34+
preserved_messages = state.messages[1:]
35+
else:
36+
preserved_messages = state.messages
37+
resolved_messages = Overwrite([*resolved_messages, *preserved_messages])
3138

3239
schema = input_schema if input_schema is not None else BaseModel
3340
job_attachments = get_job_attachments(schema, state)

src/uipath_langchain/agent/react/router_conversational.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def create_route_agent_conversational():
2323

2424
def route_agent_conversational(
2525
state: AgentGraphState,
26-
) -> list[str] | Literal[AgentGraphNode.USER_MESSAGE_WAIT]:
26+
) -> list[str] | Literal[AgentGraphNode.TERMINATE]:
2727
"""Route after agent
2828
2929
Routing logic:
@@ -41,6 +41,6 @@ def route_agent_conversational(
4141
if last_message.tool_calls:
4242
return [tc["name"] for tc in last_message.tool_calls]
4343
else:
44-
return AgentGraphNode.USER_MESSAGE_WAIT
44+
return AgentGraphNode.TERMINATE
4545

4646
return route_agent_conversational
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Termination node for the Agent graph."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, NoReturn
6+
7+
from langchain_core.messages import AIMessage
8+
from pydantic import BaseModel
9+
from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL
10+
from uipath.runtime.errors import UiPathErrorCode
11+
12+
from ..exceptions import (
13+
AgentNodeRoutingException,
14+
AgentTerminationException,
15+
)
16+
from .types import AgentGraphState, AgentTermination
17+
18+
19+
def _handle_end_execution(
20+
args: dict[str, Any], response_schema: type[BaseModel] | None
21+
) -> dict[str, Any]:
22+
"""Handle LLM-initiated termination via END_EXECUTION_TOOL."""
23+
output_schema = response_schema or END_EXECUTION_TOOL.args_schema
24+
validated = output_schema.model_validate(args)
25+
return validated.model_dump()
26+
27+
28+
def _handle_raise_error(args: dict[str, Any]) -> NoReturn:
29+
"""Handle LLM-initiated error via RAISE_ERROR_TOOL."""
30+
error_message = args.get("message", "The LLM did not set the error message")
31+
detail = args.get("details", "")
32+
raise AgentTerminationException(
33+
code=UiPathErrorCode.EXECUTION_ERROR,
34+
title=error_message,
35+
detail=detail,
36+
)
37+
38+
39+
def _handle_agent_termination(termination: AgentTermination) -> NoReturn:
40+
"""Handle Command-based termination."""
41+
raise AgentTerminationException(
42+
code=UiPathErrorCode.EXECUTION_ERROR,
43+
title=termination.title,
44+
detail=termination.detail,
45+
)
46+
47+
48+
def create_terminate_node(
49+
response_schema: type[BaseModel] | None = None, is_conversational: bool = False
50+
):
51+
"""Handles Agent Graph termination for multiple sources and output or error propagation to Orchestrator.
52+
53+
Termination scenarios:
54+
1. Command based termination with information in state (e.g: escalation)
55+
2. LLM-initiated termination (END_EXECUTION_TOOL)
56+
3. LLM-initiated error (RAISE_ERROR_TOOL)
57+
"""
58+
59+
def terminate_node(state: AgentGraphState):
60+
if state.inner_state.termination:
61+
_handle_agent_termination(state.inner_state.termination)
62+
63+
if not is_conversational:
64+
last_message = state.messages[-1]
65+
if not isinstance(last_message, AIMessage):
66+
raise AgentNodeRoutingException(
67+
f"Expected last message to be AIMessage, got {type(last_message).__name__}"
68+
)
69+
70+
for tool_call in last_message.tool_calls:
71+
tool_name = tool_call["name"]
72+
73+
if tool_name == END_EXECUTION_TOOL.name:
74+
return _handle_end_execution(tool_call["args"], response_schema)
75+
76+
if tool_name == RAISE_ERROR_TOOL.name:
77+
_handle_raise_error(tool_call["args"])
78+
79+
raise AgentNodeRoutingException(
80+
"No control flow tool call found in terminate node. Unexpected state."
81+
)
82+
83+
return terminate_node

0 commit comments

Comments
 (0)