Skip to content

Commit 605e4d3

Browse files
committed
add ts-v2 user interface
1 parent b270da7 commit 605e4d3

5 files changed

Lines changed: 393 additions & 0 deletions

File tree

temporalio/client/_impl.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Callable,
1111
Mapping,
1212
)
13+
from contextvars import ContextVar
1314
from datetime import timedelta
1415
from typing import (
1516
TYPE_CHECKING,
@@ -26,6 +27,7 @@
2627
import temporalio.api.schedule.v1
2728
import temporalio.api.taskqueue.v1
2829
import temporalio.api.update.v1
30+
import temporalio.api.workflow.v1
2931
import temporalio.api.workflowservice.v1
3032
import temporalio.common
3133
import temporalio.converter
@@ -136,6 +138,14 @@
136138
from ._client import Client
137139

138140

141+
# Set by WorkflowTimeSkipper's outbound interceptor before super().start_workflow(input),
142+
# read in _populate_start_workflow_execution_request to stamp time_skipping_config onto
143+
# the outgoing request. Reset in the interceptor's finally block.
144+
_start_workflow_time_skipping_config: ContextVar[
145+
temporalio.api.workflow.v1.TimeSkippingConfig | None
146+
] = ContextVar("_start_workflow_time_skipping_config", default=None)
147+
148+
139149
class _ClientImpl(OutboundInterceptor): # pyright: ignore[reportUnusedClass]
140150
def __init__(self, client: Client) -> None: # type: ignore
141151
# We are intentionally not calling the base class's __init__ here
@@ -340,6 +350,9 @@ async def _populate_start_workflow_execution_request(
340350
req.priority.CopyFrom(input.priority._to_proto())
341351
if input.versioning_override is not None:
342352
req.versioning_override.CopyFrom(input.versioning_override._to_proto())
353+
ts_config = _start_workflow_time_skipping_config.get()
354+
if ts_config is not None:
355+
req.time_skipping_config.CopyFrom(ts_config)
343356

344357
async def cancel_workflow(self, input: CancelWorkflowInput) -> None:
345358
await self._client.workflow_service.request_cancel_workflow_execution(

temporalio/testing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Test framework for workflows and activities."""
22

33
from ._activity import ActivityEnvironment
4+
from ._timeskipping import WorkflowTimeSkipper, WorkflowTimeSkippingConfig
45
from ._workflow import WorkflowEnvironment
56

67
__all__ = [
78
"ActivityEnvironment",
89
"WorkflowEnvironment",
10+
"WorkflowTimeSkipper",
11+
"WorkflowTimeSkippingConfig",
912
]
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""Utilities for per-workflow time skipping in tests."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from datetime import timedelta
7+
from typing import Any
8+
9+
import google.protobuf.field_mask_pb2
10+
11+
import temporalio.api.common.v1
12+
import temporalio.api.enums.v1.event_type_pb2 as _event_type
13+
import temporalio.api.workflow.v1
14+
import temporalio.api.workflowservice.v1
15+
import temporalio.client
16+
from temporalio.client._impl import _start_workflow_time_skipping_config
17+
18+
19+
@dataclass(frozen=True)
20+
class WorkflowTimeSkippingConfig:
21+
"""Per-workflow time skipping configuration."""
22+
23+
enabled: bool = True
24+
"""Whether time skipping is enabled for the workflow."""
25+
26+
max_skip_duration: timedelta | None = None
27+
"""Maximum total virtual time that can be skipped before time skipping
28+
is automatically disabled."""
29+
30+
def _to_proto(self) -> temporalio.api.workflow.v1.TimeSkippingConfig:
31+
proto = temporalio.api.workflow.v1.TimeSkippingConfig(enabled=self.enabled)
32+
if self.max_skip_duration is not None:
33+
proto.max_skipped_duration.FromTimedelta(self.max_skip_duration)
34+
return proto
35+
36+
37+
_TERMINAL_EVENT_TYPES = frozenset(
38+
{
39+
_event_type.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
40+
_event_type.EVENT_TYPE_WORKFLOW_EXECUTION_FAILED,
41+
_event_type.EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT,
42+
_event_type.EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED,
43+
_event_type.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
44+
_event_type.EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW,
45+
}
46+
)
47+
48+
49+
class WorkflowTimeSkipper:
50+
"""Testing utility for per-workflow time skipping.
51+
52+
Creates a cloned client that automatically enables time skipping on every
53+
workflow started through it. Once a workflow's configured bound is
54+
reached, :py:meth:`wait_for_skip_duration_reached` blocks until the
55+
transition occurs and :py:meth:`resume` re-enables skipping with an
56+
optional new delta.
57+
58+
Example::
59+
60+
ts = WorkflowTimeSkipper(env.client,
61+
config=WorkflowTimeSkippingConfig(max_skip_duration=timedelta(hours=1)))
62+
63+
handle = await ts.client.start_workflow(
64+
MyWorkflow.run, id="wf-1", task_queue="tq",
65+
)
66+
await ts.wait_for_skip_duration_reached(handle)
67+
# inspect state, signal, etc.
68+
await ts.resume(handle, delta=timedelta(hours=1))
69+
result = await handle.result()
70+
71+
Works against any client the test suite hands in (local, self-hosted, or
72+
cloud). TODO: cloud usage assumes the namespace has server-side time
73+
skipping enabled (``frontend.TimeSkippingEnabled``); add a ``cloud``
74+
fixture mode alongside ``local`` / ``time-skipping`` in ``conftest.env``
75+
so the same tests can be pointed at a cloud namespace once that lands.
76+
"""
77+
78+
def __init__(
79+
self,
80+
client: temporalio.client.Client,
81+
*,
82+
config: WorkflowTimeSkippingConfig = WorkflowTimeSkippingConfig(),
83+
) -> None:
84+
"""Create a workflow time skipper.
85+
86+
Args:
87+
client: The client to wrap. A cloned client with a time-skipping
88+
interceptor is created; the original is left untouched.
89+
config: Initial bound. Defaults to no bound — time skipping runs
90+
until the workflow completes.
91+
"""
92+
self._config = config
93+
client_config = client.config()
94+
client_config["interceptors"] = [
95+
*client_config["interceptors"],
96+
_TimeSkippingConfigInterceptor(self),
97+
]
98+
self._client = temporalio.client.Client(**client_config)
99+
# Per-workflow max_skip_duration last set on the server, keyed by
100+
# (workflow_id, run_id).
101+
self._bound_cache: dict[tuple[str, str], timedelta] = {}
102+
103+
@property
104+
def client(self) -> temporalio.client.Client:
105+
"""Client that enables time skipping on every started workflow."""
106+
return self._client
107+
108+
@property
109+
def config(self) -> WorkflowTimeSkippingConfig:
110+
"""Bound applied to future start_workflow calls."""
111+
return self._config
112+
113+
@config.setter
114+
def config(self, value: WorkflowTimeSkippingConfig) -> None:
115+
self._config = value
116+
117+
async def wait_for_skip_duration_reached(
118+
self,
119+
handle: temporalio.client.WorkflowHandle[Any, Any],
120+
) -> bool:
121+
"""Block until the workflow's configured skip duration is reached.
122+
123+
Returns ``True`` once a time-skipping-disabled transition is observed.
124+
Returns ``False`` if the workflow terminates before any bound is
125+
reached.
126+
"""
127+
# TODO: Replace with a dedicated long-poll RPC once the server adds
128+
# one for time-skipping transitions. The current path streams every
129+
# history event since the workflow started, which is correct but not
130+
# the most efficient if event volume is high.
131+
async for event in handle.fetch_history_events(wait_new_event=True):
132+
if (
133+
event.event_type
134+
== _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED
135+
):
136+
attrs = (
137+
event.workflow_execution_time_skipping_transitioned_event_attributes
138+
)
139+
if attrs.disabled_after_bound:
140+
return True
141+
elif event.event_type in _TERMINAL_EVENT_TYPES:
142+
return False
143+
return False
144+
145+
async def resume(
146+
self,
147+
handle: temporalio.client.WorkflowHandle[Any, Any],
148+
delta: timedelta | None = None,
149+
) -> None:
150+
"""Re-enable time skipping after a bound was reached.
151+
152+
With ``delta``, sets a new bound equal to (previously-set bound +
153+
delta). Without ``delta``, resumes skipping with no bound — the
154+
workflow auto-skips until completion.
155+
"""
156+
proto = temporalio.api.workflow.v1.TimeSkippingConfig(enabled=True)
157+
if delta is not None:
158+
cache_key = (handle.id, handle.run_id or "")
159+
if cache_key not in self._bound_cache:
160+
if self._config.max_skip_duration is None:
161+
raise ValueError(
162+
"resume(delta=...) requires an initial bound to have been "
163+
"configured on the WorkflowTimeSkipper, or call resume() "
164+
"with no delta to resume unbounded."
165+
)
166+
self._bound_cache[cache_key] = self._config.max_skip_duration
167+
new_value = self._bound_cache[cache_key] + delta
168+
proto.max_skipped_duration.FromTimedelta(new_value)
169+
self._bound_cache[cache_key] = new_value
170+
171+
await self._client.workflow_service.update_workflow_execution_options(
172+
temporalio.api.workflowservice.v1.UpdateWorkflowExecutionOptionsRequest(
173+
namespace=self._client.namespace,
174+
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
175+
workflow_id=handle.id,
176+
run_id=handle.run_id or "",
177+
),
178+
workflow_execution_options=temporalio.api.workflow.v1.WorkflowExecutionOptions(
179+
time_skipping_config=proto,
180+
),
181+
update_mask=google.protobuf.field_mask_pb2.FieldMask(
182+
paths=["time_skipping_config"],
183+
),
184+
identity=self._client.identity,
185+
),
186+
retry=True,
187+
)
188+
189+
190+
class _TimeSkippingConfigInterceptor(temporalio.client.Interceptor):
191+
def __init__(self, skipper: WorkflowTimeSkipper) -> None:
192+
super().__init__()
193+
self._skipper = skipper
194+
195+
def intercept_client(
196+
self, next: temporalio.client.OutboundInterceptor
197+
) -> temporalio.client.OutboundInterceptor:
198+
return _TimeSkippingConfigOutbound(next, self._skipper)
199+
200+
201+
class _TimeSkippingConfigOutbound(temporalio.client.OutboundInterceptor):
202+
def __init__(
203+
self,
204+
next: temporalio.client.OutboundInterceptor,
205+
skipper: WorkflowTimeSkipper,
206+
) -> None:
207+
super().__init__(next)
208+
self._skipper = skipper
209+
210+
async def start_workflow(
211+
self, input: temporalio.client.StartWorkflowInput
212+
) -> temporalio.client.WorkflowHandle[Any, Any]:
213+
proto = self._skipper.config._to_proto()
214+
token = _start_workflow_time_skipping_config.set(proto)
215+
try:
216+
handle = await super().start_workflow(input)
217+
finally:
218+
_start_workflow_time_skipping_config.reset(token)
219+
# Seed the bound cache so future resume(delta=...) calls have a
220+
# baseline to add to. Captures the config at start time, even if the
221+
# user mutates self._skipper.config afterwards.
222+
cfg = self._skipper.config
223+
if cfg.max_skip_duration is not None:
224+
cache_key = (handle.id, handle.run_id or "")
225+
self._skipper._bound_cache[cache_key] = cfg.max_skip_duration
226+
return handle

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
134134
"nexusoperation.enableStandalone=true",
135135
"--dynamic-config-value",
136136
'system.system.refreshNexusEndpointsMinWait="0s"',
137+
"--dynamic-config-value",
138+
"frontend.TimeSkippingEnabled=true",
137139
],
138140
dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION,
139141
)

0 commit comments

Comments
 (0)