|
16 | 16 | import json |
17 | 17 | import unittest |
18 | 18 | from datetime import datetime |
19 | | -from unittest.mock import AsyncMock, MagicMock |
| 19 | +from unittest.mock import AsyncMock, MagicMock, patch |
20 | 20 |
|
| 21 | +import grpc |
21 | 22 | from dapr.ext.workflow._durabletask import client |
22 | 23 | from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient |
23 | 24 | from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef |
24 | 25 | from dapr.ext.workflow.workflow_state import WorkflowState |
25 | 26 |
|
26 | 27 |
|
| 28 | +class _StubRpcError(grpc.RpcError): |
| 29 | + """Test double for grpc.RpcError with a configurable status code.""" |
| 30 | + |
| 31 | + def __init__(self, status_code: grpc.StatusCode): |
| 32 | + super().__init__() |
| 33 | + self._status_code = status_code |
| 34 | + |
| 35 | + def code(self) -> grpc.StatusCode: |
| 36 | + return self._status_code |
| 37 | + |
| 38 | + |
27 | 39 | def _make_completed_state(output_json: dict) -> WorkflowState: |
28 | 40 | """Create a WorkflowState that simulates a COMPLETED workflow.""" |
29 | 41 | inner = client.WorkflowState( |
@@ -385,6 +397,159 @@ async def test_connect_caches_tools(self): |
385 | 397 | self.assertEqual(tools[1].name, 'get_forecast') |
386 | 398 |
|
387 | 399 |
|
| 400 | +class TestDaprMCPClientConnectRetry(unittest.TestCase): |
| 401 | + """Tests for connect()'s retry-on-transient-gRPC-error path.""" |
| 402 | + |
| 403 | + def test_retries_then_succeeds_on_cancelled(self): |
| 404 | + """A CANCELLED schedule failure should be retried within the timeout budget.""" |
| 405 | + mock_wf = MagicMock() |
| 406 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 407 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 408 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 409 | + 'inst-1', |
| 410 | + ] |
| 411 | + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( |
| 412 | + SAMPLE_LIST_TOOLS_RESPONSE |
| 413 | + ) |
| 414 | + |
| 415 | + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 416 | + with patch('dapr.ext.workflow.mcp.time.sleep'): |
| 417 | + mcp_client.connect('weather') |
| 418 | + |
| 419 | + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3) |
| 420 | + self.assertEqual(len(mcp_client.get_all_tools()), 2) |
| 421 | + |
| 422 | + def test_retries_on_unavailable(self): |
| 423 | + """UNAVAILABLE should also be treated as transient.""" |
| 424 | + mock_wf = MagicMock() |
| 425 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 426 | + _StubRpcError(grpc.StatusCode.UNAVAILABLE), |
| 427 | + 'inst-1', |
| 428 | + ] |
| 429 | + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( |
| 430 | + SAMPLE_LIST_TOOLS_RESPONSE |
| 431 | + ) |
| 432 | + |
| 433 | + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 434 | + with patch('dapr.ext.workflow.mcp.time.sleep'): |
| 435 | + mcp_client.connect('weather') |
| 436 | + |
| 437 | + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2) |
| 438 | + |
| 439 | + def test_non_transient_propagates_immediately(self): |
| 440 | + """A non-CANCELLED/UNAVAILABLE error must not be retried.""" |
| 441 | + mock_wf = MagicMock() |
| 442 | + mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.PERMISSION_DENIED) |
| 443 | + |
| 444 | + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 445 | + with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock: |
| 446 | + with self.assertRaises(grpc.RpcError): |
| 447 | + mcp_client.connect('weather') |
| 448 | + |
| 449 | + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1) |
| 450 | + sleep_mock.assert_not_called() |
| 451 | + |
| 452 | + def test_deadline_exhausted_raises_last_error(self): |
| 453 | + """When the timeout budget runs out mid-retry, propagate the last error.""" |
| 454 | + mock_wf = MagicMock() |
| 455 | + mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED) |
| 456 | + |
| 457 | + mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) |
| 458 | + # Patch monotonic to advance past the deadline immediately so we don't |
| 459 | + # actually sleep for a second in tests. |
| 460 | + with ( |
| 461 | + patch('dapr.ext.workflow.mcp.time.sleep'), |
| 462 | + patch( |
| 463 | + 'dapr.ext.workflow.mcp.time.monotonic', |
| 464 | + side_effect=[0.0, 2.0], |
| 465 | + ), |
| 466 | + ): |
| 467 | + with self.assertRaises(grpc.RpcError): |
| 468 | + mcp_client.connect('weather') |
| 469 | + |
| 470 | + def test_budget_exhausted_after_schedule_succeeds(self): |
| 471 | + """If retries burn the budget but schedule eventually succeeds, raise |
| 472 | + without calling wait_for_workflow_completion (timeout=0 means |
| 473 | + 'wait forever' in the underlying client).""" |
| 474 | + mock_wf = MagicMock() |
| 475 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 476 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 477 | + 'inst-1', |
| 478 | + ] |
| 479 | + |
| 480 | + mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) |
| 481 | + # monotonic: 0.0 → deadline = 1.0; 0.4 → sleep_for = 0.5 (still in budget); |
| 482 | + # 2.0 → post-loop remaining = -1.0 → raise. |
| 483 | + with ( |
| 484 | + patch('dapr.ext.workflow.mcp.time.sleep'), |
| 485 | + patch( |
| 486 | + 'dapr.ext.workflow.mcp.time.monotonic', |
| 487 | + side_effect=[0.0, 0.4, 2.0], |
| 488 | + ), |
| 489 | + ): |
| 490 | + with self.assertRaises(RuntimeError) as ctx: |
| 491 | + mcp_client.connect('weather') |
| 492 | + self.assertIn('timed out', str(ctx.exception)) |
| 493 | + mock_wf.wait_for_workflow_completion.assert_not_called() |
| 494 | + |
| 495 | + |
| 496 | +class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase): |
| 497 | + """Async counterpart of TestDaprMCPClientConnectRetry.""" |
| 498 | + |
| 499 | + async def test_retries_then_succeeds_on_cancelled(self): |
| 500 | + mock_wf = AsyncMock() |
| 501 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 502 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 503 | + 'inst-1', |
| 504 | + ] |
| 505 | + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( |
| 506 | + SAMPLE_LIST_TOOLS_RESPONSE |
| 507 | + ) |
| 508 | + |
| 509 | + mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) |
| 510 | + with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()): |
| 511 | + await mcp_client.connect('weather') |
| 512 | + |
| 513 | + self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2) |
| 514 | + self.assertEqual(len(mcp_client.get_all_tools()), 2) |
| 515 | + |
| 516 | + async def test_deadline_exhausted_raises(self): |
| 517 | + mock_wf = AsyncMock() |
| 518 | + mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED) |
| 519 | + |
| 520 | + mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) |
| 521 | + with ( |
| 522 | + patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), |
| 523 | + patch( |
| 524 | + 'dapr.ext.workflow.aio.mcp.time.monotonic', |
| 525 | + side_effect=[0.0, 2.0], |
| 526 | + ), |
| 527 | + ): |
| 528 | + with self.assertRaises(grpc.RpcError): |
| 529 | + await mcp_client.connect('weather') |
| 530 | + |
| 531 | + async def test_budget_exhausted_after_schedule_succeeds(self): |
| 532 | + """Async mirror of the fail-fast-after-schedule-success guard.""" |
| 533 | + mock_wf = AsyncMock() |
| 534 | + mock_wf.schedule_new_workflow.side_effect = [ |
| 535 | + _StubRpcError(grpc.StatusCode.CANCELLED), |
| 536 | + 'inst-1', |
| 537 | + ] |
| 538 | + |
| 539 | + mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) |
| 540 | + with ( |
| 541 | + patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), |
| 542 | + patch( |
| 543 | + 'dapr.ext.workflow.aio.mcp.time.monotonic', |
| 544 | + side_effect=[0.0, 0.4, 2.0], |
| 545 | + ), |
| 546 | + ): |
| 547 | + with self.assertRaises(RuntimeError) as ctx: |
| 548 | + await mcp_client.connect('weather') |
| 549 | + self.assertIn('timed out', str(ctx.exception)) |
| 550 | + mock_wf.wait_for_workflow_completion.assert_not_awaited() |
| 551 | + |
| 552 | + |
388 | 553 | class TestMCPWorkflowPrefix(unittest.TestCase): |
389 | 554 | """Tests for the workflow naming constant.""" |
390 | 555 |
|
|
0 commit comments