|
16 | 16 | import json |
17 | 17 | import unittest |
18 | 18 | from datetime import datetime |
19 | | -from unittest.mock import AsyncMock, MagicMock |
| 19 | +from unittest.mock import AsyncMock, MagicMock, patch |
| 20 | + |
| 21 | +import grpc |
20 | 22 |
|
21 | 23 | from dapr.ext.workflow._durabletask import client |
22 | 24 | from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient |
23 | 25 | from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef |
24 | 26 | from dapr.ext.workflow.workflow_state import WorkflowState |
25 | 27 |
|
26 | 28 |
|
| 29 | +class _StubRpcError(grpc.RpcError): |
| 30 | + """Test double for grpc.RpcError with a configurable status code.""" |
| 31 | + |
| 32 | + def __init__(self, status_code: grpc.StatusCode): |
| 33 | + super().__init__() |
| 34 | + self._status_code = status_code |
| 35 | + |
| 36 | + def code(self) -> grpc.StatusCode: |
| 37 | + return self._status_code |
| 38 | + |
| 39 | + |
27 | 40 | def _make_completed_state(output_json: dict) -> WorkflowState: |
28 | 41 | """Create a WorkflowState that simulates a COMPLETED workflow.""" |
29 | 42 | inner = client.WorkflowState( |
@@ -385,6 +398,113 @@ async def test_connect_caches_tools(self): |
385 | 398 | self.assertEqual(tools[1].name, 'get_forecast') |
386 | 399 |
|
387 | 400 |
|
| 401 | +class TestDaprMCPClientConnectRetry(unittest.TestCase): |
| 402 | + """Tests for connect()'s retry-on-transient-gRPC-error path.""" |
| 403 | + |
| 404 | + def test_retries_then_succeeds_on_cancelled(self): |
| 405 | + """A CANCELLED schedule failure should be retried within the timeout budget.""" |
| 406 | + mock_wf = MagicMock() |
| 407 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 408 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 409 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 410 | + 'inst-1', |
| 411 | + ] |
| 412 | + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( |
| 413 | + SAMPLE_LIST_TOOLS_RESPONSE |
| 414 | + ) |
| 415 | + |
| 416 | + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 417 | + with patch('dapr.ext.workflow.mcp.time.sleep'): |
| 418 | + mcp_client.connect('weather') |
| 419 | + |
| 420 | + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3) |
| 421 | + self.assertEqual(len(mcp_client.get_all_tools()), 2) |
| 422 | + |
| 423 | + def test_retries_on_unavailable(self): |
| 424 | + """UNAVAILABLE should also be treated as transient.""" |
| 425 | + mock_wf = MagicMock() |
| 426 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 427 | + _StubRpcError(grpc.StatusCode.UNAVAILABLE), |
| 428 | + 'inst-1', |
| 429 | + ] |
| 430 | + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( |
| 431 | + SAMPLE_LIST_TOOLS_RESPONSE |
| 432 | + ) |
| 433 | + |
| 434 | + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 435 | + with patch('dapr.ext.workflow.mcp.time.sleep'): |
| 436 | + mcp_client.connect('weather') |
| 437 | + |
| 438 | + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2) |
| 439 | + |
| 440 | + def test_non_transient_propagates_immediately(self): |
| 441 | + """A non-CANCELLED/UNAVAILABLE error must not be retried.""" |
| 442 | + mock_wf = MagicMock() |
| 443 | + mock_wf.schedule_new_workflow.side_effect = _StubRpcError( |
| 444 | + grpc.StatusCode.PERMISSION_DENIED |
| 445 | + ) |
| 446 | + |
| 447 | + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 448 | + with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock: |
| 449 | + with self.assertRaises(grpc.RpcError): |
| 450 | + mcp_client.connect('weather') |
| 451 | + |
| 452 | + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1) |
| 453 | + sleep_mock.assert_not_called() |
| 454 | + |
| 455 | + def test_deadline_exhausted_raises_last_error(self): |
| 456 | + """When the timeout budget runs out mid-retry, propagate the last error.""" |
| 457 | + mock_wf = MagicMock() |
| 458 | + mock_wf.schedule_new_workflow.side_effect = _StubRpcError( |
| 459 | + grpc.StatusCode.CANCELLED |
| 460 | + ) |
| 461 | + |
| 462 | + mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) |
| 463 | + # Patch monotonic to advance past the deadline immediately so we don't |
| 464 | + # actually sleep for a second in tests. |
| 465 | + with patch('dapr.ext.workflow.mcp.time.sleep'), patch( |
| 466 | + 'dapr.ext.workflow.mcp.time.monotonic', |
| 467 | + side_effect=[0.0, 2.0], |
| 468 | + ): |
| 469 | + with self.assertRaises(grpc.RpcError): |
| 470 | + mcp_client.connect('weather') |
| 471 | + |
| 472 | + |
| 473 | +class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase): |
| 474 | + """Async counterpart of TestDaprMCPClientConnectRetry.""" |
| 475 | + |
| 476 | + async def test_retries_then_succeeds_on_cancelled(self): |
| 477 | + mock_wf = AsyncMock() |
| 478 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 479 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 480 | + 'inst-1', |
| 481 | + ] |
| 482 | + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( |
| 483 | + SAMPLE_LIST_TOOLS_RESPONSE |
| 484 | + ) |
| 485 | + |
| 486 | + mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 487 | + with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()): |
| 488 | + await mcp_client.connect('weather') |
| 489 | + |
| 490 | + self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2) |
| 491 | + self.assertEqual(len(mcp_client.get_all_tools()), 2) |
| 492 | + |
| 493 | + async def test_deadline_exhausted_raises(self): |
| 494 | + mock_wf = AsyncMock() |
| 495 | + mock_wf.schedule_new_workflow.side_effect = _StubRpcError( |
| 496 | + grpc.StatusCode.CANCELLED |
| 497 | + ) |
| 498 | + |
| 499 | + mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) |
| 500 | + with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), patch( |
| 501 | + 'dapr.ext.workflow.aio.mcp.time.monotonic', |
| 502 | + side_effect=[0.0, 2.0], |
| 503 | + ): |
| 504 | + with self.assertRaises(grpc.RpcError): |
| 505 | + await mcp_client.connect('weather') |
| 506 | + |
| 507 | + |
388 | 508 | class TestMCPWorkflowPrefix(unittest.TestCase): |
389 | 509 | """Tests for the workflow naming constant.""" |
390 | 510 |
|
|
0 commit comments