Skip to content

Commit 413c601

Browse files
committed
fix: revert deferred tool call, add confirmToolCall support with ToolRunnableCallable [JAR-9208]
1 parent f536077 commit 413c601

File tree

11 files changed

+286
-143
lines changed

11 files changed

+286
-143
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.9.24"
3+
version = "0.9.25"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_tools_from_resources,
1212
)
1313
from .tool_node import (
14+
ConversationalToolRunnableCallable,
1415
ToolWrapperMixin,
1516
UiPathToolNode,
1617
create_tool_node,
@@ -28,6 +29,7 @@
2829
"create_ixp_extraction_tool",
2930
"create_ixp_escalation_tool",
3031
"UiPathToolNode",
32+
"ConversationalToolRunnableCallable",
3133
"ToolWrapperMixin",
3234
"wrap_tools_with_error_handling",
3335
]

src/uipath_langchain/agent/tools/tool_node.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
extract_current_tool_call_index,
2323
find_latest_ai_message,
2424
)
25-
from uipath_langchain.chat.hitl import request_conversational_tool_confirmation
25+
from uipath_langchain.chat.hitl import (
26+
REQUIRE_CONVERSATIONAL_CONFIRMATION,
27+
request_conversational_tool_confirmation,
28+
)
2629

2730
# the type safety can be improved with generics
2831
ToolWrapperReturnType = dict[str, Any] | Command[Any] | None
@@ -274,9 +277,27 @@ async def _afunc(state: AgentGraphState) -> OutputType:
274277
raise
275278
return result
276279

280+
tool = getattr(tool_node, "tool", None)
281+
282+
# Preserve tool ref so the runtime can discover which tools need confirmation
283+
# (see runtime.py _get_tool_confirmation_info)
284+
metadata = getattr(tool, "metadata", None) or {}
285+
if isinstance(tool, BaseTool) and metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
286+
return ConversationalToolRunnableCallable(
287+
func=_func, afunc=_afunc, name=tool_name, tool=tool
288+
)
289+
277290
return RunnableCallable(func=_func, afunc=_afunc, name=tool_name)
278291

279292

293+
class ConversationalToolRunnableCallable(RunnableCallable):
294+
"""Preserves a reference to the underlying BaseTool for conversational HITL confirmation."""
295+
296+
def __init__(self, *, func: Any, afunc: Any, name: str, tool: BaseTool):
297+
super().__init__(func=func, afunc=afunc, name=name)
298+
self.tool = tool
299+
300+
280301
class ToolWrapperMixin:
281302
wrapper: ToolWrapperType | None = None
282303
awrapper: AsyncToolWrapperType | None = None

src/uipath_langchain/chat/hitl.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,41 @@
77
from langchain_core.messages.tool import ToolCall, ToolMessage
88
from langchain_core.tools import BaseTool, InjectedToolCallId
99
from langchain_core.tools import tool as langchain_tool
10-
from uipath.core.chat import (
11-
UiPathConversationToolCallConfirmationValue,
12-
)
13-
1410
from uipath_langchain._utils.durable_interrupt import durable_interrupt
1511

1612
CANCELLED_MESSAGE = "Cancelled by user"
13+
ARGS_MODIFIED_MESSAGE = "User has modified the tool arguments"
1714

1815
CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args"
1916
REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation"
2017

2118

19+
def _wrap_with_args_modified_meta(result: Any, approved_args: dict[str, Any]) -> str:
20+
"""Wrap a tool result with metadata indicating the user modified the args."""
21+
try:
22+
result_value = json.loads(result) if isinstance(result, str) else result
23+
except (json.JSONDecodeError, TypeError):
24+
result_value = result
25+
return json.dumps(
26+
{
27+
"meta": {
28+
"message": ARGS_MODIFIED_MESSAGE,
29+
"executed_args": approved_args,
30+
},
31+
"result": result_value,
32+
}
33+
)
34+
35+
36+
def get_confirmation_schema(tool: Any) -> dict[str, Any] | None:
37+
"""Return the JSON input schema if this tool requires confirmation, else None."""
38+
metadata = getattr(tool, "metadata", None) or {}
39+
if not metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
40+
return None
41+
tool_call_schema = getattr(tool, "tool_call_schema", None)
42+
return tool_call_schema.model_json_schema() if tool_call_schema is not None else {}
43+
44+
2245
class ConfirmationResult(NamedTuple):
2346
"""Result of a tool confirmation check."""
2447

