Skip to content

Commit 050977d

Browse files
committed
test: validation and tests
1 parent 0f6e711 commit 050977d

4 files changed

Lines changed: 285 additions & 11 deletions

File tree

src/uipath_langchain/agent/react/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Sequence, Type, TypeVar
1+
from typing import Callable, Sequence, Type, TypeVar
22

33
from langchain_core.language_models import BaseChatModel
44
from langchain_core.messages import HumanMessage, SystemMessage

src/uipath_langchain/agent/react/init_node.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@ def graph_state_init(state: Any) -> Any:
7474
# Validate client-side tool declarations from the exchange input
7575
if client_side_tools:
7676
client_tools_input = getattr(state, UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY, None)
77-
if client_tools_input is not None and isinstance(client_tools_input, list):
78-
validate_and_apply_tool_filter(client_tools_input, client_side_tools)
79-
else:
77+
if client_tools_input is None:
8078
available_client_side_tools.set(None)
79+
elif not isinstance(client_tools_input, list):
80+
raise ValueError(
81+
f"'{UIPATH_CLIENT_SIDE_TOOLS_INPUT_KEY}' must be a list of tool declarations, "
82+
f"got {type(client_tools_input).__name__}."
83+
)
84+
else:
85+
validate_and_apply_tool_filter(client_tools_input, client_side_tools)
8186

