Skip to content

Commit 3a029a8

Browse files
committed
fix: return modified arguments
1 parent ebbf522 commit 3a029a8

6 files changed

Lines changed: 64 additions & 37 deletions

File tree

src/uipath_langchain/agent/tools/tool_node.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
extract_current_tool_call_index,
2222
find_latest_ai_message,
2323
)
24-
from uipath_langchain.chat.hitl import check_tool_confirmation
24+
from uipath_langchain.chat.hitl import request_tool_confirmation
2525

2626
# the type safety can be improved with generics
2727
ToolWrapperReturnType = dict[str, Any] | Command[Any] | None
@@ -80,7 +80,10 @@ def _func(self, state: AgentGraphState) -> OutputType:
8080
if call is None:
8181
return None
8282

83-
confirmation = check_tool_confirmation(call, self.tool)
83+
# HITL: prompt user for approval if tool requires confirmation
84+
confirmation = request_tool_confirmation(call, self.tool)
85+
86+
# HITL cancelled: user rejected the tool call
8487
if confirmation is not None and confirmation.cancelled:
8588
return self._process_result(call, confirmation.cancelled)
8689

@@ -93,6 +96,7 @@ def _func(self, state: AgentGraphState) -> OutputType:
9396
else:
9497
result = self.tool.invoke(call)
9598
output = self._process_result(call, result)
99+
# HITL approved: tag result with approved args (and whether they were modified)
96100
if confirmation is not None:
97101
confirmation.annotate_result(output)
98102
return output
@@ -106,7 +110,10 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
106110
if call is None:
107111
return None
108112

109-
confirmation = check_tool_confirmation(call, self.tool)
113+
# HITL: prompt user for approval if tool requires confirmation
114+
confirmation = request_tool_confirmation(call, self.tool)
115+
116+
# HITL cancelled: user rejected the tool call
110117
if confirmation is not None and confirmation.cancelled:
111118
return self._process_result(call, confirmation.cancelled)
112119

@@ -119,6 +126,7 @@ async def _afunc(self, state: AgentGraphState) -> OutputType:
119126
else:
120127
result = await self.tool.ainvoke(call)
121128
output = self._process_result(call, result)
129+
# HITL approved: tag result with approved args (and whether they were modified)
122130
if confirmation is not None:
123131
confirmation.annotate_result(output)
124132
return output

src/uipath_langchain/chat/hitl.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import inspect
3+
import json
34
from inspect import Parameter
45
from typing import Annotated, Any, Callable, NamedTuple
56

@@ -12,7 +13,6 @@
1213
)
1314

1415
CANCELLED_MESSAGE = "Cancelled by user"
15-
ARGS_MODIFIED_MESSAGE = "Tool arguments were modified by the user"
1616

