-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathinit_node.py
More file actions
104 lines (92 loc) · 4.44 KB
/
Copy pathinit_node.py
File metadata and controls
104 lines (92 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""State initialization node for the ReAct Agent graph."""
from typing import Any, Callable, Sequence
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.types import Overwrite
from pydantic import BaseModel
from uipath_langchain.agent.tools.client_side_tool import (
UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY,
ClientSideToolInfo,
apply_tool_filter,
available_client_side_tools,
)
from .job_attachments import (
get_job_attachments,
parse_attachments_from_conversation_messages,
)
def create_init_node(
messages: Sequence[SystemMessage | HumanMessage]
| Callable[[Any], Sequence[SystemMessage | HumanMessage]],
input_schema: type[BaseModel] | None,
is_conversational: bool = False,
client_side_tools: dict[str, ClientSideToolInfo] | None = None,
):
def graph_state_init(state: Any) -> Any:
resolved_messages: Sequence[SystemMessage | HumanMessage] | Overwrite
preserved_messages: Sequence[Any] = []
if callable(messages):
resolved_messages = list(messages(state))
else:
resolved_messages = list(messages)
# Append memory injection from the MEMORY_RECALL node (if present)
memory_injection = ""
if hasattr(state, "inner_state") and hasattr(
state.inner_state, "memory_injection"
):
memory_injection = state.inner_state.memory_injection or ""
if memory_injection and resolved_messages:
first = resolved_messages[0]
if isinstance(first, SystemMessage):
resolved_messages[0] = SystemMessage(
content=str(first.content) + memory_injection
)
if is_conversational:
# For conversational agents we need to reorder the messages so that the system message is first, followed by
# the initial user message. When resuming the conversation, the state will have the entire message history,
# including the system message. In this case, we need to replace the system message from the state with the
# newly generated one. It will have the current date/time and reflect any changes to user settings. The add
# reducer is used for the messages property in the state, so by default new messages are appended to the end
# and using Overwrite will cause LangGraph to replace the entire array instead.
if len(state.messages) > 0 and isinstance(state.messages[0], SystemMessage):
preserved_messages = state.messages[1:]
else:
preserved_messages = state.messages
resolved_messages = Overwrite([*resolved_messages, *preserved_messages])
schema = input_schema if input_schema is not None else BaseModel
job_attachments = get_job_attachments(schema, state)
job_attachments_dict = {
str(att.id): att for att in job_attachments if att.id is not None
}
# Merge attachments from preserved messages for conversational agents
if is_conversational:
message_attachments = parse_attachments_from_conversation_messages(
preserved_messages
)
job_attachments_dict.update(message_attachments)
# Filter available client-side tools based on exchange input declarations
if client_side_tools:
client_tools_input = getattr(
state, UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY, None
)
if client_tools_input is None:
available_client_side_tools.set(None)
elif not isinstance(client_tools_input, list):
raise ValueError(
f"'{UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY}' must be a list of tool names, "
f"got {type(client_tools_input).__name__}."
)
else:
apply_tool_filter(client_tools_input, client_side_tools)
# Calculate initial message count for tracking new messages
initial_message_count = (
len(resolved_messages.value)
if isinstance(resolved_messages, Overwrite)
else len(resolved_messages)
)
return {
"messages": resolved_messages,
"inner_state": {
"job_attachments": job_attachments_dict,
"initial_message_count": initial_message_count,
},
}
return graph_state_init