8287
# Calculate initial message count for tracking new messages
8388
initial_message_count = (

src/uipath_langchain/runtime/messages.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from uipath.runtime import UiPathRuntimeStorageProtocol
4242

43+
from uipath_langchain.agent.tools.client_side_tool import ClientSideToolInfo
4344
from uipath_langchain.chat.hitl import IS_CONVERSATIONAL_CLIENT_SIDE_TOOL
4445

4546
from ._citations import (
@@ -67,7 +68,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
6768
self.storage = storage
6869
self.current_message: AIMessageChunk | AIMessage
6970
self.tools_requiring_confirmation: dict[str, Any] = {}
70-
self.client_side_tools: dict[str, Any] = {}
71+
self.client_side_tools: dict[str, ClientSideToolInfo] = {}
7172
self.seen_message_ids: set[str] = set()
7273
self._storage_lock = asyncio.Lock()
7374
self._citation_stream_processor = CitationStreamProcessor()
@@ -448,13 +449,8 @@ async def map_current_message_to_start_tool_call_events(self):
448449
)
449450
input_schema = self.tools_requiring_confirmation.get(tool_name)
450451
is_client_side = tool_name in self.client_side_tools
451-
client_tool_info = self.client_side_tools.get(tool_name)
452452
output_schema = (
453-
(
454-
client_tool_info.get("output_schema")
455-
if isinstance(client_tool_info, dict)
456-
else client_tool_info
457-
)
453+
self.client_side_tools[tool_name].get("output_schema")
458454
if is_client_side
459455
else None
460456
)
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""Tests for client-side tool validation and filtering logic."""
2+
3+
import pytest
4+
5+
from uipath_langchain.agent.tools.client_side_tool import (
6+
ClientSideToolInfo,
7+
available_client_side_tools,
8+
validate_and_apply_tool_filter,
9+
)
10+
11+
AGENT_TOOLS: dict[str, ClientSideToolInfo] = {
12+
"get_weather": {
13+
"input_schema": {
14+
"type": "object",
15+
"properties": {"city": {"type": "string"}},
16+
},
17+
"output_schema": {
18+
"type": "object",
19+
"properties": {"temp": {"type": "number"}},
20+
},
21+
},
22+
"show_map": {
23+
"input_schema": None,
24+
"output_schema": None,
25+
},
26+
}
27+
28+
29+
class TestValidateAndApplyToolFilter:
30+
"""Tests for validate_and_apply_tool_filter."""
31+
32+
def test_valid_declarations_set_filter(self):
33+
declared = [
34+
{"name": "get_weather"},
35+
{"name": "show_map"},
36+
]
37+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
38+
39+
allowed = available_client_side_tools.get()
40+
assert allowed == {"get_weather", "show_map"}
41+
42+
def test_missing_required_tool_raises(self):
43+
declared = [{"name": "get_weather"}] # missing show_map
44+
45+
with pytest.raises(ValueError, match="Missing required client-side tools"):
46+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
47+
48+
def test_input_schema_mismatch_raises(self):
49+
declared = [
50+
{
51+
"name": "get_weather",
52+
"inputSchema": {
53+
"type": "object",
54+
"properties": {"location": {"type": "string"}},
55+
},
56+
},
57+
{"name": "show_map"},
58+
]
59+
60+
with pytest.raises(ValueError, match="inputSchema does not match"):
61+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
62+
63+
def test_output_schema_mismatch_raises(self):
64+
declared = [
65+
{
66+
"name": "get_weather",
67+
"outputSchema": {
68+
"type": "object",
69+
"properties": {"temperature": {"type": "string"}},
70+
},
71+
},
72+
{"name": "show_map"},
73+
]
74+
75+
with pytest.raises(ValueError, match="outputSchema does not match"):
76+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
77+
78+
def test_unknown_extra_tools_are_ignored(self):
79+
declared = [
80+
{"name": "get_weather"},
81+
{"name": "show_map"},
82+
{"name": "unknown_tool"},
83+
]
84+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
85+
86+
allowed = available_client_side_tools.get()
87+
assert allowed is not None
88+
assert "unknown_tool" in allowed
89+
assert "get_weather" in allowed
90+
91+
def test_string_declarations_accepted(self):
92+
declared = ["get_weather", "show_map"]
93+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
94+
95+
allowed = available_client_side_tools.get()
96+
assert allowed == {"get_weather", "show_map"}
97+
98+
def test_missing_name_field_raises(self):
99+
declared = [{"inputSchema": {}}]
100+
101+
with pytest.raises(ValueError, match="missing required 'name' field"):
102+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
103+
104+
def test_invalid_type_raises(self):
105+
declared = [123]
106+
107+
with pytest.raises(ValueError, match="must be a dict or string"):
108+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
109+
110+
def test_duplicate_name_raises(self):
111+
declared = [
112+
{"name": "get_weather"},
113+
{"name": "get_weather"},
114+
{"name": "show_map"},
115+
]
116+
117+
with pytest.raises(ValueError, match="Duplicate client-side tool"):
118+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
119+
120+
def test_matching_schemas_pass(self):
121+
declared = [
122+
{
123+
"name": "get_weather",
124+
"inputSchema": {
125+
"type": "object",
126+
"properties": {"city": {"type": "string"}},
127+
},
128+
"outputSchema": {
129+
"type": "object",
130+
"properties": {"temp": {"type": "number"}},
131+
},
132+
},
133+
{"name": "show_map"},
134+
]
135+
validate_and_apply_tool_filter(declared, AGENT_TOOLS)
136+
137+
allowed = available_client_side_tools.get()
138+
assert allowed is not None
139+
assert "get_weather" in allowed
140+
141+
142+
class TestToolNotAvailableEnforcement:
143+
"""Tests that client_side_tool_fn returns error ToolMessage when tool is filtered out."""
144+
145+
def test_tool_not_in_allowed_set_returns_error(self):
146+
token = available_client_side_tools.set({"other_tool"})
147+
try:
148+
from unittest.mock import AsyncMock, patch
149+
150+
from uipath.agent.models.agent import AgentClientSideToolResourceConfig
151+
152+
resource = AgentClientSideToolResourceConfig(
153+
name="my_tool",
154+
description="A test tool",
155+
input_schema={
156+
"type": "object",
157+
"properties": {"query": {"type": "string"}},
158+
},
159+
output_schema=None,
160+
)
161+
162+
from uipath_langchain.agent.tools.client_side_tool import (
163+
create_client_side_tool,
164+
)
165+
166+
tool = create_client_side_tool(resource)
167+
168+
import asyncio
169+
170+
result = asyncio.get_event_loop().run_until_complete(
171+
tool.coroutine(tool_call_id="tc-1", query="test")
172+
)
173+
174+
assert result.status == "error"
175+
assert "not available" in result.content
176+
finally:
177+
available_client_side_tools.reset(token)
178+
179+
def test_tool_in_allowed_set_proceeds(self):
180+
"""When tool IS in the allowed set, it should NOT return an error.
181+
182+
We can't fully test execution (it would hit durable_interrupt),
183+
but we verify the availability check passes by patching the interrupt.
184+
"""
185+
token = available_client_side_tools.set({"my_tool"})
186+
try:
187+
from unittest.mock import AsyncMock, patch
188+
189+
from uipath.agent.models.agent import AgentClientSideToolResourceConfig
190+
191+
resource = AgentClientSideToolResourceConfig(
192+
name="my_tool",
193+
description="A test tool",
194+
input_schema={
195+
"type": "object",
196+
"properties": {"query": {"type": "string"}},
197+
},
198+
output_schema=None,
199+
)
200+
201+
from uipath_langchain.agent.tools.client_side_tool import (
202+
create_client_side_tool,
203+
)
204+
205+
tool = create_client_side_tool(resource)
206+
207+
import asyncio
208+
209+
# Patch durable_interrupt to avoid GraphInterrupt
210+
with (
211+
patch(
212+
"uipath_langchain.agent.tools.client_side_tool.durable_interrupt",
213+
side_effect=lambda fn: fn,
214+
),
215+
patch(
216+
"uipath_langchain.agent.tools.client_side_tool.mockable",
217+
side_effect=lambda **kw: lambda fn: fn,
218+
),
219+
):
220+
# Re-create tool after patching
221+
tool = create_client_side_tool(resource)
222+
result = asyncio.get_event_loop().run_until_complete(
223+
tool.coroutine(tool_call_id="tc-1", query="test")
224+
)
225+
# Should NOT be an error ToolMessage — it proceeded past the availability check
226+
if hasattr(result, "status"):
227+
assert result.status != "error"
228+
finally:
229+
available_client_side_tools.reset(token)
230+
231+
def test_none_allowed_set_permits_all(self):
232+
"""When available_client_side_tools is None (CAS default), all tools proceed."""
233+
token = available_client_side_tools.set(None)
234+
try:
235+
from uipath.agent.models.agent import AgentClientSideToolResourceConfig
236+
237+
resource = AgentClientSideToolResourceConfig(
238+
name="any_tool",
239+
description="A test tool",
240+
input_schema={
241+
"type": "object",
242+
"properties": {"q": {"type": "string"}},
243+
},
244+
output_schema=None,
245+
)
246+
247+
from unittest.mock import patch
248+
249+
from uipath_langchain.agent.tools.client_side_tool import (
250+
create_client_side_tool,
251+
)
252+
253+
with (
254+
patch(
255+
"uipath_langchain.agent.tools.client_side_tool.durable_interrupt",
256+
side_effect=lambda fn: fn,
257+
),
258+
patch(
259+
"uipath_langchain.agent.tools.client_side_tool.mockable",
260+
side_effect=lambda **kw: lambda fn: fn,
261+
),
262+
):
263+
tool = create_client_side_tool(resource)
264+
265+
import asyncio
266+
267+
result = asyncio.get_event_loop().run_until_complete(
268+
tool.coroutine(tool_call_id="tc-1", q="test")
269+
)
270+
if hasattr(result, "status"):
271+
assert result.status != "error"
272+
finally:
273+
available_client_side_tools.reset(token)

0 commit comments

Comments
 (0)