Skip to content

Commit ebbf522

Browse files
committed
feat: add HITL confirmation tool deferral
1 parent 2126bba commit ebbf522

File tree

10 files changed

+639
-27
lines changed

10 files changed

+639
-27
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.8.6"
3+
version = "0.8.7"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/tools/tool_factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
LowCodeAgentDefinition,
1717
)
1818

19+
from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION
20+
1921
from .context_tool import create_context_tool
2022
from .escalation_tool import create_escalation_tool
2123
from .extraction_tool import create_ixp_extraction_tool
@@ -54,6 +56,15 @@ async def create_tools_from_resources(
5456
else:
5557
tools.append(tool)
5658

59+
if agent.is_conversational:
60+
props = getattr(resource, "properties", None)
61+
if props and getattr(
62+
props, REQUIRE_CONVERSATIONAL_CONFIRMATION, False
63+
):
64+
if tool.metadata is None:
65+
tool.metadata = {}
66+
tool.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True
67+
5768
return tools
5869

5970

src/uipath_langchain/agent/tools/tool_node.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
extract_current_tool_call_index,
2222
find_latest_ai_message,
2323
)
24+
from uipath_langchain.chat.hitl import check_tool_confirmation
2425

2526
# the type safety can be improved with generics
2627
ToolWrapperReturnType = dict[str, Any] | Command[Any] | None
@@ -79,6 +80,10 @@ def _func(self, state: AgentGraphState) -> OutputType:
7980
if call is None:
8081
return None
8182

