Skip to content

Commit 9353d2b

Browse files
committed
fix: Interrupt for low coded CAS agents [JAR-9208]
1 parent ad49d48 commit 9353d2b

File tree

10 files changed

+675
-30
lines changed

10 files changed

+675
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.8.22"
3+
version = "0.8.23"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/tools/tool_factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
LowCodeAgentDefinition,
1717
)
1818

19+
from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION
20+
1921
from .context_tool import create_context_tool
2022
from .escalation_tool import create_escalation_tool
2123
from .extraction_tool import create_ixp_extraction_tool
@@ -54,6 +56,15 @@ async def create_tools_from_resources(
5456
else:
5557
tools.append(tool)
5658

59+
if agent.is_conversational:
60+
props = getattr(resource, "properties", None)
61+
if props and getattr(
62+
props, REQUIRE_CONVERSATIONAL_CONFIRMATION, False
63+
):
64+
if tool.metadata is None:
65+
tool.metadata = {}
66+
tool.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True
67+
5768
return tools
5869

5970

src/uipath_langchain/agent/tools/tool_node.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
extract_current_tool_call_index,
2323
find_latest_ai_message,
2424
)
25+
from uipath_langchain.chat.hitl import request_tool_confirmation
2526

2627
# the type safety can be improved with generics
2728
ToolWrapperReturnType = dict[str, Any] | Command[Any] | None
@@ -80,6 +81,13 @@ def _func(self, state: AgentGraphState) -> OutputType:
8081
if call is None:
8182
return None
8283

