Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 79 additions & 23 deletions ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
import time
import uuid
from datetime import datetime
from typing import Any, Optional, Sequence, Union
Expand All @@ -33,6 +35,7 @@
TOutput,
WorkflowIdReusePolicy,
WorkflowState,
_TransientTimeout,
new_orchestration_state,
)
from google.protobuf import wrappers_pb2
Expand Down Expand Up @@ -123,31 +126,30 @@ async def wait_for_orchestration_start(
self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0
) -> Optional[WorkflowState]:
Comment thread
javier-aliaga marked this conversation as resolved.
Comment on lines 125 to 127
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start."
)
self._logger.info(
f"Waiting {'indefinitely' if timeout in (0, None) else f'up to {timeout}s'} for instance '{instance_id}' to start."
)

async def _call(grpc_timeout):
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(
req, timeout=grpc_timeout
)
return new_orchestration_state(req.instanceId, res)
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError('Timed-out waiting for the orchestration to start')
else:
raise

try:
return await self._call_with_transient_retry(instance_id, timeout, _call)
except _TransientTimeout:
raise TimeoutError('Timed-out waiting for the orchestration to start')

async def wait_for_orchestration_completion(
self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0
) -> Optional[WorkflowState]:
Comment thread
javier-aliaga marked this conversation as resolved.
Comment thread
javier-aliaga marked this conversation as resolved.
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
)
self._logger.info(
f"Waiting {'indefinitely' if timeout in (0, None) else f'up to {timeout}s'} for instance '{instance_id}' to complete."
)

