Skip to content

Commit 91e26f1

Browse files
committed
Fix interceptor contract inconsistency for start_update_with_start_workflow
Add top-level rpc_metadata and rpc_timeout fields to StartWorkflowUpdateWithStartInput, making it consistent with every other OutboundInterceptor input dataclass. Previously this composite input lacked these fields, forcing interceptors to special-case it. Also fix the _ClientImpl to actually pass rpc_metadata and rpc_timeout to the execute_multi_operation gRPC call, which were previously silently dropped. Add a test verifying that rpc_metadata set by an interceptor on StartWorkflowUpdateWithStartInput is forwarded to the gRPC call. Fixes #1582
1 parent 8d66724 commit 91e26f1

4 files changed

Lines changed: 93 additions & 2 deletions

File tree

temporalio/client/_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,8 @@ def on_start_error(
12161216
input = StartWorkflowUpdateWithStartInput(
12171217
start_workflow_input=start_workflow_operation._start_workflow_input,
12181218
update_workflow_input=update_input,
1219+
rpc_metadata=rpc_metadata,
1220+
rpc_timeout=rpc_timeout,
12191221
_on_start=on_start,
12201222
_on_start_error=on_start_error,
12211223
)

temporalio/client/_impl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,11 @@ def on_start(
852852

853853
try:
854854
return await self._start_workflow_update_with_start(
855-
input.start_workflow_input, input.update_workflow_input, on_start
855+
input.start_workflow_input,
856+
input.update_workflow_input,
857+
input.rpc_metadata,
858+
input.rpc_timeout,
859+
on_start,
856860
)
857861
except asyncio.CancelledError as _err:
858862
err = _err
@@ -914,6 +918,8 @@ async def _start_workflow_update_with_start(
914918
self,
915919
start_input: UpdateWithStartStartWorkflowInput,
916920
update_input: UpdateWithStartUpdateWorkflowInput,
921+
rpc_metadata: Mapping[str, str | bytes],
922+
rpc_timeout: timedelta | None,
917923
on_start: Callable[
918924
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
919925
],
@@ -941,7 +947,12 @@ async def _start_workflow_update_with_start(
941947
# Repeatedly try to invoke ExecuteMultiOperation until the update is durable
942948
while True:
943949
multiop_response = (
944-
await self._client.workflow_service.execute_multi_operation(multiop_req)
950+
await self._client.workflow_service.execute_multi_operation(
951+
multiop_req,
952+
retry=True,
953+
metadata=rpc_metadata,
954+
timeout=rpc_timeout,
955+
)
945956
)
946957
start_response = multiop_response.responses[0].start_workflow
947958
update_response = multiop_response.responses[1].update_workflow

temporalio/client/_interceptor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ class StartWorkflowUpdateWithStartInput:
378378

379379
start_workflow_input: UpdateWithStartStartWorkflowInput
380380
update_workflow_input: UpdateWithStartUpdateWorkflowInput
381+
rpc_metadata: Mapping[str, str | bytes]
382+
rpc_timeout: timedelta | None
381383
_on_start: Callable[
382384
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
383385
]

tests/worker/test_update_with_start.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,3 +1104,79 @@ async def _do_update() -> Any:
11041104
elif id_reuse_policy == WorkflowIDReusePolicy.REJECT_DUPLICATE:
11051105
with pytest.raises(WorkflowAlreadyStartedError):
11061106
await _do_update()
1107+
1108+
1109+
class MetadataCapturingInterceptor(Interceptor):
1110+
"""Interceptor that sets rpc_metadata on update-with-start calls."""
1111+
1112+
def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
1113+
return MetadataCapturingOutboundInterceptor(super().intercept_client(next))
1114+
1115+
1116+
class MetadataCapturingOutboundInterceptor(OutboundInterceptor):
1117+
def __init__(self, next: OutboundInterceptor) -> None:
1118+
super().__init__(next)
1119+
1120+
async def start_update_with_start_workflow(
1121+
self, input: StartWorkflowUpdateWithStartInput
1122+
) -> WorkflowUpdateHandle[Any]:
1123+
input.rpc_metadata = {
1124+
**input.rpc_metadata,
1125+
"test-header-key": "test-header-value",
1126+
}
1127+
return await super().start_update_with_start_workflow(input)
1128+
1129+
1130+
# Verify fix for https://github.com/temporalio/sdk-python/issues/1582
1131+
async def test_update_with_start_rpc_metadata_forwarded(client: Client):
1132+
"""Test that rpc_metadata on StartWorkflowUpdateWithStartInput is forwarded
1133+
to the execute_multi_operation gRPC call."""
1134+
captured_metadata: dict[str, str | bytes] = {}
1135+
1136+
class execute_multi_operation:
1137+
err = RPCError("intentional", RPCStatusCode.INTERNAL, b"")
1138+
err._grpc_status = temporalio.api.common.v1.GrpcStatus(details=[])
1139+
1140+
def __init__(self) -> None: # type: ignore[reportMissingSuperCall]
1141+
pass
1142+
1143+
async def __call__(
1144+
self,
1145+
req: temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest,
1146+
*,
1147+
retry: bool = False,
1148+
metadata: Mapping[str, str | bytes] = {},
1149+
timeout: timedelta | None = None,
1150+
) -> temporalio.api.workflowservice.v1.ExecuteMultiOperationResponse:
1151+
captured_metadata.update(metadata)
1152+
raise self.err
1153+
1154+
interceptor = MetadataCapturingInterceptor()
1155+
intercepted_client = Client(
1156+
**{**client.config(), "interceptors": [interceptor]} # type: ignore
1157+
)
1158+
1159+
with patch.object(
1160+
intercepted_client.workflow_service,
1161+
"execute_multi_operation",
1162+
execute_multi_operation(),
1163+
):
1164+
start_workflow_operation = WithStartWorkflowOperation(
1165+
UpdateWithStartInterceptorWorkflow.run,
1166+
"wf-arg",
1167+
id=f"wf-{uuid.uuid4()}",
1168+
task_queue="tq",
1169+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
1170+
)
1171+
with pytest.raises(RPCError):
1172+
await intercepted_client.start_update_with_start_workflow(
1173+
UpdateWithStartInterceptorWorkflow.my_update,
1174+
"update-arg",
1175+
start_workflow_operation=start_workflow_operation,
1176+
wait_for_stage=WorkflowUpdateStage.ACCEPTED,
1177+
rpc_metadata={"original-key": "original-value"},
1178+
)
1179+
1180+
# The interceptor should have added its metadata on top of the caller's
1181+
assert captured_metadata.get("test-header-key") == "test-header-value"
1182+
assert captured_metadata.get("original-key") == "original-value"

0 commit comments

Comments
 (0)