83+
confirmation = check_tool_confirmation(call, self.tool)
84+
if confirmation is not None and confirmation.cancelled:
85+
return self._process_result(call, confirmation.cancelled)
86+
8287
try:
8388
if self.wrapper:
8489
inputs = self._prepare_wrapper_inputs(
@@ -87,7 +92,10 @@ def _func(self, state: AgentGraphState) -> OutputType:
8792
result = self.wrapper(*inputs)
8893
else:
8994
result = self.tool.invoke(call)
90-
return self._process_result(call, result)
95+
output = self._process_result(call, result)
96+
if confirmation is not None:
97+
confirmation.annotate_result(output)
98+
return output
9199
except Exception as e:
92100
if self.handle_tool_errors:
93101
return self._process_error_result(call, e)
@@ -98,6 +106,10 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
98106
if call is None:
99107
return None
100108

109+
confirmation = check_tool_confirmation(call, self.tool)
110+
if confirmation is not None and confirmation.cancelled:
111+
return self._process_result(call, confirmation.cancelled)
112+
101113
try:
102114
if self.awrapper:
103115
inputs = self._prepare_wrapper_inputs(
@@ -106,7 +118,10 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
106118
result = await self.awrapper(*inputs)
107119
else:
108120
result = await self.tool.ainvoke(call)
109-
return self._process_result(call, result)
121+
output = self._process_result(call, result)
122+
if confirmation is not None:
123+
confirmation.annotate_result(output)
124+
return output
110125
except Exception as e:
111126
if self.handle_tool_errors:
112127
return self._process_error_result(call, e)

src/uipath_langchain/chat/hitl.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,47 @@
11
import functools
22
import inspect
33
from inspect import Parameter
4-
from typing import Annotated, Any, Callable
4+
from typing import Annotated, Any, Callable, NamedTuple
55

6+
from langchain_core.messages.tool import ToolCall, ToolMessage
67
from langchain_core.tools import BaseTool, InjectedToolCallId
78
from langchain_core.tools import tool as langchain_tool
89
from langgraph.types import interrupt
910
from uipath.core.chat import (
1011
UiPathConversationToolCallConfirmationValue,
1112
)
1213

13-
_CANCELLED_MESSAGE = "Cancelled by user"
14+
CANCELLED_MESSAGE = "Cancelled by user"
15+
ARGS_MODIFIED_MESSAGE = "Tool arguments were modified by the user"
16+
17+
CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args"
18+
REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation"
19+
20+
21+
class ConfirmationResult(NamedTuple):
22+
"""Result of a tool confirmation check."""
23+
24+
cancelled: ToolMessage | None # ToolMessage if cancelled, None if approved
25+
args_modified: bool
26+
approved_args: dict[str, Any] | None = None
27+
28+
def annotate_result(self, output: dict[str, Any] | Any) -> None:
29+
"""Apply confirmation metadata to a tool result message."""
30+
msg = None
31+
if isinstance(output, dict):
32+
messages = output.get("messages")
33+
if messages:
34+
msg = messages[0]
35+
if msg is None:
36+
return
37+
if self.approved_args is not None:
38+
msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
39+
self.approved_args
40+
)
41+
if self.args_modified:
42+
msg.content = (
43+
f'{{"meta": "{ARGS_MODIFIED_MESSAGE}", "result": {msg.content}}}'
44+
)
1445

1546

1647
def _patch_span_input(approved_args: dict[str, Any]) -> None:
@@ -53,7 +84,7 @@ def _patch_span_input(approved_args: dict[str, Any]) -> None:
5384
pass
5485

5586

56-
def _request_approval(
87+
def request_approval(
5788
tool_args: dict[str, Any],
5889
tool: BaseTool,
5990
) -> dict[str, Any] | None:
@@ -89,7 +120,39 @@ def _request_approval(
89120
if not confirmation.get("approved", True):
90121
return None
91122

92-
return confirmation.get("input") or tool_args
123+
return (
124+
confirmation.get("input")
125+
if confirmation.get("input") is not None
126+
else tool_args
127+
)
128+
129+
130+
def check_tool_confirmation(
131+
call: ToolCall, tool: BaseTool
132+
) -> ConfirmationResult | None:
133+
if not (tool.metadata and tool.metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION)):
134+
return None
135+
136+
original_args = call["args"]
137+
approved_args = request_approval(
138+
{**original_args, "tool_call_id": call["id"]}, tool
139+
)
140+
if approved_args is None:
141+
cancelled_msg = ToolMessage(
142+
content=CANCELLED_MESSAGE,
143+
name=call["name"],
144+
tool_call_id=call["id"],
145+
)
146+
cancelled_msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
147+
original_args
148+
)
149+
return ConfirmationResult(cancelled=cancelled_msg, args_modified=False)
150+
call["args"] = approved_args
151+
return ConfirmationResult(
152+
cancelled=None,
153+
args_modified=approved_args != original_args,
154+
approved_args=approved_args,
155+
)
93156

94157

95158
def requires_approval(
@@ -107,9 +170,9 @@ def decorator(fn: Callable[..., Any]) -> BaseTool:
107170
# wrap the tool/function
108171
@functools.wraps(fn)
109172
def wrapper(**tool_args: Any) -> Any:
110-
approved_args = _request_approval(tool_args, _created_tool[0])
173+
approved_args = request_approval(tool_args, _created_tool[0])
111174
if approved_args is None:
112-
return _CANCELLED_MESSAGE
175+
return {"meta": CANCELLED_MESSAGE}
113176
_patch_span_input(approved_args)
114177
return fn(**approved_args)
115178

src/uipath_langchain/runtime/messages.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
)
4040
from uipath.runtime import UiPathRuntimeStorageProtocol
4141

42+
from uipath_langchain.chat.hitl import CONVERSATIONAL_APPROVED_TOOL_ARGS
43+
4244
from ._citations import CitationStreamProcessor, extract_citations_from_text
4345

4446
logger = logging.getLogger(__name__)
@@ -58,6 +60,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
5860
"""Initialize the mapper with empty state."""
5961
self.runtime_id = runtime_id
6062
self.storage = storage
63+
self.confirmation_tool_names: set[str] = set[str]()
6164
self.current_message: AIMessageChunk
6265
self.seen_message_ids: set[str] = set()
6366
self._storage_lock = asyncio.Lock()
@@ -389,11 +392,14 @@ async def map_current_message_to_start_tool_call_events(self):
389392
tool_call_id_to_message_id_map[tool_call_id] = (
390393
self.current_message.id
391394
)
392-
events.append(
393-
self.map_tool_call_to_tool_call_start_event(
394-
self.current_message.id, tool_call
395+
396+
if tool_call["name"] not in self.confirmation_tool_names:
397+
# defer tool call for HITL
398+
events.append(
399+
self.map_tool_call_to_tool_call_start_event(
400+
self.current_message.id, tool_call
401+
)
395402
)
396-
)
397403

398404
if self.storage is not None:
399405
await self.storage.set_value(
@@ -426,7 +432,19 @@ async def map_tool_message_to_events(
426432
# Keep as string if not valid JSON
427433
pass
428434

429-
events = [
435+
events: list[UiPathConversationMessageEvent] = []
436+
437+
# Emit deferred startToolCall for confirmation tools (skipped in Pass 1)
438+
approved_args = message.response_metadata.get(CONVERSATIONAL_APPROVED_TOOL_ARGS)
439+
if approved_args is not None:
440+
tool_call = ToolCall(
441+
name=message.name or "", args=approved_args, id=message.tool_call_id
442+
)
443+
events.append(
444+
self.map_tool_call_to_tool_call_start_event(message_id, tool_call)
445+
)
446+
447+
events.append(
430448
UiPathConversationMessageEvent(
431449
message_id=message_id,
432450
tool_call=UiPathConversationToolCallEvent(
@@ -438,7 +456,7 @@ async def map_tool_message_to_events(
438456
),
439457
),
440458
)
441-
]
459+
)
442460

443461
if is_last_tool_call:
444462
events.append(self.map_to_message_end_event(message_id))
@@ -665,7 +683,7 @@ def _map_langchain_ai_message_to_uipath_message_data(
665683
role="assistant",
666684
content_parts=content_parts,
667685
tool_calls=uipath_tool_calls,
668-
interrupts=[], # TODO: Interrupts
686+
interrupts=[], # interrupts are skipped during eval mode
669687
)
670688

671689

src/uipath_langchain/runtime/runtime.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from uipath.runtime.schema import UiPathRuntimeSchema
3131

32+
from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION
3233
from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
3334
from uipath_langchain.runtime.messages import UiPathChatMessagesMapper
3435
from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema
@@ -64,6 +65,7 @@ def __init__(
6465
self.entrypoint: str | None = entrypoint
6566
self.callbacks: list[BaseCallbackHandler] = callbacks or []
6667
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
68+
self.chat.confirmation_tool_names = self._detect_confirmation_tools()
6769
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
6870

6971
async def execute(
@@ -486,6 +488,20 @@ def _detect_middleware_nodes(self) -> set[str]:
486488

487489
return middleware_nodes
488490

491+
def _detect_confirmation_tools(self) -> set[str]:
492+
confirmation_tools: set[str] = set()
493+
for node_name, node_spec in self.graph.nodes.items():
494+
bound = getattr(node_spec, "bound", None)
495+
if bound is None:
496+
continue
497+
tool = getattr(bound, "tool", None)
498+
if tool is None:
499+
continue
500+
metadata = getattr(tool, "metadata", None) or {}
501+
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
502+
confirmation_tools.add(getattr(tool, "name", node_name))
503+
return confirmation_tools
504+
489505
def _is_middleware_node(self, node_name: str) -> bool:
490506
"""Check if a node name represents a middleware node."""
491507
return node_name in self._middleware_node_names

0 commit comments

Comments
 (0)