Skip to content

Commit e687fc6

Browse files
committed
Swap type aliases back to classes to avoid printing or runtime type checking concerns
1 parent ff7c0b8 commit e687fc6

3 files changed

Lines changed: 42 additions & 26 deletions

File tree

temporalio/nexus/_operation_context.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Any,
2020
Concatenate,
2121
Generic,
22-
TypeAlias,
2322
TypeVar,
2423
overload,
2524
)
@@ -29,6 +28,7 @@
2928
OperationContext,
3029
StartOperationContext,
3130
)
31+
from typing_extensions import Self
3232

3333
import temporalio.api.common.v1
3434
import temporalio.api.workflowservice.v1
@@ -549,8 +549,32 @@ def set(self) -> None:
549549
_temporal_cancel_operation_context.set(self)
550550

551551

552-
TemporalNexusStartOperationContext: TypeAlias = StartOperationContext
553-
TemporalNexusCancelOperationContext: TypeAlias = CancelOperationContext
552+
class TemporalNexusStartOperationContext(StartOperationContext):
553+
"""Context received by a Temporal Nexus operation when it is started.
554+
555+
.. warning::
556+
This API is experimental and unstable.
557+
"""
558+
559+
@classmethod
560+
def _from_start_operation_context(cls, ctx: StartOperationContext) -> Self:
561+
return cls(
562+
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
563+
)
564+
565+
566+
class TemporalNexusCancelOperationContext(CancelOperationContext):
567+
"""Context received by a Temporal Nexus operation when it is canceled.
568+
569+
.. warning::
570+
This API is experimental and unstable.
571+
"""
572+
573+
@classmethod
574+
def _from_cancel_operation_context(cls, ctx: CancelOperationContext) -> Self:
575+
return cls(
576+
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
577+
)
554578

555579

556580
class LoggerAdapter(logging.LoggerAdapter):

temporalio/nexus/_operation_handlers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ async def start(
150150
This API is experimental and unstable.
151151
"""
152152
nexus_client = _TemporalNexusClient()
153-
result = await self.start_operation(ctx, nexus_client, input)
153+
start_ctx = TemporalNexusStartOperationContext._from_start_operation_context(
154+
ctx
155+
)
156+
result = await self.start_operation(start_ctx, nexus_client, input)
154157
return result._to_nexus_result()
155158

156159
async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
@@ -167,9 +170,12 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
167170
type=HandlerErrorType.INTERNAL,
168171
) from err
169172

173+
cancel_ctx = TemporalNexusCancelOperationContext._from_cancel_operation_context(
174+
ctx
175+
)
170176
match operation_token.type:
171177
case OperationTokenType.WORKFLOW:
172-
await self.cancel_workflow_run(ctx, operation_token.workflow_id)
178+
await self.cancel_workflow_run(cancel_ctx, operation_token.workflow_id)
173179

174180
async def cancel_workflow_run(
175181
self,

temporalio/nexus/_util.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import typing
66
import warnings
77
from collections.abc import Awaitable, Callable
8-
from dataclasses import dataclass
98
from typing import (
109
Any,
1110
)
@@ -31,16 +30,6 @@
3130
)
3231

3332

34-
@dataclass(frozen=True)
35-
class _ExpectedParamType:
36-
runtime_type: type[Any]
37-
display_name: str | None = None
38-
39-
@property
40-
def name(self) -> str:
41-
return self.display_name or self.runtime_type.__name__
42-
43-
4433
def get_workflow_run_start_method_input_and_output_type_annotations(
4534
start: Callable[
4635
[NexusServiceType, WorkflowRunOperationContext, InputT],
@@ -57,7 +46,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations(
5746
"""
5847
return _get_wrapped_start_method_input_and_output_type_annotations(
5948
start,
60-
expected_param_types=(_ExpectedParamType(WorkflowRunOperationContext),),
49+
expected_param_types=(WorkflowRunOperationContext,),
6150
expected_return_origin=WorkflowHandle,
6251
)
6352

@@ -84,11 +73,8 @@ def get_temporal_operation_start_method_input_and_output_type_annotations(
8473
return _get_wrapped_start_method_input_and_output_type_annotations(
8574
start,
8675
expected_param_types=(
87-
_ExpectedParamType(
88-
TemporalNexusStartOperationContext,
89-
"TemporalNexusStartOperationContext",
90-
),
91-
_ExpectedParamType(TemporalNexusClient),
76+
TemporalNexusStartOperationContext,
77+
TemporalNexusClient,
9278
),
9379
expected_return_origin=TemporalOperationResult,
9480
)
@@ -97,7 +83,7 @@ def get_temporal_operation_start_method_input_and_output_type_annotations(
9783
def _get_wrapped_start_method_input_and_output_type_annotations(
9884
start: Callable[..., Any],
9985
*,
100-
expected_param_types: tuple[_ExpectedParamType, ...],
86+
expected_param_types: tuple[type[Any], ...],
10187
expected_return_origin: type[Any],
10288
) -> tuple[
10389
type[Any] | None,
@@ -135,7 +121,7 @@ def _get_wrapped_start_method_input_and_output_type_annotations(
135121
def _get_start_method_input_and_output_type_annotations(
136122
start: Callable[..., Any],
137123
*,
138-
expected_param_types: tuple[_ExpectedParamType, ...],
124+
expected_param_types: tuple[type[Any], ...],
139125
) -> tuple[
140126
type[Any] | None,
141127
type[Any] | None,
@@ -164,10 +150,10 @@ def _get_start_method_input_and_output_type_annotations(
164150
for index, (param_type, expected_param_type) in enumerate(
165151
zip(param_types, expected_param_types), start=1
166152
):
167-
if not issubclass(expected_param_type.runtime_type, param_type):
153+
if not issubclass(expected_param_type, param_type):
168154
warnings.warn(
169155
f"Expected parameter {index} of {start} to be an instance of "
170-
f"{expected_param_type.name}, but is {param_type}."
156+
f"{expected_param_type.__name__}, but is {param_type}."
171157
)
172158
input_type = None
173159

0 commit comments

Comments
 (0)