Skip to content

Commit b9b2cc3

Browse files
authored
Fix interceptor contract inconsistency for start_update_with_start_workflow (#1588)
* Fix interceptor contract inconsistency for start_update_with_start_workflow Add top-level rpc_metadata and rpc_timeout fields to StartWorkflowUpdateWithStartInput, making it consistent with every other OutboundInterceptor input dataclass. Previously this composite input lacked these fields, forcing interceptors to special-case it. Also fix the _ClientImpl to actually pass rpc_metadata and rpc_timeout to the execute_multi_operation gRPC call, which were previously silently dropped. Add a test verifying that rpc_metadata set by an interceptor on StartWorkflowUpdateWithStartInput is forwarded to the gRPC call. Fixes #1582 * Remove unused rpc_metadata/rpc_timeout from child interceptor inputs Remove rpc_metadata and rpc_timeout fields from UpdateWithStartUpdateWorkflowInput and UpdateWithStartStartWorkflowInput. These fields were never forwarded to the underlying execute_multi_operation gRPC call — only the top-level StartWorkflowUpdateWithStartInput fields are authoritative. Also remove the corresponding parameters from WithStartWorkflowOperation since they only served to populate the (now-removed) child input fields. This is a breaking change for interceptors that accessed rpc_metadata or rpc_timeout on the child input objects.
1 parent 24badcf commit b9b2cc3

5 files changed

Lines changed: 104 additions & 21 deletions

File tree

temporalio/client/_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,8 +1190,6 @@ async def _start_update_with_start(
11901190
args=temporalio.common._arg_or_args(arg, args),
11911191
headers={},
11921192
ret_type=result_type or result_type_from_type_hint,
1193-
rpc_metadata=rpc_metadata,
1194-
rpc_timeout=rpc_timeout,
11951193
wait_for_stage=wait_for_stage,
11961194
)
11971195

@@ -1216,6 +1214,8 @@ def on_start_error(
12161214
input = StartWorkflowUpdateWithStartInput(
12171215
start_workflow_input=start_workflow_operation._start_workflow_input,
12181216
update_workflow_input=update_input,
1217+
rpc_metadata=rpc_metadata,
1218+
rpc_timeout=rpc_timeout,
12191219
_on_start=on_start,
12201220
_on_start_error=on_start_error,
12211221
)

temporalio/client/_impl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,11 @@ def on_start(
852852

853853
try:
854854
return await self._start_workflow_update_with_start(
855-
input.start_workflow_input, input.update_workflow_input, on_start
855+
input.start_workflow_input,
856+
input.update_workflow_input,
857+
input.rpc_metadata,
858+
input.rpc_timeout,
859+
on_start,
856860
)
857861
except asyncio.CancelledError as _err:
858862
err = _err
@@ -914,6 +918,8 @@ async def _start_workflow_update_with_start(
914918
self,
915919
start_input: UpdateWithStartStartWorkflowInput,
916920
update_input: UpdateWithStartUpdateWorkflowInput,
921+
rpc_metadata: Mapping[str, str | bytes],
922+
rpc_timeout: timedelta | None,
917923
on_start: Callable[
918924
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
919925
],
@@ -941,7 +947,12 @@ async def _start_workflow_update_with_start(
941947
# Repeatedly try to invoke ExecuteMultiOperation until the update is durable
942948
while True:
943949
multiop_response = (
944-
await self._client.workflow_service.execute_multi_operation(multiop_req)
950+
await self._client.workflow_service.execute_multi_operation(
951+
multiop_req,
952+
retry=True,
953+
metadata=rpc_metadata,
954+
timeout=rpc_timeout,
955+
)
945956
)
946957
start_response = multiop_response.responses[0].start_workflow
947958
update_response = multiop_response.responses[1].update_workflow

temporalio/client/_interceptor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ class UpdateWithStartUpdateWorkflowInput:
334334
wait_for_stage: WorkflowUpdateStage
335335
headers: Mapping[str, temporalio.api.common.v1.Payload]
336336
ret_type: type | None
337-
rpc_metadata: Mapping[str, str | bytes]
338-
rpc_timeout: timedelta | None
339337

340338

341339
@dataclass
@@ -366,18 +364,23 @@ class UpdateWithStartStartWorkflowInput:
366364
static_details: str | None
367365
# Type may be absent
368366
ret_type: type | None
369-
rpc_metadata: Mapping[str, str | bytes]
370-
rpc_timeout: timedelta | None
371367
priority: temporalio.common.Priority
372368
versioning_override: temporalio.common.VersioningOverride | None = None
373369

374370

375371
@dataclass
376372
class StartWorkflowUpdateWithStartInput:
377-
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`."""
373+
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.
374+
375+
The ``rpc_metadata`` and ``rpc_timeout`` fields are authoritative for the
376+
``execute_multi_operation`` gRPC call. Interceptors that wish to set RPC
377+
metadata should modify :py:attr:`rpc_metadata` on this object.
378+
"""
378379

379380
start_workflow_input: UpdateWithStartStartWorkflowInput
380381
update_workflow_input: UpdateWithStartUpdateWorkflowInput
382+
rpc_metadata: Mapping[str, str | bytes]
383+
rpc_timeout: timedelta | None
381384
_on_start: Callable[
382385
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
383386
]

temporalio/client/_workflow.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,8 +1065,6 @@ def __init__(
10651065
static_summary: str | None = None,
10661066
static_details: str | None = None,
10671067
start_delay: timedelta | None = None,
1068-
rpc_metadata: Mapping[str, str | bytes] = {},
1069-
rpc_timeout: timedelta | None = None,
10701068
priority: temporalio.common.Priority = temporalio.common.Priority.default,
10711069
versioning_override: temporalio.common.VersioningOverride | None = None,
10721070
) -> None: ...
@@ -1095,8 +1093,6 @@ def __init__(
10951093
static_summary: str | None = None,
10961094
static_details: str | None = None,
10971095
start_delay: timedelta | None = None,
1098-
rpc_metadata: Mapping[str, str | bytes] = {},
1099-
rpc_timeout: timedelta | None = None,
11001096
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11011097
versioning_override: temporalio.common.VersioningOverride | None = None,
11021098
) -> None: ...
@@ -1127,8 +1123,6 @@ def __init__(
11271123
static_summary: str | None = None,
11281124
static_details: str | None = None,
11291125
start_delay: timedelta | None = None,
1130-
rpc_metadata: Mapping[str, str | bytes] = {},
1131-
rpc_timeout: timedelta | None = None,
11321126
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11331127
versioning_override: temporalio.common.VersioningOverride | None = None,
11341128
) -> None: ...
@@ -1159,8 +1153,6 @@ def __init__(
11591153
static_summary: str | None = None,
11601154
static_details: str | None = None,
11611155
start_delay: timedelta | None = None,
1162-
rpc_metadata: Mapping[str, str | bytes] = {},
1163-
rpc_timeout: timedelta | None = None,
11641156
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11651157
versioning_override: temporalio.common.VersioningOverride | None = None,
11661158
) -> None: ...
@@ -1189,8 +1181,6 @@ def __init__(
11891181
static_summary: str | None = None,
11901182
static_details: str | None = None,
11911183
start_delay: timedelta | None = None,
1192-
rpc_metadata: Mapping[str, str | bytes] = {},
1193-
rpc_timeout: timedelta | None = None,
11941184
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11951185
versioning_override: temporalio.common.VersioningOverride | None = None,
11961186
stack_level: int = 2,
@@ -1228,8 +1218,6 @@ def __init__(
12281218
start_delay=start_delay,
12291219
headers={},
12301220
ret_type=result_type or result_type_from_run_fn,
1231-
rpc_metadata=rpc_metadata,
1232-
rpc_timeout=rpc_timeout,
12331221
priority=priority,
12341222
versioning_override=versioning_override,
12351223
)

tests/worker/test_update_with_start.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,3 +1104,84 @@ async def _do_update() -> Any:
11041104
elif id_reuse_policy == WorkflowIDReusePolicy.REJECT_DUPLICATE:
11051105
with pytest.raises(WorkflowAlreadyStartedError):
11061106
await _do_update()
1107+
1108+
1109+
class MetadataCapturingInterceptor(Interceptor):
1110+
"""Interceptor that sets rpc_metadata on update-with-start calls."""
1111+
1112+
def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
1113+
return MetadataCapturingOutboundInterceptor(super().intercept_client(next))
1114+
1115+
1116+
class MetadataCapturingOutboundInterceptor(OutboundInterceptor):
1117+
def __init__(self, next: OutboundInterceptor) -> None:
1118+
super().__init__(next)
1119+
1120+
async def start_update_with_start_workflow(
1121+
self, input: StartWorkflowUpdateWithStartInput
1122+
) -> WorkflowUpdateHandle[Any]:
1123+
input.rpc_metadata = {
1124+
**input.rpc_metadata,
1125+
"test-header-key": "test-header-value",
1126+
}
1127+
return await super().start_update_with_start_workflow(input)
1128+
1129+
1130+
# Verify fix for https://github.com/temporalio/sdk-python/issues/1582
1131+
async def test_update_with_start_rpc_metadata_and_timeout_forwarded(client: Client):
1132+
"""Test that rpc_metadata and rpc_timeout on StartWorkflowUpdateWithStartInput
1133+
are forwarded to the execute_multi_operation gRPC call."""
1134+
captured_metadata: dict[str, str | bytes] = {}
1135+
captured_timeout: list[timedelta | None] = []
1136+
1137+
class execute_multi_operation:
1138+
err = RPCError("intentional", RPCStatusCode.INTERNAL, b"")
1139+
err._grpc_status = temporalio.api.common.v1.GrpcStatus(details=[])
1140+
1141+
def __init__(self) -> None: # type: ignore[reportMissingSuperCall]
1142+
pass
1143+
1144+
async def __call__(
1145+
self,
1146+
req: temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest,
1147+
*,
1148+
retry: bool = False,
1149+
metadata: Mapping[str, str | bytes] = {},
1150+
timeout: timedelta | None = None,
1151+
) -> temporalio.api.workflowservice.v1.ExecuteMultiOperationResponse:
1152+
captured_metadata.update(metadata)
1153+
captured_timeout.append(timeout)
1154+
raise self.err
1155+
1156+
interceptor = MetadataCapturingInterceptor()
1157+
intercepted_client = Client(
1158+
**{**client.config(), "interceptors": [interceptor]} # type: ignore
1159+
)
1160+
1161+
with patch.object(
1162+
intercepted_client.workflow_service,
1163+
"execute_multi_operation",
1164+
execute_multi_operation(),
1165+
):
1166+
start_workflow_operation = WithStartWorkflowOperation(
1167+
UpdateWithStartInterceptorWorkflow.run,
1168+
"wf-arg",
1169+
id=f"wf-{uuid.uuid4()}",
1170+
task_queue="tq",
1171+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
1172+
)
1173+
with pytest.raises(RPCError):
1174+
await intercepted_client.start_update_with_start_workflow(
1175+
UpdateWithStartInterceptorWorkflow.my_update,
1176+
"update-arg",
1177+
start_workflow_operation=start_workflow_operation,
1178+
wait_for_stage=WorkflowUpdateStage.ACCEPTED,
1179+
rpc_metadata={"original-key": "original-value"},
1180+
rpc_timeout=timedelta(seconds=42),
1181+
)
1182+
1183+
# The interceptor should have added its metadata on top of the caller's
1184+
assert captured_metadata.get("test-header-key") == "test-header-value"
1185+
assert captured_metadata.get("original-key") == "original-value"
1186+
# The caller's timeout should have been forwarded
1187+
assert captured_timeout == [timedelta(seconds=42)]

0 commit comments

Comments
 (0)