Skip to content

Commit 8571c3e

Browse files
authored
fix(mcp): retry connect calls on transient grpc errs (#1062)
* fix(mcp): retry connect calls on transient grpc errs Signed-off-by: Samantha Coyle <sam@diagrid.io> * style: comment cleanup Signed-off-by: Samantha Coyle <sam@diagrid.io> * fix: address copilot feedback Signed-off-by: Samantha Coyle <sam@diagrid.io> * style: appease linter Signed-off-by: Samantha Coyle <sam@diagrid.io> --------- Signed-off-by: Samantha Coyle <sam@diagrid.io>
1 parent 8840b9a commit 8571c3e

3 files changed

Lines changed: 255 additions & 16 deletions

File tree

ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515

1616
from __future__ import annotations
1717

18+
import asyncio
1819
import logging
20+
import time
1921
import uuid
2022
from typing import Optional, Set
2123

2224
from dapr.ext.workflow.aio.dapr_workflow_client import DaprWorkflowClient
23-
from dapr.ext.workflow.mcp import _MCP_METHOD_LIST_TOOLS, MCP_WORKFLOW_PREFIX, _DaprMCPClientBase
25+
from dapr.ext.workflow.mcp import (
26+
_MCP_METHOD_LIST_TOOLS,
27+
_SCHEDULE_RETRY_INTERVAL_SECONDS,
28+
MCP_WORKFLOW_PREFIX,
29+
_DaprMCPClientBase,
30+
_is_transient_schedule_error,
31+
)
2432
from dapr.ext.workflow.workflow_state import WorkflowStatus
2533

2634
logger = logging.getLogger(__name__)
@@ -84,15 +92,35 @@ async def connect(self, mcpserver_name: str) -> None:
8492

8593
logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)
8694

87-
await self._wf_client.schedule_new_workflow(
88-
workflow=workflow_name,
89-
input={'mcpServerName': mcpserver_name},
90-
instance_id=instance_id,
91-
)
92-
95+
deadline = time.monotonic() + self._timeout
96+
while True:
97+
try:
98+
await self._wf_client.schedule_new_workflow(
99+
workflow=workflow_name,
100+
input={'mcpServerName': mcpserver_name},
101+
instance_id=instance_id,
102+
)
103+
break
104+
except Exception as exc: # noqa: BLE001 — classified by helper
105+
if not _is_transient_schedule_error(exc):
106+
raise
107+
sleep_for = min(_SCHEDULE_RETRY_INTERVAL_SECONDS, deadline - time.monotonic())
108+
if sleep_for <= 0:
109+
raise
110+
logger.debug('schedule_new_workflow returned transient error %s; retrying', exc)
111+
await asyncio.sleep(sleep_for)
112+
113+
remaining = deadline - time.monotonic()
114+
if remaining <= 0:
115+
raise RuntimeError(
116+
f"ListTools workflow for MCPServer '{mcpserver_name}' "
117+
f'timed out after {self._timeout}s'
118+
)
119+
# wait_for_workflow_completion treats timeout=0 as "wait forever",
120+
# so floor the gRPC timeout at 1s when sub-second remaining survives.
93121
state = await self._wf_client.wait_for_workflow_completion(
94122
instance_id=instance_id,
95-
timeout_in_seconds=self._timeout,
123+
timeout_in_seconds=max(int(remaining), 1),
96124
fetch_payloads=True,
97125
)
98126

ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333

3434
import json
3535
import logging
36+
import time
3637
import uuid
3738
from dataclasses import dataclass, field
3839
from typing import Any, Dict, List, Optional, Set
3940

41+
import grpc
4042
from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
4143
from dapr.ext.workflow.workflow_state import WorkflowStatus
4244

@@ -50,6 +52,30 @@
5052
_MCP_METHOD_LIST_TOOLS = '.ListTools'
5153
_MCP_METHOD_CALL_TOOL = '.CallTool'
5254

55+
_TRANSIENT_GRPC_CODES = frozenset(
56+
{
57+
grpc.StatusCode.CANCELLED,
58+
grpc.StatusCode.UNAVAILABLE,
59+
}
60+
)
61+
_SCHEDULE_RETRY_INTERVAL_SECONDS = 0.5
62+
63+
64+
def _is_transient_schedule_error(exc: BaseException) -> bool:
65+
"""True if a schedule_new_workflow failure should be retried.
66+
67+
Walks ``__cause__`` so we catch both raw ``grpc.RpcError`` and any
68+
durabletask-layer wrapping.
69+
"""
70+
if isinstance(exc, grpc.RpcError):
71+
code = getattr(exc, 'code', None)
72+
if callable(code) and code() in _TRANSIENT_GRPC_CODES:
73+
return True
74+
cause = getattr(exc, '__cause__', None)
75+
if cause is not None and cause is not exc:
76+
return _is_transient_schedule_error(cause)
77+
return False
78+
5379

