Skip to content

Commit 64b7b23

Browse files
committed
Add nexus-operation-token header to nexus callback headers for TemporalOperationHandler and WorkflowRunOperationHandler
1 parent d211027 commit 64b7b23

3 files changed

Lines changed: 97 additions & 11 deletions

File tree

temporalio/nexus/_operation_context.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
workflow_event_to_nexus_link,
4848
workflow_execution_started_event_link_from_workflow_handle,
4949
)
50-
from ._token import WorkflowHandle
50+
from ._token import OperationToken, OperationTokenType, WorkflowHandle
5151

5252
if TYPE_CHECKING:
5353
import temporalio.client
@@ -225,15 +225,14 @@ def get(cls) -> _TemporalStartOperationContext:
225225
def set(self) -> None:
226226
_temporal_start_operation_context.set(self)
227227

228-
def _get_callbacks(
229-
self,
230-
) -> list[temporalio.client.Callback]:
228+
def _get_callbacks(self, token: str) -> list[temporalio.client.Callback]:
231229
ctx = self.nexus_context
230+
callback_headers = {**ctx.callback_headers, "nexus-operation-token": token}
232231
return (
233232
[
234233
NexusCallback(
235234
url=ctx.callback_url,
236-
headers=ctx.callback_headers,
235+
headers=callback_headers,
237236
)
238237
]
239238
if ctx.callback_url
@@ -643,6 +642,11 @@ async def _start_nexus_backing_workflow(
643642
# terminal state) and inbound links to the caller workflow (attached to history events of
644643
# the workflow started in the handler namespace, and displayed in the UI).
645644
with _nexus_backing_workflow_start_context():
645+
token = OperationToken(
646+
type=OperationTokenType.WORKFLOW,
647+
namespace=temporal_context.client.namespace,
648+
workflow_id=id,
649+
).encode()
646650
wf_handle = await temporal_context.client.start_workflow( # type: ignore
647651
workflow=workflow,
648652
arg=arg,
@@ -669,7 +673,7 @@ async def _start_nexus_backing_workflow(
669673
request_eager_start=request_eager_start,
670674
priority=priority,
671675
versioning_override=versioning_override,
672-
callbacks=temporal_context._get_callbacks(),
676+
callbacks=temporal_context._get_callbacks(token),
673677
links=temporal_context._get_links(),
674678
request_id=temporal_context.nexus_context.request_id,
675679
)

tests/nexus/test_temporal_operation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from temporalio import nexus, workflow
1313
from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError
1414
from temporalio.common import NexusOperationExecutionStatus, WorkflowIDConflictPolicy
15+
from temporalio.nexus._token import OperationToken, OperationTokenType
1516
from temporalio.testing import WorkflowEnvironment
1617
from temporalio.worker import Worker
1718
from tests.helpers import EventType, assert_event_subsequence, assert_eventually
@@ -685,3 +686,41 @@ async def test_temporal_operation_overloads(
685686
if op == "no_param"
686687
else TemporalOperationOverloadTestValue(value=4)
687688
)
689+
690+
691+
async def test_temporal_operation_includes_token_in_callback(
692+
client: Client, env: WorkflowEnvironment
693+
):
694+
task_queue = str(uuid.uuid4())
695+
endpoint_name = make_nexus_endpoint_name(task_queue)
696+
await env.create_nexus_endpoint(endpoint_name, task_queue)
697+
async with Worker(
698+
env.client,
699+
task_queue=task_queue,
700+
nexus_service_handlers=[TestServiceHandler()],
701+
workflows=[EchoWorkflow, EchoWorkflowCaller],
702+
):
703+
input_value = f"test-{uuid.uuid4()}"
704+
wf_handle = await client.start_workflow(
705+
EchoWorkflowCaller.run,
706+
Input(value=input_value, task_queue=task_queue),
707+
task_queue=task_queue,
708+
id=str(uuid.uuid4()),
709+
)
710+
result = await wf_handle.result()
711+
assert result == input_value
712+
713+
target_handle = client.get_workflow_handle(f"echo-{input_value}")
714+
715+
desc = await target_handle.describe()
716+
token = desc.raw_description.callbacks[0].callback.nexus.header[
717+
"nexus-operation-token"
718+
]
719+
720+
expected_token = OperationToken(
721+
type=OperationTokenType.WORKFLOW,
722+
namespace=client.namespace,
723+
workflow_id=target_handle.id,
724+
).encode()
725+
726+
assert token == expected_token

tests/nexus/test_workflow_run_operation.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from temporalio.client import Client
1919
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
2020
from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler
21+
from temporalio.nexus._token import OperationToken, OperationTokenType
2122
from temporalio.testing import WorkflowEnvironment
2223
from temporalio.worker import Worker
2324
from tests.helpers.nexus import make_nexus_endpoint_name
@@ -48,7 +49,7 @@ async def start(
4849
handle = await tctx.start_workflow(
4950
EchoWorkflow.run,
5051
input.value,
51-
id=str(uuid.uuid4()),
52+
id=input.value,
5253
)
5354
return StartOperationResultAsync(handle.to_token())
5455

@@ -78,7 +79,7 @@ async def op(
7879
return await ctx.start_workflow(
7980
EchoWorkflow.run,
8081
input.value,
81-
id=str(uuid.uuid4()),
82+
id=input.value,
8283
)
8384

8485

@@ -146,13 +147,14 @@ async def test_workflow_run_operation(
146147
nexus_service_handlers=[service_handler_cls()],
147148
workflows=[CallerWorkflow, EchoWorkflow],
148149
):
150+
input_value = str(uuid.uuid4())
149151
result = await client.execute_workflow(
150152
CallerWorkflow.run,
151-
args=[Input(value="test"), service_defn.name, task_queue],
153+
args=[Input(value=input_value), service_defn.name, task_queue],
152154
id=str(uuid.uuid4()),
153155
task_queue=task_queue,
154156
)
155-
assert result == "test"
157+
assert result == input_value
156158

157159

158160
async def test_request_deadline_is_accessible_in_workflow_run_operation(
@@ -173,9 +175,10 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation(
173175
nexus_service_handlers=[service_handler],
174176
workflows=[RequestDeadlineWorkflow, EchoWorkflow],
175177
):
178+
input_value = str(uuid.uuid4())
176179
await client.execute_workflow(
177180
RequestDeadlineWorkflow.run,
178-
args=[Input(value="test"), task_queue],
181+
args=[Input(value=input_value), task_queue],
179182
task_queue=task_queue,
180183
id=str(uuid.uuid4()),
181184
)
@@ -186,3 +189,43 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation(
186189
"request_deadline should be set in WorkflowRunOperationContext"
187190
)
188191
assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc"
192+
193+
194+
async def test_workflow_run_operation_includes_token_in_callback(
195+
client: Client,
196+
env: WorkflowEnvironment,
197+
):
198+
if env.supports_time_skipping:
199+
pytest.skip("Nexus tests don't work with time-skipping server")
200+
201+
task_queue = str(uuid.uuid4())
202+
await env.create_nexus_endpoint(make_nexus_endpoint_name(task_queue), task_queue)
203+
async with Worker(
204+
client,
205+
task_queue=task_queue,
206+
nexus_service_handlers=[SubclassingHappyPath()],
207+
workflows=[CallerWorkflow, EchoWorkflow],
208+
):
209+
input_value = str(uuid.uuid4())
210+
result = await client.execute_workflow(
211+
CallerWorkflow.run,
212+
args=[Input(value=input_value), "SubclassingHappyPath", task_queue],
213+
id=str(uuid.uuid4()),
214+
task_queue=task_queue,
215+
)
216+
assert result == input_value
217+
218+
target_handle = client.get_workflow_handle(input_value)
219+
220+
desc = await target_handle.describe()
221+
token = desc.raw_description.callbacks[0].callback.nexus.header[
222+
"nexus-operation-token"
223+
]
224+
225+
expected_token = OperationToken(
226+
type=OperationTokenType.WORKFLOW,
227+
namespace=client.namespace,
228+
workflow_id=target_handle.id,
229+
).encode()
230+
231+
assert token == expected_token

0 commit comments

Comments
 (0)