@@ -47,20 +70,8 @@ def annotate_result(self, output: dict[str, Any] | Any) -> None:
4770
msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = (
4871
self.approved_args
4972
)
50-
if self.args_modified:
51-
try:
52-
result_value = json.loads(msg.content)
53-
except (json.JSONDecodeError, TypeError):
54-
result_value = msg.content
55-
msg.content = json.dumps(
56-
{
57-
"meta": {
58-
"args_modified_by_user": True,
59-
"executed_args": self.approved_args,
60-
},
61-
"result": result_value,
62-
}
63-
)
73+
if self.args_modified and self.approved_args is not None:
74+
msg.content = _wrap_with_args_modified_meta(msg.content, self.approved_args)
6475

6576

6677
def _patch_span_input(approved_args: dict[str, Any]) -> None:
@@ -113,39 +124,23 @@ def request_approval(
113124
"""
114125
tool_call_id: str = tool_args.pop("tool_call_id")
115126

116-
input_schema: dict[str, Any] = {}
117-
tool_call_schema = getattr(
118-
tool, "tool_call_schema", None
119-
) # doesn't include InjectedToolCallId (tool id from claude/oai/etc.)
120-
if tool_call_schema is not None:
121-
input_schema = tool_call_schema.model_json_schema()
122-
123127
@durable_interrupt
124128
def ask_confirmation():
125-
return UiPathConversationToolCallConfirmationValue(
126-
tool_call_id=tool_call_id,
127-
tool_name=tool.name,
128-
input_schema=input_schema,
129-
input_value=tool_args,
130-
)
129+
return {
130+
"tool_call_id": tool_call_id,
131+
"tool_name": tool.name,
132+
"input": tool_args,
133+
}
131134

132135
response = ask_confirmation()
133136

134-
# The resume payload from CAS has shape:
135-
# {"type": "uipath_cas_tool_call_confirmation",
136-
# "value": {"approved": bool, "input": <edited args | None>}}
137137
if not isinstance(response, dict):
138138
return tool_args
139139

140-
confirmation = response.get("value", response)
141-
if not confirmation.get("approved", True):
140+
if not response.get("approved", True):
142141
return None
143142

144-
return (
145-
confirmation.get("input")
146-
if confirmation.get("input") is not None
147-
else tool_args
148-
)
143+
return response.get("input") if response.get("input") is not None else tool_args
149144

150145

151146
# for conversational low code agents
@@ -200,8 +195,15 @@ def wrapper(**tool_args: Any) -> Any:
200195
if approved_args is None:
201196
return json.dumps({"meta": CANCELLED_MESSAGE})
202197

198+
args_modified = approved_args != tool_args
199+
203200
_patch_span_input(approved_args)
204-
return fn(**approved_args)
201+
result = fn(**approved_args)
202+
203+
if args_modified:
204+
return _wrap_with_args_modified_meta(result, approved_args)
205+
206+
return result
205207

206208
# rewrite the signature: e.g. (query: str) -> (query: str, *, tool_call_id: str)
207209
original_sig = inspect.signature(fn)
@@ -234,6 +236,10 @@ def wrapper(**tool_args: Any) -> Any:
234236
return_direct=return_direct,
235237
)
236238

239+
if result.metadata is None:
240+
result.metadata = {}
241+
result.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True
242+
237243
_created_tool.append(result)
238244
return result
239245

src/uipath_langchain/runtime/messages.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
5959
self.runtime_id = runtime_id
6060
self.storage = storage
6161
self.current_message: AIMessageChunk | AIMessage
62-
self.tool_names_requiring_confirmation: set[str] = set()
62+
self.tool_confirmation_schemas: dict[str, Any] = {}
6363
self.seen_message_ids: set[str] = set()
6464
self._storage_lock = asyncio.Lock()
6565
self._citation_stream_processor = CitationStreamProcessor()
@@ -320,7 +320,6 @@ async def map_ai_message_chunk_to_events(
320320

321321
events: list[UiPathConversationMessageEvent] = []
322322

323-
# For every new message_id, start a new message
324323
if message.id not in self.seen_message_ids:
325324
self.current_message = message
326325
self.seen_message_ids.add(message.id)
@@ -339,8 +338,13 @@ async def map_ai_message_chunk_to_events(
339338
self._chunk_to_message_event(message.id, chunk)
340339
)
341340
case "tool_call_chunk":
342-
# Accumulate the message chunk. Note that we assume no interweaving of AIMessage and AIMessageChunks for a given message.
343-
if isinstance(self.current_message, AIMessageChunk):
341+
# Skip the first chunk — it's already assigned as current_message above,
342+
# so accumulating it with itself would duplicate fields via string concat
343+
# (e.g. tool name "search_web" becomes "search_websearch_web").
344+
if (
345+
isinstance(self.current_message, AIMessageChunk)
346+
and self.current_message is not message
347+
):
344348
self.current_message = self.current_message + message
345349

346350
elif isinstance(message.content, str) and message.content:
@@ -425,16 +429,19 @@ async def map_current_message_to_start_tool_call_events(self):
425429
self.current_message.id
426430
)
427431

428-
# if tool requires confirmation, we skip start tool call
429-
if (
430-
tool_call["name"]
431-
not in self.tool_names_requiring_confirmation
432-
):
433-
events.append(
434-
self.map_tool_call_to_tool_call_start_event(
435-
self.current_message.id, tool_call
436-
)
432+
tool_name = tool_call["name"]
433+
require_confirmation = (
434+
tool_name in self.tool_confirmation_schemas
435+
)
436+
input_schema = self.tool_confirmation_schemas.get(tool_name)
437+
events.append(
438+
self.map_tool_call_to_tool_call_start_event(
439+
self.current_message.id,
440+
tool_call,
441+
require_confirmation=require_confirmation or None,
442+
input_schema=input_schema,
437443
)
444+
)
438445

439446
if self.storage is not None:
440447
await self.storage.set_value(
@@ -531,7 +538,12 @@ async def get_message_id_for_tool_call(
531538
return message_id, is_last
532539

533540
def map_tool_call_to_tool_call_start_event(
534-
self, message_id: str, tool_call: ToolCall
541+
self,
542+
message_id: str,
543+
tool_call: ToolCall,
544+
*,
545+
require_confirmation: bool | None = None,
546+
input_schema: Any | None = None,
535547
) -> UiPathConversationMessageEvent:
536548
return UiPathConversationMessageEvent(
537549
message_id=message_id,
@@ -541,6 +553,8 @@ def map_tool_call_to_tool_call_start_event(
541553
tool_name=tool_call["name"],
542554
timestamp=self.get_timestamp(),
543555
input=tool_call["args"],
556+
require_confirmation=require_confirmation,
557+
input_schema=input_schema,
544558
),
545559
),
546560
)

src/uipath_langchain/runtime/runtime.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from uipath.runtime.schema import UiPathRuntimeSchema
3131

32-
from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION
32+
from uipath_langchain.chat.hitl import get_confirmation_schema
3333
from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
3434
from uipath_langchain.runtime.messages import UiPathChatMessagesMapper
3535
from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema
@@ -65,9 +65,7 @@ def __init__(
6565
self.entrypoint: str | None = entrypoint
6666
self.callbacks: list[BaseCallbackHandler] = callbacks or []
6767
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
68-
self.chat.tool_names_requiring_confirmation = (
69-
self._get_tool_names_requiring_confirmation()
70-
)
68+
self.chat.tool_confirmation_schemas = self._get_tool_confirmation_info()
7169
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
7270

7371
async def execute(
@@ -490,17 +488,36 @@ def _detect_middleware_nodes(self) -> set[str]:
490488

491489
return middleware_nodes
492490

493-
def _get_tool_names_requiring_confirmation(self) -> set[str]:
494-
names: set[str] = set()
491+
def _get_tool_confirmation_info(self) -> dict[str, Any]:
492+
"""Build {tool_name: input_schema} for tools requiring confirmation.
493+
494+
Walks compiled graph nodes once at runtime init. This is needed because coded agents
495+
(create_agent) export a compiled graph as the only artifact — there's no side channel
496+
to pass confirmation metadata from the build step to the runtime.
497+
"""
498+
schemas: dict[str, Any] = {}
495499
for node_name, node_spec in self.graph.nodes.items():
496-
# langgraph's processing node.bound -> runnable.tool -> baseTool (if tool node)
497-
tool = getattr(getattr(node_spec, "bound", None), "tool", None)
498-
if tool is None:
500+
bound = getattr(node_spec, "bound", None)
501+
if bound is None:
499502
continue
500-
metadata = getattr(tool, "metadata", None) or {}
501-
if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION):
502-
names.add(getattr(tool, "name", node_name))
503-
return names
503+
504+
# Coded agents: one tool per node
505+
tool = getattr(bound, "tool", None)
506+
if tool is not None:
507+
schema = get_confirmation_schema(tool)
508+
if schema is not None:
509+
schemas[getattr(tool, "name", node_name)] = schema
510+
continue
511+
512+
# Low-code agents: multiple tools in one node
513+
tools_by_name = getattr(bound, "tools_by_name", None)
514+
if isinstance(tools_by_name, dict):
515+
for name, tool in tools_by_name.items():
516+
schema = get_confirmation_schema(tool)
517+
if schema is not None:
518+
schemas[str(getattr(tool, "name", name))] = schema
519+
520+
return schemas
504521

505522
def _is_middleware_node(self, node_name: str) -> bool:
506523
"""Check if a node name represents a middleware node."""

tests/agent/tools/test_tool_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
wrap_tools_with_error_handling,
2424
)
2525
from uipath_langchain.chat.hitl import (
26+
ARGS_MODIFIED_MESSAGE,
2627
CANCELLED_MESSAGE,
2728
CONVERSATIONAL_APPROVED_TOOL_ARGS,
2829
)
@@ -507,7 +508,7 @@ def test_approved_same_args_no_meta(
507508
assert result is not None
508509
assert isinstance(result, dict)
509510
msg = result["messages"][0]
510-
assert "args_modified_by_user" not in msg.content
511+
assert ARGS_MODIFIED_MESSAGE not in msg.content
511512
assert "Mock result:" in msg.content
512513

513514
@patch(
@@ -528,7 +529,7 @@ def test_approved_modified_args_injects_meta(
528529

529530
assert isinstance(msg.content, str)
530531
wrapped = json.loads(msg.content)
531-
assert wrapped["meta"]["args_modified_by_user"] is True
532+
assert wrapped["meta"]["message"] == ARGS_MODIFIED_MESSAGE
532533
assert wrapped["meta"]["executed_args"] == {"input_text": "edited"}
533534
assert "Mock result: edited" in wrapped["result"]
534535

@@ -564,7 +565,7 @@ async def test_async_approved_modified_args(
564565

565566
assert isinstance(msg.content, str)
566567
wrapped = json.loads(msg.content)
567-
assert wrapped["meta"]["args_modified_by_user"] is True
568+
assert wrapped["meta"]["message"] == ARGS_MODIFIED_MESSAGE
568569
assert wrapped["meta"]["executed_args"] == {"input_text": "async edited"}
569570
assert "Async mock result: async edited" in wrapped["result"]
570571

0 commit comments

Comments
 (0)