5480
# TODO(@sicoyle): see if I can use the mcp pkg class instead for this?
5581
@dataclass(frozen=True)
@@ -210,15 +236,35 @@ def connect(self, mcpserver_name: str) -> None:
210236

211237
logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)
212238

213-
self._wf_client.schedule_new_workflow(
214-
workflow=workflow_name,
215-
input={'mcpServerName': mcpserver_name},
216-
instance_id=instance_id,
217-
)
218-
239+
deadline = time.monotonic() + self._timeout
240+
while True:
241+
try:
242+
self._wf_client.schedule_new_workflow(
243+
workflow=workflow_name,
244+
input={'mcpServerName': mcpserver_name},
245+
instance_id=instance_id,
246+
)
247+
break
248+
except Exception as exc: # noqa: BLE001 — classified by helper
249+
if not _is_transient_schedule_error(exc):
250+
raise
251+
sleep_for = min(_SCHEDULE_RETRY_INTERVAL_SECONDS, deadline - time.monotonic())
252+
if sleep_for <= 0:
253+
raise
254+
logger.debug('schedule_new_workflow returned transient error %s; retrying', exc)
255+
time.sleep(sleep_for)
256+
257+
remaining = deadline - time.monotonic()
258+
if remaining <= 0:
259+
raise RuntimeError(
260+
f"ListTools workflow for MCPServer '{mcpserver_name}' "
261+
f'timed out after {self._timeout}s'
262+
)
263+
# wait_for_workflow_completion treats timeout=0 as "wait forever",
264+
# so floor the gRPC timeout at 1s when sub-second remaining survives.
219265
state = self._wf_client.wait_for_workflow_completion(
220266
instance_id=instance_id,
221-
timeout_in_seconds=self._timeout,
267+
timeout_in_seconds=max(int(remaining), 1),
222268
fetch_payloads=True,
223269
)
224270

ext/dapr-ext-workflow/tests/test_mcp_client.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,26 @@
1616
import json
1717
import unittest
1818
from datetime import datetime
19-
from unittest.mock import AsyncMock, MagicMock
19+
from unittest.mock import AsyncMock, MagicMock, patch
2020

21+
import grpc
2122
from dapr.ext.workflow._durabletask import client
2223
from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient
2324
from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef
2425
from dapr.ext.workflow.workflow_state import WorkflowState
2526

2627