async def _call(grpc_timeout):
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(
req, timeout=grpc_timeout
)
Expand All @@ -167,14 +169,68 @@ async def wait_for_orchestration_completion(
self._logger.info(f"Instance '{instance_id}' was terminated.")
elif state.runtime_status == OrchestrationStatus.COMPLETED:
self._logger.info(f"Instance '{instance_id}' completed.")

return state
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError('Timed-out waiting for the orchestration to complete')
else:
raise

try:
return await self._call_with_transient_retry(instance_id, timeout, _call)
except _TransientTimeout:
raise TimeoutError('Timed-out waiting for the orchestration to complete')

# Transient gRPC codes that indicate the workflow runtime is temporarily
# unable to locate the workflow actor — typically immediately after a Dapr
# sidecar restart (e.g. recovery from chaos). The placement service has the
# actor registration, but local daprd hasn't received the dissemination yet.
# Without retry, every poll fails permanently with FAILED_PRECONDITION even
# though the workflow runtime state is intact.
_TRANSIENT_RPC_CODES = (
grpc.StatusCode.FAILED_PRECONDITION,
grpc.StatusCode.UNAVAILABLE,
)

async def _call_with_transient_retry(self, instance_id, timeout, call_fn):
"""Async mirror of TaskHubGrpcClient._call_with_transient_retry.
Retries FAILED_PRECONDITION/UNAVAILABLE with capped exponential
backoff while clamping sleep and per-call gRPC timeout to the
remaining budget. The first call passes ``timeout`` verbatim so
callers observe identical behavior on a healthy runtime.
"""
Comment thread
javier-aliaga marked this conversation as resolved.
unbounded = timeout in (0, None)
deadline = None if unbounded else time.monotonic() + timeout
grpc_timeout = None if unbounded else timeout
backoff = 0.5
while True:
try:
return await call_fn(grpc_timeout)
except grpc.RpcError as rpc_error:
code = rpc_error.code() # type: ignore
if code == grpc.StatusCode.DEADLINE_EXCEEDED:
raise _TransientTimeout()
if code not in self._TRANSIENT_RPC_CODES:
raise

if deadline is None:
remaining = None
else:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise _TransientTimeout()

sleep_for = min(backoff, 5.0)
if remaining is not None:
sleep_for = min(sleep_for, remaining)
self._logger.warning(
f"Transient gRPC error {code.name} waiting on instance '{instance_id}'; "
f'retrying in {sleep_for:.2f}s'
)
await asyncio.sleep(sleep_for)
backoff = min(backoff * 2, 5.0)

if deadline is None:
grpc_timeout = None
else:
grpc_timeout = deadline - time.monotonic()
if grpc_timeout <= 0:
raise _TransientTimeout()

async def raise_orchestration_event(
self, instance_id: str, event_name: str, *, data: Optional[Any] = None
Expand Down
114 changes: 91 additions & 23 deletions ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import logging
import time
import uuid
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -25,6 +26,12 @@
from dapr.ext.workflow._durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
from google.protobuf import wrappers_pb2


class _TransientTimeout(Exception):
"""Internal sentinel: the retry loop exhausted the user-provided timeout
budget. Callers convert this to a public ``TimeoutError``."""


TInput = TypeVar('TInput')
TOutput = TypeVar('TOutput')

Expand Down Expand Up @@ -220,29 +227,28 @@ def wait_for_orchestration_start(
self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0
) -> Optional[WorkflowState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start."
)
self._logger.info(
Comment thread
javier-aliaga marked this conversation as resolved.
Comment thread
javier-aliaga marked this conversation as resolved.
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start."
)

def _call(grpc_timeout):
res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
return new_orchestration_state(req.instanceId, res)
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError('Timed-out waiting for the orchestration to start')
else:
raise

try:
return self._call_with_transient_retry(instance_id, timeout, _call)
except _TransientTimeout:
raise TimeoutError('Timed-out waiting for the orchestration to start')

def wait_for_orchestration_completion(
self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0
) -> Optional[WorkflowState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
)
self._logger.info(
Comment thread
javier-aliaga marked this conversation as resolved.
Comment thread
javier-aliaga marked this conversation as resolved.
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
)

def _call(grpc_timeout):
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
req, timeout=grpc_timeout
)
Expand All @@ -262,14 +268,76 @@ def wait_for_orchestration_completion(
self._logger.info(f"Instance '{instance_id}' was terminated.")
elif state.runtime_status == OrchestrationStatus.COMPLETED:
self._logger.info(f"Instance '{instance_id}' completed.")

return state
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError('Timed-out waiting for the orchestration to complete')
else:
raise

try:
return self._call_with_transient_retry(instance_id, timeout, _call)
except _TransientTimeout:
raise TimeoutError('Timed-out waiting for the orchestration to complete')
Comment thread
javier-aliaga marked this conversation as resolved.
Comment thread
javier-aliaga marked this conversation as resolved.

# Transient gRPC codes that indicate the workflow runtime is temporarily
# unable to locate the workflow actor — typically immediately after a Dapr
# sidecar restart (e.g. recovery from chaos). The placement service has the
# actor registration, but local daprd hasn't received the dissemination yet.
# Without retry, every poll fails permanently with FAILED_PRECONDITION even
# though the workflow runtime state is intact.
_TRANSIENT_RPC_CODES = (
grpc.StatusCode.FAILED_PRECONDITION,
grpc.StatusCode.UNAVAILABLE,
)
Comment thread
javier-aliaga marked this conversation as resolved.

def _call_with_transient_retry(self, instance_id, timeout, call_fn):
"""Run a gRPC wait call, retrying transient errors until the user
timeout deadline. Re-raises non-transient errors immediately.
timeout in (0, None) means unbounded; we still retry transients with
backoff.
Comment thread
javier-aliaga marked this conversation as resolved.
Outdated

The first call passes ``timeout`` verbatim to ``call_fn`` so callers
observe identical behavior to a non-retrying client when no transient
occurs (preserves prior public behavior). On a retry, both the sleep
and the per-call gRPC deadline are clamped to the remaining budget so
the helper never sleeps past ``timeout`` or starts a gRPC call with
no time left.
Comment thread
javier-aliaga marked this conversation as resolved.
Outdated
Comment thread
javier-aliaga marked this conversation as resolved.
Outdated
"""
unbounded = timeout in (0, None)
deadline = None if unbounded else time.monotonic() + timeout
grpc_timeout = None if unbounded else timeout
backoff = 0.5
while True:
try:
return call_fn(grpc_timeout)
except grpc.RpcError as rpc_error:
code = rpc_error.code() # type: ignore
if code == grpc.StatusCode.DEADLINE_EXCEEDED:
raise _TransientTimeout()
if code not in self._TRANSIENT_RPC_CODES:
raise

# Compute remaining budget once and reuse so the sleep and the
# next per-call grpc_timeout agree on "how much time is left".
if deadline is None:
remaining = None
else:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise _TransientTimeout()

sleep_for = min(backoff, 5.0)
if remaining is not None:
sleep_for = min(sleep_for, remaining)
self._logger.warning(
f"Transient gRPC error {code.name} waiting on instance '{instance_id}'; "
f'retrying in {sleep_for:.2f}s'
)
time.sleep(sleep_for)
backoff = min(backoff * 2, 5.0)

if deadline is None:
grpc_timeout = None
else:
grpc_timeout = deadline - time.monotonic()
if grpc_timeout <= 0:
raise _TransientTimeout()

def raise_orchestration_event(
self, instance_id: str, event_name: str, *, data: Optional[Any] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time
from unittest.mock import Mock

import grpc
import pytest
from dapr.ext.workflow._durabletask.client import TaskHubGrpcClient

Expand Down Expand Up @@ -66,3 +68,91 @@ def test_wait_for_orchestration_completion_timeout(timeout):
assert kwargs.get('timeout') is None
else:
assert kwargs.get('timeout') == timeout


def _make_rpc_error(code: grpc.StatusCode) -> grpc.RpcError:
err = grpc.RpcError()
err.code = lambda: code # type: ignore[method-assign]
err.details = lambda: f'simulated {code.name}' # type: ignore[method-assign]
return err


@pytest.mark.parametrize(
'transient_code', [grpc.StatusCode.FAILED_PRECONDITION, grpc.StatusCode.UNAVAILABLE]
)
def test_wait_for_orchestration_start_retries_transient_then_succeeds(transient_code, monkeypatch):
"""Transient gRPC error on the first call → backoff → next call succeeds."""
instance_id = 'test-instance'

from dapr.ext.workflow._durabletask.internal.protos import (
ORCHESTRATION_STATUS_RUNNING,
GetInstanceResponse,
WorkflowState,
)

response = GetInstanceResponse()
state = WorkflowState()
state.instanceId = instance_id
state.workflowStatus = ORCHESTRATION_STATUS_RUNNING
response.workflowState.CopyFrom(state)

sleeps = []
monkeypatch.setattr(
'dapr.ext.workflow._durabletask.client.time.sleep', lambda s: sleeps.append(s)
)

calls = {'n': 0}

def fake_call(*args, **kwargs):
calls['n'] += 1
if calls['n'] == 1:
raise _make_rpc_error(transient_code)
return response

c = TaskHubGrpcClient()
c._stub = Mock()
c._stub.WaitForInstanceStart.side_effect = fake_call

# The point of this test is the retry behavior, not the response payload —
# the second call returns successfully (no exception), the first transient
# is absorbed, and exactly one backoff sleep happens between them.
c.wait_for_orchestration_start(instance_id, timeout=10)
assert calls['n'] == 2
assert len(sleeps) == 1 and sleeps[0] > 0
Comment thread
javier-aliaga marked this conversation as resolved.


def test_wait_for_orchestration_start_transient_exhaustion_raises_timeout(monkeypatch):
"""Transient gRPC errors keep returning until the user budget runs out
→ public TimeoutError, not the raw RpcError."""
instance_id = 'test-instance'

# Advance monotonic time on every call so the deadline is reached quickly.
fake_time = [0.0]

def fake_monotonic():
fake_time[0] += 0.6 # 0.0, 0.6, 1.2, ...
return fake_time[0]

monkeypatch.setattr('dapr.ext.workflow._durabletask.client.time.monotonic', fake_monotonic)
monkeypatch.setattr('dapr.ext.workflow._durabletask.client.time.sleep', lambda s: None)

c = TaskHubGrpcClient()
c._stub = Mock()
c._stub.WaitForInstanceStart.side_effect = _make_rpc_error(grpc.StatusCode.UNAVAILABLE)

with pytest.raises(TimeoutError):
c.wait_for_orchestration_start(instance_id, timeout=1)


def test_wait_for_orchestration_start_non_transient_propagates(monkeypatch):
"""Non-transient gRPC errors must NOT be retried — propagate directly."""
instance_id = 'test-instance'
monkeypatch.setattr(time, 'sleep', lambda s: None)

c = TaskHubGrpcClient()
c._stub = Mock()
c._stub.WaitForInstanceStart.side_effect = _make_rpc_error(grpc.StatusCode.PERMISSION_DENIED)

with pytest.raises(grpc.RpcError):
c.wait_for_orchestration_start(instance_id, timeout=10)
assert c._stub.WaitForInstanceStart.call_count == 1
Loading