|
7 | 7 | from langchain_core.messages.tool import ToolCall, ToolMessage |
8 | 8 | from langchain_core.tools import BaseTool, InjectedToolCallId |
9 | 9 | from langchain_core.tools import tool as langchain_tool |
10 | | -from uipath.core.chat import ( |
11 | | - UiPathConversationToolCallConfirmationValue, |
12 | | -) |
13 | | - |
14 | 10 | from uipath_langchain._utils.durable_interrupt import durable_interrupt |
15 | 11 |
|
16 | 12 | CANCELLED_MESSAGE = "Cancelled by user" |
| 13 | +ARGS_MODIFIED_MESSAGE = "User has modified the tool arguments" |
17 | 14 |
|
18 | 15 | CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args" |
19 | 16 | REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation" |
20 | 17 |
|
21 | 18 |
|
| 19 | +def _wrap_with_args_modified_meta(result: Any, approved_args: dict[str, Any]) -> str: |
| 20 | + """Wrap a tool result with metadata indicating the user modified the args.""" |
| 21 | + try: |
| 22 | + result_value = json.loads(result) if isinstance(result, str) else result |
| 23 | + except (json.JSONDecodeError, TypeError): |
| 24 | + result_value = result |
| 25 | + return json.dumps( |
| 26 | + { |
| 27 | + "meta": { |
| 28 | + "message": ARGS_MODIFIED_MESSAGE, |
| 29 | + "executed_args": approved_args, |
| 30 | + }, |
| 31 | + "result": result_value, |
| 32 | + } |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +def get_confirmation_schema(tool: Any) -> dict[str, Any] | None: |
| 37 | + """Return the JSON input schema if this tool requires confirmation, else None.""" |
| 38 | + metadata = getattr(tool, "metadata", None) or {} |
| 39 | + if not metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION): |
| 40 | + return None |
| 41 | + tool_call_schema = getattr(tool, "tool_call_schema", None) |
| 42 | + return tool_call_schema.model_json_schema() if tool_call_schema is not None else {} |
| 43 | + |
| 44 | + |
22 | 45 | class ConfirmationResult(NamedTuple): |
23 | 46 | """Result of a tool confirmation check.""" |
24 | 47 |
|
@@ -47,20 +70,8 @@ def annotate_result(self, output: dict[str, Any] | Any) -> None: |
47 | 70 | msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = ( |
48 | 71 | self.approved_args |
49 | 72 | ) |
50 | | - if self.args_modified: |
51 | | - try: |
52 | | - result_value = json.loads(msg.content) |
53 | | - except (json.JSONDecodeError, TypeError): |
54 | | - result_value = msg.content |
55 | | - msg.content = json.dumps( |
56 | | - { |
57 | | - "meta": { |
58 | | - "args_modified_by_user": True, |
59 | | - "executed_args": self.approved_args, |
60 | | - }, |
61 | | - "result": result_value, |
62 | | - } |
63 | | - ) |
| 73 | + if self.args_modified and self.approved_args is not None: |
| 74 | + msg.content = _wrap_with_args_modified_meta(msg.content, self.approved_args) |
64 | 75 |
|
65 | 76 |
|
66 | 77 | def _patch_span_input(approved_args: dict[str, Any]) -> None: |
@@ -113,39 +124,23 @@ def request_approval( |
113 | 124 | """ |
114 | 125 | tool_call_id: str = tool_args.pop("tool_call_id") |
115 | 126 |
|
116 | | - input_schema: dict[str, Any] = {} |
117 | | - tool_call_schema = getattr( |
118 | | - tool, "tool_call_schema", None |
119 | | - ) # doesn't include InjectedToolCallId (tool id from claude/oai/etc.) |
120 | | - if tool_call_schema is not None: |
121 | | - input_schema = tool_call_schema.model_json_schema() |
122 | | - |
123 | 127 | @durable_interrupt |
124 | 128 | def ask_confirmation(): |
125 | | - return UiPathConversationToolCallConfirmationValue( |
126 | | - tool_call_id=tool_call_id, |
127 | | - tool_name=tool.name, |
128 | | - input_schema=input_schema, |
129 | | - input_value=tool_args, |
130 | | - ) |
| 129 | + return { |
| 130 | + "tool_call_id": tool_call_id, |
| 131 | + "tool_name": tool.name, |
| 132 | + "input": tool_args, |
| 133 | + } |
131 | 134 |
|
132 | 135 | response = ask_confirmation() |
133 | 136 |
|
134 | | - # The resume payload from CAS has shape: |
135 | | - # {"type": "uipath_cas_tool_call_confirmation", |
136 | | - # "value": {"approved": bool, "input": <edited args | None>}} |
137 | 137 | if not isinstance(response, dict): |
138 | 138 | return tool_args |
139 | 139 |
|
140 | | - confirmation = response.get("value", response) |
141 | | - if not confirmation.get("approved", True): |
| 140 | + if not response.get("approved", True): |
142 | 141 | return None |
143 | 142 |
|
144 | | - return ( |
145 | | - confirmation.get("input") |
146 | | - if confirmation.get("input") is not None |
147 | | - else tool_args |
148 | | - ) |
| 143 | + return response.get("input") if response.get("input") is not None else tool_args |
149 | 144 |
|
150 | 145 |
|
151 | 146 | # for conversational low code agents |
@@ -200,8 +195,15 @@ def wrapper(**tool_args: Any) -> Any: |
200 | 195 | if approved_args is None: |
201 | 196 | return json.dumps({"meta": CANCELLED_MESSAGE}) |
202 | 197 |
|
| 198 | + args_modified = approved_args != tool_args |
| 199 | + |
203 | 200 | _patch_span_input(approved_args) |
204 | | - return fn(**approved_args) |
| 201 | + result = fn(**approved_args) |
| 202 | + |
| 203 | + if args_modified: |
| 204 | + return _wrap_with_args_modified_meta(result, approved_args) |
| 205 | + |
| 206 | + return result |
205 | 207 |
|
206 | 208 | # rewrite the signature: e.g. (query: str) -> (query: str, *, tool_call_id: str) |
207 | 209 | original_sig = inspect.signature(fn) |
@@ -234,6 +236,10 @@ def wrapper(**tool_args: Any) -> Any: |
234 | 236 | return_direct=return_direct, |
235 | 237 | ) |
236 | 238 |
|
| 239 | + if result.metadata is None: |
| 240 | + result.metadata = {} |
| 241 | + result.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True |
| 242 | + |
237 | 243 | _created_tool.append(result) |
238 | 244 | return result |
239 | 245 |
|
|
0 commit comments