Skip to content

Commit 4b4a2cb

Browse files
committed
fix(mcp): retry connect calls on transient grpc errs
Signed-off-by: Samantha Coyle <sam@diagrid.io>
1 parent 2fd3237 commit 4b4a2cb

3 files changed

Lines changed: 197 additions & 15 deletions

File tree

ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515

1616
from __future__ import annotations
1717

18+
import asyncio
1819
import logging
20+
import time
1921
import uuid
2022
from typing import Optional, Set
2123

2224
from dapr.ext.workflow.aio.dapr_workflow_client import DaprWorkflowClient
23-
from dapr.ext.workflow.mcp import _MCP_METHOD_LIST_TOOLS, MCP_WORKFLOW_PREFIX, _DaprMCPClientBase
25+
from dapr.ext.workflow.mcp import (
26+
_MCP_METHOD_LIST_TOOLS,
27+
_SCHEDULE_RETRY_INTERVAL_SECONDS,
28+
MCP_WORKFLOW_PREFIX,
29+
_DaprMCPClientBase,
30+
_is_transient_schedule_error,
31+
)
2432
from dapr.ext.workflow.workflow_state import WorkflowStatus
2533

2634
logger = logging.getLogger(__name__)
@@ -84,15 +92,27 @@ async def connect(self, mcpserver_name: str) -> None:
8492

8593
logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)
8694

87-
await self._wf_client.schedule_new_workflow(
88-
workflow=workflow_name,
89-
input={'mcpServerName': mcpserver_name},
90-
instance_id=instance_id,
91-
)
92-
95+
deadline = time.monotonic() + self._timeout
96+
while True:
97+
try:
98+
await self._wf_client.schedule_new_workflow(
99+
workflow=workflow_name,
100+
input={'mcpServerName': mcpserver_name},
101+
instance_id=instance_id,
102+
)
103+
break
104+
except Exception as exc: # noqa: BLE001 — classified by helper
105+
if not _is_transient_schedule_error(exc) or time.monotonic() >= deadline:
106+
raise
107+
logger.debug(
108+
'schedule_new_workflow returned transient error %s; retrying', exc
109+
)
110+
await asyncio.sleep(_SCHEDULE_RETRY_INTERVAL_SECONDS)
111+
112+
remaining = max(deadline - time.monotonic(), 1.0)
93113
state = await self._wf_client.wait_for_workflow_completion(
94114
instance_id=instance_id,
95-
timeout_in_seconds=self._timeout,
115+
timeout_in_seconds=int(remaining),
96116
fetch_payloads=True,
97117
)
98118

ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@
3333

3434
import json
3535
import logging
36+
import time
3637
import uuid
3738
from dataclasses import dataclass, field
3839
from typing import Any, Dict, List, Optional, Set
3940

41+
import grpc
42+
4043
from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
4144
from dapr.ext.workflow.workflow_state import WorkflowStatus
4245

@@ -50,6 +53,33 @@
5053
_MCP_METHOD_LIST_TOOLS = '.ListTools'
5154
_MCP_METHOD_CALL_TOOL = '.CallTool'
5255

56+
# `dapr run` reports the sidecar ready when its HTTP port responds, but
57+
# MCPServer-derived workflows aren't registered until daprd finishes its
58+
# loadMCPServers init step. A schedule_new_workflow call inside that window
59+
# comes back as CANCELLED or UNAVAILABLE. Retry such failures within the
60+
# caller's timeout budget instead of surfacing them as hard failures.
61+
_TRANSIENT_GRPC_CODES = frozenset({
62+
grpc.StatusCode.CANCELLED,
63+
grpc.StatusCode.UNAVAILABLE,
64+
})
65+
_SCHEDULE_RETRY_INTERVAL_SECONDS = 0.5
66+
67+
68+
def _is_transient_schedule_error(exc: BaseException) -> bool:
69+
"""True if a schedule_new_workflow failure should be retried.
70+
71+
Walks ``__cause__`` so we catch both raw ``grpc.RpcError`` and any
72+
durabletask-layer wrapping.
73+
"""
74+
if isinstance(exc, grpc.RpcError):
75+
code = getattr(exc, 'code', None)
76+
if callable(code) and code() in _TRANSIENT_GRPC_CODES:
77+
return True
78+
cause = getattr(exc, '__cause__', None)
79+
if cause is not None and cause is not exc:
80+
return _is_transient_schedule_error(cause)
81+
return False
82+
5383

5484
# TODO(@sicoyle): see if I can use the mcp pkg class instead for this?
5585
@dataclass(frozen=True)
@@ -210,15 +240,27 @@ def connect(self, mcpserver_name: str) -> None:
210240

211241
logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)
212242

213-
self._wf_client.schedule_new_workflow(
214-
workflow=workflow_name,
215-
input={'mcpServerName': mcpserver_name},
216-
instance_id=instance_id,
217-
)
243+
deadline = time.monotonic() + self._timeout
244+
while True:
245+
try:
246+
self._wf_client.schedule_new_workflow(
247+
workflow=workflow_name,
248+
input={'mcpServerName': mcpserver_name},
249+
instance_id=instance_id,
250+
)
251+
break
252+
except Exception as exc: # noqa: BLE001 — classified by helper
253+
if not _is_transient_schedule_error(exc) or time.monotonic() >= deadline:
254+
raise
255+
logger.debug(
256+
'schedule_new_workflow returned transient error %s; retrying', exc
257+
)
258+
time.sleep(_SCHEDULE_RETRY_INTERVAL_SECONDS)
218259

