Skip to content

Commit f130e35

Browse files
committed
centralize and simplify start workflow response link handling. Move private start context accessor to nexus package. Update tests to reflect that response links are created when using plain start workflow
1 parent af78a9c commit f130e35

3 files changed

Lines changed: 65 additions & 63 deletions

File tree

temporalio/client/_impl.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,9 @@ async def start_workflow(
206206
# returned so the caller workflow's Nexus history event links to the signaled event. A
207207
# plain start does not capture a response link: it only forwards the inbound request links
208208
# onto the start request.
209-
nexus_ctx = self._try_nexus_start_operation_context()
210-
if (
211-
nexus_ctx is not None
212-
and not temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context()
213-
and isinstance(
214-
resp,
215-
temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse,
216-
)
217-
):
218-
# Server >= 1.31 with EnableCHASMSignalBacklinks returns signal_link pointing at
219-
# the WorkflowExecutionSignaled event; older servers leave it unset.
220-
if resp.HasField("signal_link"):
221-
nexus_ctx._add_response_link(resp.signal_link)
209+
nexus_ctx = temporalio.nexus._operation_context._try_start_operation_context()
210+
if nexus_ctx is not None:
211+
nexus_ctx._add_start_workflow_response_link(handle)
222212
return handle
223213

224214
async def _build_start_workflow_execution_request(
@@ -256,7 +246,7 @@ async def _build_start_workflow_execution_request(
256246
# Links are duplicated on request for compatibility with older server versions.
257247
req.links.extend(links)
258248

259-
nexus_ctx = self._try_nexus_start_operation_context()
249+
nexus_ctx = temporalio.nexus._operation_context._try_start_operation_context()
260250
if nexus_ctx is not None:
261251
# This start was issued from inside a Nexus operation handler. If the workflow ID
262252
# conflict policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING and a conflict is
@@ -304,7 +294,9 @@ async def _build_signal_with_start_workflow_execution_request(
304294
# nexus-backing workflow), forward the inbound Nexus task links so both the callee's
305295
# WorkflowExecutionStarted and WorkflowExecutionSignaled events link back to the caller.
306296
if not temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
307-
nexus_ctx = self._try_nexus_start_operation_context()
297+
nexus_ctx = (
298+
temporalio.nexus._operation_context._try_start_operation_context()
299+
)
308300
if nexus_ctx is not None:
309301
req.links.extend(nexus_ctx._get_request_links())
310302
return req
@@ -542,7 +534,7 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
542534
await self._apply_headers(input.headers, req.header.fields)
543535
# If this signal is issued from inside a Nexus operation handler, forward the inbound
544536
# Nexus task links so the WorkflowExecutionSignaled event links back to the caller.
545-
nexus_ctx = self._try_nexus_start_operation_context()
537+
nexus_ctx = temporalio.nexus._operation_context._try_start_operation_context()
546538
if nexus_ctx is not None:
547539
req.links.extend(nexus_ctx._get_request_links())
548540
resp = await self._client.workflow_service.signal_workflow_execution(
@@ -1685,17 +1677,6 @@ async def count_nexus_operations(
16851677
)
16861678
)
16871679

1688-
@staticmethod
1689-
def _try_nexus_start_operation_context() -> (
1690-
temporalio.nexus._operation_context._TemporalStartOperationContext | None
1691-
):
1692-
"""The Nexus start-operation context if a handler is currently running, else None."""
1693-
return (
1694-
temporalio.nexus._operation_context._temporal_start_operation_context.get(
1695-
None
1696-
)
1697-
)
1698-
16991680
async def _apply_headers(
17001681
self,
17011682
source: Mapping[str, temporalio.api.common.v1.Payload] | None,

temporalio/nexus/_operation_context.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
overload,
2424
)
2525

26+
import nexusrpc
2627
from nexusrpc.handler import (
2728
CancelOperationContext,
2829
OperationContext,
@@ -42,8 +43,11 @@
4243
SelfType,
4344
)
4445

46+
import temporalio.api.enums.v1
47+
4548
from ._link_conversion import (
4649
nexus_link_to_temporal_link,
50+
temporal_link_to_nexus_link,
4751
workflow_event_to_nexus_link,
4852
workflow_execution_started_event_link_from_workflow_handle,
4953
)
@@ -167,6 +171,11 @@ def _try_temporal_context() -> (
167171
return start_ctx or cancel_ctx
168172

169173

174+
def _try_start_operation_context() -> _TemporalStartOperationContext | None: # pyright: ignore[reportUnusedFunction]
175+
"""The Nexus start-operation context if a handler is currently running, else None."""
176+
return _temporal_start_operation_context.get(None)
177+
178+
170179
@contextmanager
171180
def _nexus_backing_workflow_start_context() -> Generator[None]:
172181
token = _temporal_nexus_backing_workflow_start_context.set(True)
@@ -253,35 +262,43 @@ def _get_request_links(self) -> list[temporalio.api.common.v1.Link]:
253262
event_links.append(link)
254263
return event_links
255264

256-
def _add_backing_workflow_response_link(
265+
def _add_start_workflow_response_link(
257266
self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any]
258267
):
259-
# If links were not sent in StartWorkflowExecutionResponse then construct them.
260-
wf_event_links: list[temporalio.api.common.v1.Link.WorkflowEvent] = []
261-
try:
262-
if isinstance(
263-
workflow_handle._start_workflow_response,
264-
temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse,
265-
):
266-
if workflow_handle._start_workflow_response.HasField("link"):
267-
if link := workflow_handle._start_workflow_response.link:
268-
if link.HasField("workflow_event"):
269-
wf_event_links.append(link.workflow_event)
270-
if not wf_event_links:
271-
wf_event_links = [
272-
workflow_execution_started_event_link_from_workflow_handle(
268+
response = workflow_handle._start_workflow_response
269+
270+
nexus_link: nexusrpc.Link | None = None
271+
if isinstance(
272+
response, temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse
273+
):
274+
if response.HasField("link"):
275+
nexus_link = temporal_link_to_nexus_link(response.link)
276+
else:
277+
# If a link was not sent in response then construct it.
278+
link = temporalio.api.common.v1.Link(
279+
workflow_event=workflow_execution_started_event_link_from_workflow_handle(
273280
workflow_handle,
274281
self.nexus_context.request_id,
275282
)
276-
]
277-
self.nexus_context.outbound_links.extend(
278-
workflow_event_to_nexus_link(link) for link in wf_event_links
279-
)
283+
)
284+
nexus_link = temporal_link_to_nexus_link(link)
285+
286+
elif isinstance(
287+
response,
288+
temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse,
289+
):
290+
# Server >= 1.31 with EnableCHASMSignalBacklinks returns signal_link pointing at
291+
# the WorkflowExecutionSignaled event; older servers leave it unset.
292+
if response.HasField("signal_link"):
293+
nexus_link = temporal_link_to_nexus_link(response.signal_link)
294+
295+
try:
296+
if nexus_link is not None:
297+
self.nexus_context.outbound_links.append(nexus_link)
280298
except Exception as e:
281299
logger.warning(
282-
f"Failed to create WorkflowExecutionStarted event links for workflow {workflow_handle}: {e}"
300+
f"Failed to create event links for workflow {workflow_handle}: {e}"
283301
)
284-
return workflow_handle
285302

286303
def _add_response_link(self, link: temporalio.api.common.v1.Link | None) -> None:
287304
"""Append a response link returned by an RPC the operation handler issued.
@@ -699,6 +716,4 @@ async def _start_nexus_backing_workflow(
699716
request_id=temporal_context.nexus_context.request_id,
700717
)
701718

702-
temporal_context._add_backing_workflow_response_link(wf_handle)
703-
704719
return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle)

tests/nexus/test_signal_link_propagation.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
from collections.abc import Generator
1112
from typing import Any
1213
from unittest import mock
1314

@@ -31,6 +32,7 @@
3132
import temporalio.converter
3233
import temporalio.nexus._link_conversion
3334
import temporalio.nexus._operation_context
35+
from temporalio.nexus._operation_context import _TemporalStartOperationContext
3436
from temporalio.client._impl import _ClientImpl
3537
from temporalio.client._interceptor import (
3638
SignalWorkflowInput,
@@ -68,7 +70,7 @@ def _inbound_nexus_link() -> temporalio.api.common.v1.Link:
6870

6971

7072
@pytest.fixture
71-
def nexus_ctx() -> Any:
73+
def nexus_ctx() -> Generator[_TemporalStartOperationContext]:
7274
"""Install a Nexus start-operation context with a single inbound link.
7375
7476
The inbound link is provided in nexusrpc.Link form, exactly as the worker populates it from
@@ -169,7 +171,7 @@ def _outbound_link_urls(ctx: Any) -> list[str]:
169171

170172

171173
async def test_signal_forwards_inbound_links_and_captures_response_backlink(
172-
nexus_ctx: Any,
174+
nexus_ctx: _TemporalStartOperationContext,
173175
) -> None:
174176
response_link = _workflow_event_link(
175177
WORKFLOW_ID,
@@ -197,7 +199,7 @@ async def test_signal_forwards_inbound_links_and_captures_response_backlink(
197199

198200

199201
async def test_signal_against_older_server_captures_no_backlink(
200-
nexus_ctx: Any,
202+
nexus_ctx: _TemporalStartOperationContext,
201203
) -> None:
202204
workflow_service = mock.MagicMock()
203205
workflow_service.signal_workflow_execution = mock.AsyncMock(
@@ -215,7 +217,9 @@ async def test_signal_against_older_server_captures_no_backlink(
215217
assert nexus_ctx.nexus_context.outbound_links == []
216218

217219

218-
async def test_multiple_signals_accumulate_all_backlinks(nexus_ctx: Any) -> None:
220+
async def test_multiple_signals_accumulate_all_backlinks(
221+
nexus_ctx: _TemporalStartOperationContext,
222+
) -> None:
219223
first = _workflow_event_link(
220224
"callee-a",
221225
"run-a",
@@ -265,7 +269,7 @@ async def test_signal_outside_nexus_context_does_not_touch_links() -> None:
265269

266270

267271
async def test_signal_with_start_forwards_inbound_links_and_captures_backlink(
268-
nexus_ctx: Any,
272+
nexus_ctx: _TemporalStartOperationContext,
269273
) -> None:
270274
response_link = _workflow_event_link(
271275
WORKFLOW_ID,
@@ -294,7 +298,7 @@ async def test_signal_with_start_forwards_inbound_links_and_captures_backlink(
294298

295299

296300
async def test_signal_with_start_against_older_server_captures_no_backlink(
297-
nexus_ctx: Any,
301+
nexus_ctx: _TemporalStartOperationContext,
298302
) -> None:
299303
workflow_service = mock.MagicMock()
300304
workflow_service.signal_with_start_workflow_execution = mock.AsyncMock(
@@ -314,8 +318,8 @@ async def test_signal_with_start_against_older_server_captures_no_backlink(
314318
# ── start ─────────────────────────────────────────────────────────────────────────────────
315319

316320

317-
async def test_start_forwards_inbound_links_and_captures_no_backlink(
318-
nexus_ctx: Any,
321+
async def test_start_forwards_inbound_links_and_captures_backlink(
322+
nexus_ctx: _TemporalStartOperationContext,
319323
) -> None:
320324
server_link = _workflow_event_link(
321325
WORKFLOW_ID,
@@ -338,12 +342,13 @@ async def test_start_forwards_inbound_links_and_captures_no_backlink(
338342
assert len(sent.links) == 1
339343
assert sent.links[0] == _inbound_nexus_link()
340344

341-
# Backward: a plain start does not capture a backlink, even when the server returns one.
342-
assert nexus_ctx.nexus_context.outbound_links == []
345+
# Backward: a plain start captures a backlink
346+
assert len(nexus_ctx.nexus_context.outbound_links) == 1
347+
assert "wf-target" in _outbound_link_urls(nexus_ctx)[0]
343348

344349

345350
async def test_start_against_older_server_captures_no_backlink(
346-
nexus_ctx: Any,
351+
nexus_ctx: _TemporalStartOperationContext,
347352
) -> None:
348353
workflow_service = mock.MagicMock()
349354
workflow_service.start_workflow_execution = mock.AsyncMock(
@@ -360,8 +365,9 @@ async def test_start_against_older_server_captures_no_backlink(
360365
assert len(sent.links) == 1
361366
assert sent.links[0] == _inbound_nexus_link()
362367

363-
# Backward: a plain start never fabricates a backlink.
364-
assert nexus_ctx.nexus_context.outbound_links == []
368+
# Backward: a plain start fabricates a backlink when the server doesn't return one.
369+
assert len(nexus_ctx.nexus_context.outbound_links) == 1
370+
assert "wf-target" in _outbound_link_urls(nexus_ctx)[0]
365371

366372

367373
async def test_start_outside_nexus_context_does_not_touch_links() -> None:

0 commit comments

Comments
 (0)