|
19 | 19 | import contextlib |
20 | 20 | import hashlib |
21 | 21 | import logging |
| 22 | +import math |
22 | 23 | import random |
23 | 24 | import uuid |
24 | | -from collections.abc import Callable, Iterable |
| 25 | +from collections.abc import Callable, Iterable, Mapping |
25 | 26 | from dataclasses import dataclass, field |
26 | 27 | from datetime import datetime, timezone |
27 | 28 | from typing import Any |
@@ -49,23 +50,97 @@ def registry() -> dict[str, type]: |
49 | 50 |
|
50 | 51 |
|
51 | 52 | # ── Commands yielded from a workflow ────────────────────────────────── |
| 53 | +@dataclass |
| 54 | +class ActivityRetryPolicy: |
| 55 | + """Retry policy applied to one scheduled activity call. |
| 56 | +
|
| 57 | + The policy is snapped onto the durable activity execution when the |
| 58 | + workflow task completes, so later code deploys do not change the retry |
| 59 | + budget for an already-scheduled activity. |
| 60 | + """ |
| 61 | + |
| 62 | + max_attempts: int = 3 |
| 63 | + initial_interval_seconds: float = 1.0 |
| 64 | + backoff_coefficient: float = 2.0 |
| 65 | + maximum_interval_seconds: float | None = None |
| 66 | + non_retryable_error_types: list[str] = field(default_factory=list) |
| 67 | + backoff_seconds: list[int] | None = None |
| 68 | + |
| 69 | + def to_dict(self) -> dict[str, Any]: |
| 70 | + """Return the server command shape for this activity retry policy.""" |
| 71 | + if self.max_attempts < 1: |
| 72 | + raise ValueError("max_attempts must be >= 1") |
| 73 | + if self.initial_interval_seconds < 0: |
| 74 | + raise ValueError("initial_interval_seconds must be >= 0") |
| 75 | + if self.backoff_coefficient < 1: |
| 76 | + raise ValueError("backoff_coefficient must be >= 1") |
| 77 | + if self.maximum_interval_seconds is not None and self.maximum_interval_seconds < 0: |
| 78 | + raise ValueError("maximum_interval_seconds must be >= 0") |
| 79 | + |
| 80 | + return { |
| 81 | + "max_attempts": self.max_attempts, |
| 82 | + "backoff_seconds": self._backoff_seconds(), |
| 83 | + "non_retryable_error_types": [ |
| 84 | + value.strip() |
| 85 | + for value in self.non_retryable_error_types |
| 86 | + if isinstance(value, str) and value.strip() |
| 87 | + ], |
| 88 | + } |
| 89 | + |
| 90 | + def _backoff_seconds(self) -> list[int]: |
| 91 | + if self.backoff_seconds is not None: |
| 92 | + return [max(0, int(seconds)) for seconds in self.backoff_seconds] |
| 93 | + |
| 94 | + seconds: list[int] = [] |
| 95 | + current = self.initial_interval_seconds |
| 96 | + maximum = self.maximum_interval_seconds |
| 97 | + for _ in range(max(0, self.max_attempts - 1)): |
| 98 | + value = current if maximum is None else min(current, maximum) |
| 99 | + seconds.append(max(0, int(math.ceil(value)))) |
| 100 | + current *= self.backoff_coefficient |
| 101 | + return seconds |
| 102 | + |
| 103 | + |
| 104 | +ActivityRetryPolicyInput = ActivityRetryPolicy | Mapping[str, Any] |
| 105 | + |
| 106 | + |
52 | 107 | @dataclass |
53 | 108 | class ScheduleActivity: |
54 | 109 | """Command requesting an activity task.""" |
55 | 110 |
|
56 | 111 | activity_type: str |
57 | 112 | arguments: list[Any] |
58 | 113 | queue: str | None = None |
| 114 | + retry_policy: ActivityRetryPolicyInput | None = None |
| 115 | + start_to_close_timeout: int | None = None |
| 116 | + schedule_to_start_timeout: int | None = None |
| 117 | + schedule_to_close_timeout: int | None = None |
| 118 | + heartbeat_timeout: int | None = None |
59 | 119 |
|
60 | 120 | def to_server_command( |
61 | 121 | self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC |
62 | 122 | ) -> dict[str, Any]: |
63 | | - return { |
| 123 | + command: dict[str, Any] = { |
64 | 124 | "type": "schedule_activity", |
65 | 125 | "activity_type": self.activity_type, |
66 | 126 | "arguments": serializer.envelope(self.arguments, codec=payload_codec), |
67 | 127 | "queue": self.queue or task_queue, |
68 | 128 | } |
| 129 | + if self.retry_policy is not None: |
| 130 | + command["retry_policy"] = ( |
| 131 | + self.retry_policy.to_dict() |
| 132 | + if isinstance(self.retry_policy, ActivityRetryPolicy) |
| 133 | + else dict(self.retry_policy) |
| 134 | + ) |
| 135 | + if self.start_to_close_timeout is not None: |
| 136 | + command["start_to_close_timeout"] = self.start_to_close_timeout |
| 137 | + if self.schedule_to_start_timeout is not None: |
| 138 | + command["schedule_to_start_timeout"] = self.schedule_to_start_timeout |
| 139 | + if self.schedule_to_close_timeout is not None: |
| 140 | + command["schedule_to_close_timeout"] = self.schedule_to_close_timeout |
| 141 | + if self.heartbeat_timeout is not None: |
| 142 | + command["heartbeat_timeout"] = self.heartbeat_timeout |
| 143 | + return command |
69 | 144 |
|
70 | 145 |
|
71 | 146 | @dataclass |
@@ -266,9 +341,27 @@ def __init__(self, *, run_id: str = "", current_time: datetime | None = None) -> |
266 | 341 | self.logger = _ReplayLogger(_REPLAY_LOGGER) |
267 | 342 |
|
268 | 343 | def schedule_activity( |
269 | | - self, activity_type: str, arguments: list[Any], *, queue: str | None = None |
| 344 | + self, |
| 345 | + activity_type: str, |
| 346 | + arguments: list[Any], |
| 347 | + *, |
| 348 | + queue: str | None = None, |
| 349 | + retry_policy: ActivityRetryPolicyInput | None = None, |
| 350 | + start_to_close_timeout: int | None = None, |
| 351 | + schedule_to_start_timeout: int | None = None, |
| 352 | + schedule_to_close_timeout: int | None = None, |
| 353 | + heartbeat_timeout: int | None = None, |
270 | 354 | ) -> ScheduleActivity: |
271 | | - return ScheduleActivity(activity_type=activity_type, arguments=list(arguments), queue=queue) |
| 355 | + return ScheduleActivity( |
| 356 | + activity_type=activity_type, |
| 357 | + arguments=list(arguments), |
| 358 | + queue=queue, |
| 359 | + retry_policy=retry_policy, |
| 360 | + start_to_close_timeout=start_to_close_timeout, |
| 361 | + schedule_to_start_timeout=schedule_to_start_timeout, |
| 362 | + schedule_to_close_timeout=schedule_to_close_timeout, |
| 363 | + heartbeat_timeout=heartbeat_timeout, |
| 364 | + ) |
272 | 365 |
|
273 | 366 | def start_timer(self, seconds: int) -> StartTimer: |
274 | 367 | return StartTimer(delay_seconds=seconds) |
|
0 commit comments