84+
# prompt user for approval if tool requires confirmation
85+
confirmation = request_tool_confirmation(call, self.tool)
86+
87+
# user rejected the tool call
88+
if confirmation is not None and confirmation.cancelled:
89+
return self._process_result(call, confirmation.cancelled)
90+
8391
try:
8492
if self.wrapper:
8593
inputs = self._prepare_wrapper_inputs(
@@ -88,7 +96,11 @@ def _func(self, state: AgentGraphState) -> OutputType:
8896
result = self.wrapper(*inputs)
8997
else:
9098
result = self.tool.invoke(call)
91-
return self._process_result(call, result)
99+
output = self._process_result(call, result)
100+
# HITL approved - apply confirmation metadata to tool result message
101+
if confirmation is not None:
102+
confirmation.annotate_result(output)
103+
return output
92104
except GraphBubbleUp:
93105
# LangGraph uses exceptions for interrupt control flow — re-raise so
94106
# handle_tool_errors doesn't swallow expected interrupts as errors.
@@ -104,6 +116,13 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
104116
if call is None:
105117
return None
106118

119+
# prompt user for approval if tool requires confirmation
120+
confirmation = request_tool_confirmation(call, self.tool)
121+
122+
# user rejected the tool call
123+
if confirmation is not None and confirmation.cancelled:
124+
return self._process_result(call, confirmation.cancelled)
125+
107126
try:
108127
if self.awrapper:
109128
inputs = self._prepare_wrapper_inputs(
@@ -112,7 +131,11 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
112131
result = await self.awrapper(*inputs)
113132
else:
114133
result = await self.tool.ainvoke(call)
115-
return self._process_result(call, result)
134+
output = self._process_result(call, result)
135+
# HITL approved - apply confirmation metadata to tool result message
136+
if confirmation is not None:
137+
confirmation.annotate_result(output)
138+
return output
116139
except GraphBubbleUp:
117140
# LangGraph uses exceptions for interrupt control flow — re-raise so
118141
# handle_tool_errors doesn't swallow expected interrupts as errors.

src/uipath_langchain/chat/hitl.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,57 @@
11
import functools
22
import inspect
3+
import json
34
from inspect import Parameter
4-
from typing import Annotated, Any, Callable
5+
from typing import Annotated, Any, Callable, NamedTuple
56

7+
from langchain_core.messages.tool import ToolCall, ToolMessage
68
from langchain_core.tools import BaseTool, InjectedToolCallId
79
from langchain_core.tools import tool as langchain_tool
810
from langgraph.types import interrupt
911
from uipath.core.chat import (
1012
UiPathConversationToolCallConfirmationValue,
1113
)
1214

13-
_CANCELLED_MESSAGE = "Cancelled by user"
15+
CANCELLED_MESSAGE = "Cancelled by user"
16+
17+
CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args"
18+
REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation"
19+
20+
21+
class ConfirmationResult(NamedTuple):
22+
"""Result of a tool confirmation check."""
23+
24+
cancelled: ToolMessage | None # ToolMessage if cancelled, None if approved
25+
args_modified: bool
26+
approved_args: dict[str, Any] | None = None
27+
28+
def annotate_result(self, output: dict[str, Any] | Any) -> None:
29+
"""Apply confirmation metadata to a tool result message."""
30+
msg = None
31+
if isinstance(output, dict):
32+
messages = output.get("messages")
33+
if messages:
34+
msg = messages[0]
35+
if msg is None:
36+
return
37+
if self.approved_args is not None:
38+
msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
39+
self.approved_args
40+
)
41+
if self.args_modified:
42+
try:
43+
result_value = json.loads(msg.content)
44+
except (json.JSONDecodeError, TypeError):
45+
result_value = msg.content
46+
msg.content = json.dumps(
47+
{
48+
"meta": {
49+
"args_modified_by_user": True,
50+
"executed_args": self.approved_args,
51+
},
52+
"result": result_value,
53+
}
54+
)
1455

1556

1657
def _patch_span_input(approved_args: dict[str, Any]) -> None:
@@ -53,7 +94,7 @@ def _patch_span_input(approved_args: dict[str, Any]) -> None:
5394
pass
5495

5596

56-
def _request_approval(
97+
def request_approval(
5798
tool_args: dict[str, Any],
5899
tool: BaseTool,
59100
) -> dict[str, Any] | None:
@@ -89,7 +130,41 @@ def _request_approval(
89130
if not confirmation.get("approved", True):
90131
return None
91132

92-
return confirmation.get("input") or tool_args
133+
return (
134+
confirmation.get("input")
135+
if confirmation.get("input") is not None
136+
else tool_args
137+
)
138+
139+
140+
def request_tool_confirmation(
141+
call: ToolCall, tool: BaseTool
142+
) -> ConfirmationResult | None:
143+
"""Check whether a tool requires user confirmation and request approval"""
144+
if not (tool.metadata and tool.metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION)):
145+
return None
146+
147+
original_args = call["args"]
148+
approved_args = request_approval(
149+
{**original_args, "tool_call_id": call["id"]}, tool
150+
)
151+
if approved_args is None:
152+
cancelled_msg = ToolMessage(
153+
content=CANCELLED_MESSAGE,
154+
name=call["name"],
155+
tool_call_id=call["id"],
156+
)
157+
cancelled_msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
158+
original_args
159+
)
160+
return ConfirmationResult(cancelled=cancelled_msg, args_modified=False)
161+
# Mutate call args so the tool executes with the approved values
162+
call["args"] = approved_args
163+
return ConfirmationResult(
164+
cancelled=None,
165+
args_modified=approved_args != original_args,
166+
approved_args=approved_args,
167+
)
93168

94169

95170
def requires_approval(
@@ -107,9 +182,9 @@ def decorator(fn: Callable[..., Any]) -> BaseTool:
107182
# wrap the tool/function
108183
@functools.wraps(fn)
109184
def wrapper(**tool_args: Any) -> Any:
110-
approved_args = _request_approval(tool_args, _created_tool[0])
185+
approved_args = request_approval(tool_args, _created_tool[0])
111186
if approved_args is None:
112-
return _CANCELLED_MESSAGE
187+
return {"meta": CANCELLED_MESSAGE}
113188
_patch_span_input(approved_args)
114189
return fn(**approved_args)
115190

src/uipath_langchain/runtime/messages.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
)
4040
from uipath.runtime import UiPathRuntimeStorageProtocol
4141

42+
from uipath_langchain.chat.hitl import CONVERSATIONAL_APPROVED_TOOL_ARGS
43+
4244
from ._citations import CitationStreamProcessor, extract_citations_from_text
4345

4446
logger = logging.getLogger(__name__)
@@ -58,6 +60,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
5860
"""Initialize the mapper with empty state."""
5961
self.runtime_id = runtime_id
6062
self.storage = storage
63+
self.tool_names_requiring_confirmation: set[str] = set()
6164
self.current_message: AIMessageChunk
6265
self.seen_message_ids: set[str] = set()
6366
self._storage_lock = asyncio.Lock()
@@ -389,11 +392,17 @@ async def map_current_message_to_start_tool_call_events(self):
389392
tool_call_id_to_message_id_map[tool_call_id] = (
390393
self.current_message.id
391394
)
392-
events.append(
393-
self.map_tool_call_to_tool_call_start_event(
394-
self.current_message.id, tool_call
395+
396+
# if tool requires confirmation, we skip start tool call
397+
if (
398+
tool_call["name"]
399+
not in self.tool_names_requiring_confirmation
400+
):
401+
events.append(
402+
self.map_tool_call_to_tool_call_start_event(
403+
self.current_message.id, tool_call
404+
)
395405
)
396-
)
397406

398407
if self.storage is not None:
399408
await self.storage.set_value(
@@ -426,7 +435,19 @@ async def map_tool_message_to_events(
426435
# Keep as string if not valid JSON
427436
pass
428437

429-
events = [
438+
events: list[UiPathConversationMessageEvent] = []
439+
440+
# emit startToolCall for tools requiring confirmation after it's approved
441+
approved_args = message.response_metadata.get(CONVERSATIONAL_APPROVED_TOOL_ARGS)
442+
if approved_args is not None:
443+
tool_call = ToolCall(
444+
name=message.name or "", args=approved_args, id=message.tool_call_id
445+
)
446+
events.append(
447+
self.map_tool_call_to_tool_call_start_event(message_id, tool_call)
448+
)
449+
450+
events.append(
430451
UiPathConversationMessageEvent(
431452
message_id=message_id,
432453
tool_call=UiPathConversationToolCallEvent(
@@ -438,7 +459,7 @@ async def map_tool_message_to_events(
438459
),
439460
),
440461
)
441-
]
462+
)
442463

443464
if is_last_tool_call:
444465
events.append(self.map_to_message_end_event(message_id))
@@ -665,7 +686,7 @@ def _map_langchain_ai_message_to_uipath_message_data(
665686
role="assistant",
666687
content_parts=content_parts,
667688
tool_calls=uipath_tool_calls,
668-
interrupts=[], # TODO: Interrupts
689+
interrupts=[],
669690
)
670691

671692

src/uipath_langchain/runtime/runtime.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from uipath.runtime.schema import UiPathRuntimeSchema
3131

32+
from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION
3233
from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
3334
from uipath_langchain.runtime.messages import UiPathChatMessagesMapper
3435
from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema
@@ -64,6 +65,9 @@ def __init__(
6465
self.entrypoint: str | None = entrypoint
6566
self.callbacks: list[BaseCallbackHandler] = callbacks or []
6667
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
68+
self.chat.tool_names_requiring_confirmation = (
69+
self._get_tool_names_requiring_confirmation()
70+
)
6771
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
6872

6973
async def execute(
@@ -486,6 +490,18 @@ def _detect_middleware_nodes(self) -> set[str]:
486490

487491
return middleware_nodes
488492

493+
def _get_tool_names_requiring_confirmation(self) -> set[str]:
494+
names: set[str] = set()
495+
for node_name, node_spec in self.graph.nodes.items():
496+
# langgraph's processing node.bound -> runnable.tool -> baseTool (if tool node)
497+
tool = getattr(getattr(node_spec, "bound", None), "tool", None)
498+
if tool is None:
499+
continue
500+
metadata = getattr(tool, "metadata", None) or {}
501+
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
502+
names.add(getattr(tool, "name", node_name))
503+
return names
504+
489505
def _is_middleware_node(self, node_name: str) -> bool:
490506
"""Check if a node name represents a middleware node."""
491507
return node_name in self._middleware_node_names

0 commit comments

Comments
 (0)