|
| 1 | +import io |
1 | 2 | import json |
2 | 3 | import os |
| 4 | +from unittest.mock import AsyncMock, MagicMock, patch |
3 | 5 |
|
4 | 6 | import pytest |
5 | 7 | from haystack.components.agents import Agent |
|
13 | 15 | MCPTool, |
14 | 16 | StdioServerInfo, |
15 | 17 | ) |
| 18 | +from haystack_integrations.tools.mcp.mcp_tool import StdioClient |
16 | 19 |
|
17 | 20 | from .mcp_memory_transport import InMemoryServerInfo |
18 | 21 | from .mcp_servers_fixtures import calculator_mcp, echo_mcp |
@@ -197,6 +200,42 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup): |
197 | 200 | assert new_tool._inputs_from_state == {"state_a": "a"} |
198 | 201 | assert new_tool._outputs_to_state == {"result": {"source": "output"}} |
199 | 202 |
|
| 203 | + @pytest.mark.asyncio |
| 204 | + @pytest.mark.parametrize( |
| 205 | + "fileno_side_effect,fileno_return_value,notebook_environment", |
| 206 | + [ |
| 207 | + (io.UnsupportedOperation("fileno"), None, True), |
| 208 | + (None, 2, False), |
| 209 | + ], |
| 210 | + ) |
| 211 | + async def test_stdio_client_stderr_handling(self, fileno_side_effect, fileno_return_value, notebook_environment): |
| 212 | + """Test that StdioClient uses sys.stderr in terminals and falls back to a file in notebooks.""" |
| 213 | + client = StdioClient(command="echo", args=["hello"]) |
| 214 | + |
| 215 | + mock_stderr = MagicMock() |
| 216 | + mock_stderr.fileno.side_effect = fileno_side_effect |
| 217 | + mock_stderr.fileno.return_value = fileno_return_value |
| 218 | + |
| 219 | + with ( |
| 220 | + patch.object(client, "exit_stack") as mock_stack, |
| 221 | + patch("haystack_integrations.tools.mcp.mcp_tool.stdio_client") as mock_stdio_client, |
| 222 | + patch("haystack_integrations.tools.mcp.mcp_tool.sys") as mock_sys, |
| 223 | + patch.object(client, "_initialize_session_with_transport", new_callable=AsyncMock) as mock_init, |
| 224 | + ): |
| 225 | + mock_sys.stderr = mock_stderr |
| 226 | + mock_stack.enter_async_context = AsyncMock(return_value=(MagicMock(), MagicMock())) |
| 227 | + mock_init.return_value = [] |
| 228 | + |
| 229 | + await client.connect() |
| 230 | + |
| 231 | + _, kwargs = mock_stdio_client.call_args |
| 232 | + errlog = kwargs["errlog"] |
| 233 | + if notebook_environment: |
| 234 | + assert errlog is not mock_stderr |
| 235 | + assert hasattr(errlog, "write") |
| 236 | + else: |
| 237 | + assert errlog is mock_stderr |
| 238 | + |
200 | 239 | @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") |
201 | 240 | @pytest.mark.integration |
202 | 241 | def test_pipeline_warmup_with_mcp_tool(self): |
|
0 commit comments