1717
CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args"
1818
REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation"
@@ -39,8 +39,18 @@ def annotate_result(self, output: dict[str, Any] | Any) -> None:
3939
self.approved_args
4040
)
4141
if self.args_modified:
42-
msg.content = (
43-
f'{{"meta": "{ARGS_MODIFIED_MESSAGE}", "result": {msg.content}}}'
42+
try:
43+
result_value = json.loads(msg.content)
44+
except (json.JSONDecodeError, TypeError):
45+
result_value = msg.content
46+
msg.content = json.dumps(
47+
{
48+
"meta": {
49+
"args_modified_by_user": True,
50+
"executed_args": self.approved_args,
51+
},
52+
"result": result_value,
53+
}
4454
)
4555

4656

@@ -127,9 +137,10 @@ def request_approval(
127137
)
128138

129139

130-
def check_tool_confirmation(
140+
def request_tool_confirmation(
131141
call: ToolCall, tool: BaseTool
132142
) -> ConfirmationResult | None:
143+
"""Check whether a tool requires user confirmation and request approval"""
133144
if not (tool.metadata and tool.metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION)):
134145
return None
135146

@@ -147,6 +158,7 @@ def check_tool_confirmation(
147158
original_args
148159
)
149160
return ConfirmationResult(cancelled=cancelled_msg, args_modified=False)
161+
# Mutate call args so the tool executes with the approved values
150162
call["args"] = approved_args
151163
return ConfirmationResult(
152164
cancelled=None,

src/uipath_langchain/runtime/messages.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
6060
"""Initialize the mapper with empty state."""
6161
self.runtime_id = runtime_id
6262
self.storage = storage
63-
self.confirmation_tool_names: set[str] = set[str]()
63+
self.confirmation_tool_names: set[str] = set()
6464
self.current_message: AIMessageChunk
6565
self.seen_message_ids: set[str] = set()
6666
self._storage_lock = asyncio.Lock()
@@ -394,7 +394,7 @@ async def map_current_message_to_start_tool_call_events(self):
394394
)
395395

396396
if tool_call["name"] not in self.confirmation_tool_names:
397-
# defer tool call for HITL
397+
# if tool requires HITL, we skip start tool call
398398
events.append(
399399
self.map_tool_call_to_tool_call_start_event(
400400
self.current_message.id, tool_call
@@ -683,7 +683,6 @@ def _map_langchain_ai_message_to_uipath_message_data(
683683
role="assistant",
684684
content_parts=content_parts,
685685
tool_calls=uipath_tool_calls,
686-
interrupts=[], # interrupts are skipped during eval mode
687686
)
688687

689688

src/uipath_langchain/runtime/runtime.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
self.entrypoint: str | None = entrypoint
6666
self.callbacks: list[BaseCallbackHandler] = callbacks or []
6767
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
68-
self.chat.confirmation_tool_names = self._detect_confirmation_tools()
68+
self.chat.confirmation_tool_names = self._get_confirmation_tool_names()
6969
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
7070

7171
async def execute(
@@ -488,19 +488,17 @@ def _detect_middleware_nodes(self) -> set[str]:
488488

489489
return middleware_nodes
490490

491-
def _detect_confirmation_tools(self) -> set[str]:
492-
confirmation_tools: set[str] = set()
491+
def _get_confirmation_tool_names(self) -> set[str]:
492+
names: set[str] = set()
493493
for node_name, node_spec in self.graph.nodes.items():
494-
bound = getattr(node_spec, "bound", None)
495-
if bound is None:
496-
continue
497-
tool = getattr(bound, "tool", None)
494+
# PregelNode.bound -> Runnable, Runnable.tool -> BaseTool (if tool node)
495+
tool = getattr(getattr(node_spec, "bound", None), "tool", None)
498496
if tool is None:
499497
continue
500498
metadata = getattr(tool, "metadata", None) or {}
501499
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
502-
confirmation_tools.add(getattr(tool, "name", node_name))
503-
return confirmation_tools
500+
names.add(getattr(tool, "name", node_name))
501+
return names
504502

505503
def _is_middleware_node(self, node_name: str) -> bool:
506504
"""Check if a node name represents a middleware node."""

tests/agent/tools/test_tool_node.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
create_tool_node,
2222
)
2323
from uipath_langchain.chat.hitl import (
24-
ARGS_MODIFIED_MESSAGE,
2524
CANCELLED_MESSAGE,
2625
CONVERSATIONAL_APPROVED_TOOL_ARGS,
2726
)
@@ -551,7 +550,7 @@ def test_approved_same_args_no_meta(
551550
assert result is not None
552551
assert isinstance(result, dict)
553552
msg = result["messages"][0]
554-
assert ARGS_MODIFIED_MESSAGE not in msg.content
553+
assert "args_modified_by_user" not in msg.content
555554
assert "Mock result:" in msg.content
556555

557556
@patch(
@@ -569,8 +568,12 @@ def test_approved_modified_args_injects_meta(
569568
assert result is not None
570569
assert isinstance(result, dict)
571570
msg = result["messages"][0]
572-
assert ARGS_MODIFIED_MESSAGE in msg.content
573-
assert "Mock result: edited" in msg.content
571+
import json
572+
573+
wrapped = json.loads(msg.content)
574+
assert wrapped["meta"]["args_modified_by_user"] is True
575+
assert wrapped["meta"]["executed_args"] == {"input_text": "edited"}
576+
assert "Mock result: edited" in wrapped["result"]
574577

575578
@patch("uipath_langchain.chat.hitl.request_approval", return_value=None)
576579
async def test_async_cancelled(
@@ -601,8 +604,12 @@ async def test_async_approved_modified_args(
601604
assert result is not None
602605
assert isinstance(result, dict)
603606
msg = result["messages"][0]
604-
assert ARGS_MODIFIED_MESSAGE in msg.content
605-
assert "Async mock result: async edited" in msg.content
607+
import json
608+
609+
wrapped = json.loads(msg.content)
610+
assert wrapped["meta"]["args_modified_by_user"] is True
611+
assert wrapped["meta"]["executed_args"] == {"input_text": "async edited"}
612+
assert "Async mock result: async edited" in wrapped["result"]
606613

607614
@patch(
608615
"uipath_langchain.chat.hitl.request_approval",

tests/chat/test_hitl.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
from langchain_core.tools import BaseTool
88

99
from uipath_langchain.chat.hitl import (
10-
ARGS_MODIFIED_MESSAGE,
1110
CANCELLED_MESSAGE,
1211
CONVERSATIONAL_APPROVED_TOOL_ARGS,
1312
ConfirmationResult,
14-
check_tool_confirmation,
1513
request_approval,
14+
request_tool_confirmation,
1615
)
1716

1817

@@ -29,33 +28,33 @@ def _make_call(args: dict[str, Any] | None = None) -> ToolCall:
2928

3029

3130
class TestCheckToolConfirmation:
32-
"""Tests for check_tool_confirmation."""
31+
"""Tests for request_tool_confirmation."""
3332

3433
def test_returns_none_when_no_metadata(self):
3534
"""No metadata → no confirmation needed."""
3635
tool = MockTool()
3736
call = _make_call()
38-
assert check_tool_confirmation(call, tool) is None
37+
assert request_tool_confirmation(call, tool) is None
3938

4039
def test_returns_none_when_flag_not_set(self):
4140
"""Metadata exists but flag is missing → no confirmation needed."""
4241
tool = MockTool(metadata={"other_key": True})
4342
call = _make_call()
44-
assert check_tool_confirmation(call, tool) is None
43+
assert request_tool_confirmation(call, tool) is None
4544

4645
def test_returns_none_when_flag_false(self):
4746
"""Flag explicitly False → no confirmation needed."""
4847
tool = MockTool(metadata={"require_conversational_confirmation": False})
4948
call = _make_call()
50-
assert check_tool_confirmation(call, tool) is None
49+
assert request_tool_confirmation(call, tool) is None
5150

5251
@patch("uipath_langchain.chat.hitl.request_approval", return_value=None)
5352
def test_cancelled_returns_tool_message(self, mock_approval):
5453
"""User rejects → ConfirmationResult with cancelled ToolMessage and metadata."""
5554
tool = MockTool(metadata={"require_conversational_confirmation": True})
5655
call = _make_call()
5756

58-
result = check_tool_confirmation(call, tool)
57+
result = request_tool_confirmation(call, tool)
5958

6059
assert result is not None
6160
assert isinstance(result, ConfirmationResult)
@@ -78,7 +77,7 @@ def test_approved_same_args(self, mock_approval):
7877
tool = MockTool(metadata={"require_conversational_confirmation": True})
7978
call = _make_call({"query": "test"})
8079

81-
result = check_tool_confirmation(call, tool)
80+
result = request_tool_confirmation(call, tool)
8281

8382
assert result is not None
8483
assert result.cancelled is None
@@ -94,7 +93,7 @@ def test_approved_modified_args(self, mock_approval):
9493
tool = MockTool(metadata={"require_conversational_confirmation": True})
9594
call = _make_call({"query": "original"})
9695

97-
result = check_tool_confirmation(call, tool)
96+
result = request_tool_confirmation(call, tool)
9897

9998
assert result is not None
10099
assert result.cancelled is None
@@ -122,7 +121,7 @@ def test_annotate_sets_metadata(self):
122121
assert msg.content == "result"
123122

124123
def test_annotate_wraps_content_when_modified(self):
125-
"""annotate_result wraps content when args were modified."""
124+
"""annotate_result wraps content with structured meta when args were modified."""
126125
confirmation = ConfirmationResult(
127126
cancelled=None, args_modified=True, approved_args={"query": "edited"}
128127
)
@@ -134,8 +133,12 @@ def test_annotate_wraps_content_when_modified(self):
134133
assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == {
135134
"query": "edited"
136135
}
137-
assert ARGS_MODIFIED_MESSAGE in msg.content
138-
assert "result" in msg.content
136+
import json
137+
138+
wrapped = json.loads(msg.content)
139+
assert wrapped["meta"]["args_modified_by_user"] is True
140+
assert wrapped["meta"]["executed_args"] == {"query": "edited"}
141+
assert wrapped["result"] == "result"
139142

140143

141144
class TestRequestApprovalTruthiness:

0 commit comments

Comments
 (0)