Skip to content

Commit c9f2f54

Browse files
author
Valentina Bojan
committed
fix: fix tests
1 parent 1a170df commit c9f2f54

2 files changed

Lines changed: 27 additions & 19 deletions

File tree

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,15 @@ def _process_llm_escalation_response(
331331
return {}
332332

333333
reviewed_tool_calls_obj = json.loads(reviewed_outputs_json)
334-
reviewed_tool_calls_list = reviewed_tool_calls_obj.get("tool_calls")
335-
336-
if not reviewed_tool_calls_list:
334+
if not reviewed_tool_calls_obj:
337335
return {}
338336

337+
reviewed_tool_calls_list = (
338+
reviewed_tool_calls_obj.get("tool_calls")
339+
if "tool_calls" in reviewed_tool_calls_obj
340+
else None
341+
)
342+
339343
# Track if tool calls were successfully processed
340344
tool_calls_processed = False
341345

tests/agent/guardrails/actions/test_escalate_action.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ async def test_node_interrupts_with_correct_message_data(
159159
assert call_args.data["GuardrailResult"] == "Validation failed"
160160

161161
if stage == ExecutionStage.PRE_EXECUTION:
162-
assert call_args.data["Inputs"] == "Test message"
162+
assert call_args.data["Inputs"] == '"Test message"'
163163
assert "Outputs" not in call_args.data
164164
else:
165-
assert call_args.data["Inputs"] == "Test message"
166-
assert call_args.data["Outputs"] == "Output message"
165+
assert call_args.data["Inputs"] == '"Test message"'
166+
assert call_args.data["Outputs"] == '"Output message"'
167167

168168
@pytest.mark.asyncio
169169
@patch("uipath_langchain.agent.guardrails.actions.escalate_action.interrupt")
@@ -222,7 +222,7 @@ async def test_node_post_agent_interrupts_with_correct_agent_result_data(
222222
assert call_args.data["ExecutionStage"] == "PostExecution"
223223
assert call_args.data["GuardrailResult"] == "Validation failed"
224224

225-
assert call_args.data["Inputs"] == "User prompt message"
225+
assert call_args.data["Inputs"] == '"User prompt message"'
226226
assert call_args.data["Outputs"] == '{"ok": true}'
227227

228228
@pytest.mark.asyncio
@@ -489,12 +489,13 @@ async def test_post_execution_ai_message_with_tool_calls_extraction(
489489

490490
# Verify interrupt was called with tool calls (name and args) in Outputs and Inputs
491491
call_args = mock_interrupt.call_args[0][0]
492-
assert call_args.data["Inputs"] == "Input message"
492+
assert call_args.data["Inputs"] == '"Input message"'
493493
tool_outputs = call_args.data["Outputs"]
494-
parsed = json.loads(tool_outputs)
495-
assert len(parsed) == 1 # Tool call data with name and args
496-
assert parsed[0]["name"] == "test_tool"
497-
assert parsed[0]["args"] == {"content": {"input": "test"}}
494+
parsed_obj = json.loads(tool_outputs)
495+
parsed_list = parsed_obj["tool_calls"]
496+
assert len(parsed_list) == 1 # Tool call data with name and args
497+
assert parsed_list[0]["name"] == "test_tool"
498+
assert parsed_list[0]["args"] == {"content": {"input": "test"}}
498499

499500
@pytest.mark.asyncio
500501
@pytest.mark.parametrize(
@@ -614,7 +615,9 @@ async def test_post_execution_ai_message_with_reviewed_outputs_and_tool_calls(
614615
guardrail.description = "Test description"
615616

616617
reviewed_tool_args = {"updated": "tool_content"}
617-
reviewed_outputs = [{"name": "test_tool", "args": reviewed_tool_args}]
618+
reviewed_outputs = {
619+
"tool_calls": [{"name": "test_tool", "args": reviewed_tool_args}]
620+
}
618621
mock_escalation_result = MagicMock()
619622
mock_escalation_result.action = "Approve"
620623
mock_escalation_result.data = {"ReviewedOutputs": json.dumps(reviewed_outputs)}
@@ -822,7 +825,7 @@ async def test_node_interrupts_with_correct_data_pre_tool(self, mock_interrupt):
822825
call_args = mock_interrupt.call_args[0][0]
823826

824827
assert call_args.data["GuardrailName"] == "Test Guardrail"
825-
assert call_args.data["Component"] == "tool"
828+
assert call_args.data["Component"] == "test_tool"
826829
assert call_args.data["ExecutionStage"] == "PreExecution"
827830
assert call_args.data["Inputs"] == '{"input": "test"}'
828831

@@ -1422,7 +1425,7 @@ async def test_extract_llm_content_pre_execution_empty_content(self):
14221425
ai_message, ExecutionStage.PRE_EXECUTION
14231426
)
14241427

1425-
assert result == ""
1428+
assert result == '""'
14261429

14271430
@pytest.mark.asyncio
14281431
async def test_extract_llm_content_post_execution_tool_calls_no_content_field(self):
@@ -1447,11 +1450,12 @@ async def test_extract_llm_content_post_execution_tool_calls_no_content_field(se
14471450
)
14481451

14491452
assert isinstance(result, str)
1450-
parsed = json.loads(result)
1453+
parsed_obj = json.loads(result)
1454+
parsed_list = parsed_obj["tool_calls"]
14511455
# Should extract tool call data with name and args
1452-
assert len(parsed) == 1
1453-
assert parsed[0]["name"] == "tool_without_content"
1454-
assert parsed[0]["args"] == {"param": "value"}
1456+
assert len(parsed_list) == 1
1457+
assert parsed_list[0]["name"] == "tool_without_content"
1458+
assert parsed_list[0]["args"] == {"param": "value"}
14551459

14561460
@pytest.mark.asyncio
14571461
async def test_validate_message_count_empty_messages_raises_exception(self):

0 commit comments

Comments
 (0)