260+
remaining = max(deadline - time.monotonic(), 1.0)
219261
state = self._wf_client.wait_for_workflow_completion(
220262
instance_id=instance_id,
221-
timeout_in_seconds=self._timeout,
263+
timeout_in_seconds=int(remaining),
222264
fetch_payloads=True,
223265
)
224266

ext/dapr-ext-workflow/tests/test_mcp_client.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,27 @@
1616
import json
1717
import unittest
1818
from datetime import datetime
19-
from unittest.mock import AsyncMock, MagicMock
19+
from unittest.mock import AsyncMock, MagicMock, patch
20+
21+
import grpc
2022

2123
from dapr.ext.workflow._durabletask import client
2224
from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient
2325
from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef
2426
from dapr.ext.workflow.workflow_state import WorkflowState
2527

2628

29+
class _StubRpcError(grpc.RpcError):
30+
"""Test double for grpc.RpcError with a configurable status code."""
31+
32+
def __init__(self, status_code: grpc.StatusCode):
33+
super().__init__()
34+
self._status_code = status_code
35+
36+
def code(self) -> grpc.StatusCode:
37+
return self._status_code
38+
39+
2740
def _make_completed_state(output_json: dict) -> WorkflowState:
2841
"""Create a WorkflowState that simulates a COMPLETED workflow."""
2942
inner = client.WorkflowState(
@@ -385,6 +398,113 @@ async def test_connect_caches_tools(self):
385398
self.assertEqual(tools[1].name, 'get_forecast')
386399

387400

401+
class TestDaprMCPClientConnectRetry(unittest.TestCase):
402+
"""Tests for connect()'s retry-on-transient-gRPC-error path."""
403+
404+
def test_retries_then_succeeds_on_cancelled(self):
405+
"""A CANCELLED schedule failure should be retried within the timeout budget."""
406+
mock_wf = MagicMock()
407+
mock_wf.schedule_new_workflow.side_effect = [
408+
_StubRpcError(grpc.StatusCode.CANCELLED),
409+
_StubRpcError(grpc.StatusCode.CANCELLED),
410+
'inst-1',
411+
]
412+
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
413+
SAMPLE_LIST_TOOLS_RESPONSE
414+
)
415+
416+
mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
417+
with patch('dapr.ext.workflow.mcp.time.sleep'):
418+
mcp_client.connect('weather')
419+
420+
self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3)
421+
self.assertEqual(len(mcp_client.get_all_tools()), 2)
422+
423+
def test_retries_on_unavailable(self):
424+
"""UNAVAILABLE should also be treated as transient."""
425+
mock_wf = MagicMock()
426+
mock_wf.schedule_new_workflow.side_effect = [
427+
_StubRpcError(grpc.StatusCode.UNAVAILABLE),
428+
'inst-1',
429+
]
430+
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
431+
SAMPLE_LIST_TOOLS_RESPONSE
432+
)
433+
434+
mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
435+
with patch('dapr.ext.workflow.mcp.time.sleep'):
436+
mcp_client.connect('weather')
437+
438+
self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2)
439+
440+
def test_non_transient_propagates_immediately(self):
441+
"""A non-CANCELLED/UNAVAILABLE error must not be retried."""
442+
mock_wf = MagicMock()
443+
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(
444+
grpc.StatusCode.PERMISSION_DENIED
445+
)
446+
447+
mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
448+
with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock:
449+
with self.assertRaises(grpc.RpcError):
450+
mcp_client.connect('weather')
451+
452+
self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1)
453+
sleep_mock.assert_not_called()
454+
455+
def test_deadline_exhausted_raises_last_error(self):
456+
"""When the timeout budget runs out mid-retry, propagate the last error."""
457+
mock_wf = MagicMock()
458+
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(
459+
grpc.StatusCode.CANCELLED
460+
)
461+
462+
mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
463+
# Patch monotonic to advance past the deadline immediately so we don't
464+
# actually sleep for a second in tests.
465+
with patch('dapr.ext.workflow.mcp.time.sleep'), patch(
466+
'dapr.ext.workflow.mcp.time.monotonic',
467+
side_effect=[0.0, 2.0],
468+
):
469+
with self.assertRaises(grpc.RpcError):
470+
mcp_client.connect('weather')
471+
472+
473+
class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase):
474+
"""Async counterpart of TestDaprMCPClientConnectRetry."""
475+
476+
async def test_retries_then_succeeds_on_cancelled(self):
477+
mock_wf = AsyncMock()
478+
mock_wf.schedule_new_workflow.side_effect = [
479+
_StubRpcError(grpc.StatusCode.CANCELLED),
480+
'inst-1',
481+
]
482+
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
483+
SAMPLE_LIST_TOOLS_RESPONSE
484+
)
485+
486+
mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
487+
with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()):
488+
await mcp_client.connect('weather')
489+
490+
self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2)
491+
self.assertEqual(len(mcp_client.get_all_tools()), 2)
492+
493+
async def test_deadline_exhausted_raises(self):
494+
mock_wf = AsyncMock()
495+
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(
496+
grpc.StatusCode.CANCELLED
497+
)
498+
499+
mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
500+
with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), patch(
501+
'dapr.ext.workflow.aio.mcp.time.monotonic',
502+
side_effect=[0.0, 2.0],
503+
):
504+
with self.assertRaises(grpc.RpcError):
505+
await mcp_client.connect('weather')
506+
507+
388508
class TestMCPWorkflowPrefix(unittest.TestCase):
389509
"""Tests for the workflow naming constant."""
390510

0 commit comments

Comments
 (0)