28+
class _StubRpcError(grpc.RpcError):
29+
"""Test double for grpc.RpcError with a configurable status code."""
30+
31+
def __init__(self, status_code: grpc.StatusCode):
32+
super().__init__()
33+
self._status_code = status_code
34+
35+
def code(self) -> grpc.StatusCode:
36+
return self._status_code
37+
38+
2739
def _make_completed_state(output_json: dict) -> WorkflowState:
2840
"""Create a WorkflowState that simulates a COMPLETED workflow."""
2941
inner = client.WorkflowState(
@@ -385,6 +397,159 @@ async def test_connect_caches_tools(self):
385397
self.assertEqual(tools[1].name, 'get_forecast')
386398

387399

400+
class TestDaprMCPClientConnectRetry(unittest.TestCase):
401+
"""Tests for connect()'s retry-on-transient-gRPC-error path."""
402+
403+
def test_retries_then_succeeds_on_cancelled(self):
404+
"""A CANCELLED schedule failure should be retried within the timeout budget."""
405+
mock_wf = MagicMock()
406+
mock_wf.schedule_new_workflow.side_effect = [
407+
_StubRpcError(grpc.StatusCode.CANCELLED),
408+
_StubRpcError(grpc.StatusCode.CANCELLED),
409+
'inst-1',
410+
]
411+
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
412+
SAMPLE_LIST_TOOLS_RESPONSE
413+
)
414+
415+
mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
416+
with patch('dapr.ext.workflow.mcp.time.sleep'):
417+
mcp_client.connect('weather')
418+
419+
self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3)
420+
self.assertEqual(len(mcp_client.get_all_tools()), 2)
421+
422+
def test_retries_on_unavailable(self):
423+
"""UNAVAILABLE should also be treated as transient."""
424+
mock_wf = MagicMock()
425+
mock_wf.schedule_new_workflow.side_effect = [
426+
_StubRpcError(grpc.StatusCode.UNAVAILABLE),
427+
'inst-1',
428+
]
429+
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
430+
SAMPLE_LIST_TOOLS_RESPONSE
431+
)
432+
433+
mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
434+
with patch('dapr.ext.workflow.mcp.time.sleep'):
435+
mcp_client.connect('weather')
436+
437+
self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2)
438+
439+
def test_non_transient_propagates_immediately(self):
440+
"""A non-CANCELLED/UNAVAILABLE error must not be retried."""
441+
mock_wf = MagicMock()
442+
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.PERMISSION_DENIED)
443+
444+
mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
445+
with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock:
446+
with self.assertRaises(grpc.RpcError):
447+
mcp_client.connect('weather')
448+
449+
self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1)
450+
sleep_mock.assert_not_called()
451+
452+
def test_deadline_exhausted_raises_last_error(self):
453+
"""When the timeout budget runs out mid-retry, propagate the last error."""
454+
mock_wf = MagicMock()
455+
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED)
456+
457+
mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
458+
# Patch monotonic to advance past the deadline immediately so we don't
459+
# actually sleep for a second in tests.
460+
with (
461+
patch('dapr.ext.workflow.mcp.time.sleep'),
462+
patch(
463+
'dapr.ext.workflow.mcp.time.monotonic',
464+
side_effect=[0.0, 2.0],
465+
),
466+
):
467+
with self.assertRaises(grpc.RpcError):
468+
mcp_client.connect('weather')
469+
470+
def test_budget_exhausted_after_schedule_succeeds(self):
471+
"""If retries burn the budget but schedule eventually succeeds, raise
472+
without calling wait_for_workflow_completion (timeout=0 means
473+
'wait forever' in the underlying client)."""
474+
mock_wf = MagicMock()
475+
mock_wf.schedule_new_workflow.side_effect = [
476+
_StubRpcError(grpc.StatusCode.CANCELLED),
477+
'inst-1',
478+
]
479+
480+
mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
481+
# monotonic: 0.0 → deadline = 1.0; 0.4 → sleep_for = 0.5 (still in budget);
482+
# 2.0 → post-loop remaining = -1.0 → raise.
483+
with (
484+
patch('dapr.ext.workflow.mcp.time.sleep'),
485+
patch(
486+
'dapr.ext.workflow.mcp.time.monotonic',
487+
side_effect=[0.0, 0.4, 2.0],
488+
),
489+
):
490+
with self.assertRaises(RuntimeError) as ctx:
491+
mcp_client.connect('weather')
492+
self.assertIn('timed out', str(ctx.exception))
493+
mock_wf.wait_for_workflow_completion.assert_not_called()
494+
495+
496+
class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase):
497+
"""Async counterpart of TestDaprMCPClientConnectRetry."""
498+
499+
async def test_retries_then_succeeds_on_cancelled(self):
500+
mock_wf = AsyncMock()
501+
mock_wf.schedule_new_workflow.side_effect = [
502+
_StubRpcError(grpc.StatusCode.CANCELLED),
503+
'inst-1',
504+
]
505+
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
506+
SAMPLE_LIST_TOOLS_RESPONSE
507+
)
508+
509+
mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
510+
with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()):
511+
await mcp_client.connect('weather')
512+
513+
self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2)
514+
self.assertEqual(len(mcp_client.get_all_tools()), 2)
515+
516+
async def test_deadline_exhausted_raises(self):
517+
mock_wf = AsyncMock()
518+
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED)
519+
520+
mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
521+
with (
522+
patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()),
523+
patch(
524+
'dapr.ext.workflow.aio.mcp.time.monotonic',
525+
side_effect=[0.0, 2.0],
526+
),
527+
):
528+
with self.assertRaises(grpc.RpcError):
529+
await mcp_client.connect('weather')
530+
531+
async def test_budget_exhausted_after_schedule_succeeds(self):
532+
"""Async mirror of the fail-fast-after-schedule-success guard."""
533+
mock_wf = AsyncMock()
534+
mock_wf.schedule_new_workflow.side_effect = [
535+
_StubRpcError(grpc.StatusCode.CANCELLED),
536+
'inst-1',
537+
]
538+
539+
mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
540+
with (
541+
patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()),
542+
patch(
543+
'dapr.ext.workflow.aio.mcp.time.monotonic',
544+
side_effect=[0.0, 0.4, 2.0],
545+
),
546+
):
547+
with self.assertRaises(RuntimeError) as ctx:
548+
await mcp_client.connect('weather')
549+
self.assertIn('timed out', str(ctx.exception))
550+
mock_wf.wait_for_workflow_completion.assert_not_awaited()
551+
552+
388553
class TestMCPWorkflowPrefix(unittest.TestCase):
389554
"""Tests for the workflow naming constant."""
390555

0 commit comments

Comments
 (0)