Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@

from __future__ import annotations

import asyncio
import logging
import time
import uuid
from typing import Optional, Set

from dapr.ext.workflow.aio.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.mcp import _MCP_METHOD_LIST_TOOLS, MCP_WORKFLOW_PREFIX, _DaprMCPClientBase
from dapr.ext.workflow.mcp import (
_MCP_METHOD_LIST_TOOLS,
_SCHEDULE_RETRY_INTERVAL_SECONDS,
MCP_WORKFLOW_PREFIX,
_DaprMCPClientBase,
_is_transient_schedule_error,
)
from dapr.ext.workflow.workflow_state import WorkflowStatus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,15 +92,27 @@ async def connect(self, mcpserver_name: str) -> None:

logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)

await self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)

deadline = time.monotonic() + self._timeout
while True:
try:
await self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)
break
except Exception as exc: # noqa: BLE001 — classified by helper
if not _is_transient_schedule_error(exc) or time.monotonic() >= deadline:
raise
logger.debug(
'schedule_new_workflow returned transient error %s; retrying', exc
)
await asyncio.sleep(_SCHEDULE_RETRY_INTERVAL_SECONDS)
Comment thread
sicoyle marked this conversation as resolved.
Outdated

remaining = max(deadline - time.monotonic(), 1.0)
state = await self._wf_client.wait_for_workflow_completion(
instance_id=instance_id,
timeout_in_seconds=self._timeout,
timeout_in_seconds=int(remaining),
fetch_payloads=True,
Comment thread
sicoyle marked this conversation as resolved.
Outdated
)

Expand Down
49 changes: 43 additions & 6 deletions ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@

import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set

import grpc

from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.workflow_state import WorkflowStatus

Expand All @@ -50,6 +53,28 @@
_MCP_METHOD_LIST_TOOLS = '.ListTools'
_MCP_METHOD_CALL_TOOL = '.CallTool'

_TRANSIENT_GRPC_CODES = frozenset({
grpc.StatusCode.CANCELLED,
grpc.StatusCode.UNAVAILABLE,
})
_SCHEDULE_RETRY_INTERVAL_SECONDS = 0.5


def _is_transient_schedule_error(exc: BaseException) -> bool:
"""True if a schedule_new_workflow failure should be retried.

Walks ``__cause__`` so we catch both raw ``grpc.RpcError`` and any
durabletask-layer wrapping.
"""
if isinstance(exc, grpc.RpcError):
code = getattr(exc, 'code', None)
if callable(code) and code() in _TRANSIENT_GRPC_CODES:
return True
cause = getattr(exc, '__cause__', None)
if cause is not None and cause is not exc:
return _is_transient_schedule_error(cause)
return False


# TODO(@sicoyle): see if I can use the mcp pkg class instead for this?
@dataclass(frozen=True)
Expand Down Expand Up @@ -210,15 +235,27 @@ def connect(self, mcpserver_name: str) -> None:

logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)

self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)
deadline = time.monotonic() + self._timeout
while True:
try:
self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)
break
except Exception as exc: # noqa: BLE001 — classified by helper
if not _is_transient_schedule_error(exc) or time.monotonic() >= deadline:
raise
logger.debug(
'schedule_new_workflow returned transient error %s; retrying', exc
)
time.sleep(_SCHEDULE_RETRY_INTERVAL_SECONDS)
Comment thread
sicoyle marked this conversation as resolved.
Outdated

remaining = max(deadline - time.monotonic(), 1.0)
state = self._wf_client.wait_for_workflow_completion(
instance_id=instance_id,
timeout_in_seconds=self._timeout,
timeout_in_seconds=int(remaining),
fetch_payloads=True,
Comment thread
sicoyle marked this conversation as resolved.
Outdated
)

Expand Down
122 changes: 121 additions & 1 deletion ext/dapr-ext-workflow/tests/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,27 @@
import json
import unittest
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import grpc

from dapr.ext.workflow._durabletask import client
from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient
from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef
from dapr.ext.workflow.workflow_state import WorkflowState


class _StubRpcError(grpc.RpcError):
"""Test double for grpc.RpcError with a configurable status code."""

def __init__(self, status_code: grpc.StatusCode):
super().__init__()
self._status_code = status_code

def code(self) -> grpc.StatusCode:
return self._status_code


def _make_completed_state(output_json: dict) -> WorkflowState:
"""Create a WorkflowState that simulates a COMPLETED workflow."""
inner = client.WorkflowState(
Expand Down Expand Up @@ -385,6 +398,113 @@ async def test_connect_caches_tools(self):
self.assertEqual(tools[1].name, 'get_forecast')


class TestDaprMCPClientConnectRetry(unittest.TestCase):
"""Tests for connect()'s retry-on-transient-gRPC-error path."""

def test_retries_then_succeeds_on_cancelled(self):
"""A CANCELLED schedule failure should be retried within the timeout budget."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.CANCELLED),
_StubRpcError(grpc.StatusCode.CANCELLED),
'inst-1',
]
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
SAMPLE_LIST_TOOLS_RESPONSE
)

mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.mcp.time.sleep'):
mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3)
self.assertEqual(len(mcp_client.get_all_tools()), 2)

def test_retries_on_unavailable(self):
"""UNAVAILABLE should also be treated as transient."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.UNAVAILABLE),
'inst-1',
]
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
SAMPLE_LIST_TOOLS_RESPONSE
)

mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.mcp.time.sleep'):
mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2)

def test_non_transient_propagates_immediately(self):
"""A non-CANCELLED/UNAVAILABLE error must not be retried."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(
grpc.StatusCode.PERMISSION_DENIED
)

mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock:
with self.assertRaises(grpc.RpcError):
mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1)
sleep_mock.assert_not_called()

def test_deadline_exhausted_raises_last_error(self):
"""When the timeout budget runs out mid-retry, propagate the last error."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(
grpc.StatusCode.CANCELLED
)

mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
# Patch monotonic to advance past the deadline immediately so we don't
# actually sleep for a second in tests.
with patch('dapr.ext.workflow.mcp.time.sleep'), patch(
'dapr.ext.workflow.mcp.time.monotonic',
side_effect=[0.0, 2.0],
):
with self.assertRaises(grpc.RpcError):
mcp_client.connect('weather')


class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase):
"""Async counterpart of TestDaprMCPClientConnectRetry."""

async def test_retries_then_succeeds_on_cancelled(self):
mock_wf = AsyncMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.CANCELLED),
'inst-1',
]
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
SAMPLE_LIST_TOOLS_RESPONSE
)

mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()):
await mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2)
self.assertEqual(len(mcp_client.get_all_tools()), 2)

async def test_deadline_exhausted_raises(self):
mock_wf = AsyncMock()
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(
grpc.StatusCode.CANCELLED
)

mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), patch(
'dapr.ext.workflow.aio.mcp.time.monotonic',
side_effect=[0.0, 2.0],
):
with self.assertRaises(grpc.RpcError):
await mcp_client.connect('weather')


class TestMCPWorkflowPrefix(unittest.TestCase):
"""Tests for the workflow naming constant."""

Expand Down
Loading