22
33from collections .abc import Awaitable , Callable
44from typing import (
5- TypeVar ,
5+ TypeAlias ,
66 overload ,
77)
88
1212 OperationHandler ,
1313 StartOperationContext ,
1414)
15+ from typing_extensions import override
1516
16- from ._operation_context import WorkflowRunOperationContext
17- from ._operation_handlers import WorkflowRunOperationHandler
17+ from temporalio .nexus ._temporal_client import (
18+ TemporalNexusClient ,
19+ TemporalOperationResult ,
20+ )
21+ from temporalio .types import NexusServiceType
22+
23+ from ._operation_context import (
24+ TemporalNexusStartOperationContext ,
25+ WorkflowRunOperationContext ,
26+ )
27+ from ._operation_handlers import (
28+ TemporalNexusOperationHandler ,
29+ WorkflowRunOperationHandler ,
30+ )
1831from ._token import WorkflowHandle
1932from ._util import (
2033 get_callable_name ,
34+ get_temporal_operation_start_method_input_and_output_type_annotations ,
2135 get_workflow_run_start_method_input_and_output_type_annotations ,
36+ is_async_callable ,
2237 set_operation_factory ,
2338)
2439
25- ServiceHandlerT = TypeVar ("ServiceHandlerT" )
26-
2740
2841@overload
2942def workflow_run_operation (
3043 start : Callable [
31- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
44+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
3245 Awaitable [WorkflowHandle [OutputT ]],
3346 ],
3447) -> Callable [
35- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
48+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
3649 Awaitable [WorkflowHandle [OutputT ]],
3750]: ...
3851
@@ -44,12 +57,12 @@ def workflow_run_operation(
4457) -> Callable [
4558 [
4659 Callable [
47- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
60+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
4861 Awaitable [WorkflowHandle [OutputT ]],
4962 ]
5063 ],
5164 Callable [
52- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
65+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
5366 Awaitable [WorkflowHandle [OutputT ]],
5467 ],
5568]: ...
@@ -59,26 +72,26 @@ def workflow_run_operation(
5972 start : None
6073 | (
6174 Callable [
62- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
75+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
6376 Awaitable [WorkflowHandle [OutputT ]],
6477 ]
6578 ) = None ,
6679 * ,
6780 name : str | None = None ,
6881) -> (
6982 Callable [
70- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
83+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
7184 Awaitable [WorkflowHandle [OutputT ]],
7285 ]
7386 | Callable [
7487 [
7588 Callable [
76- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
89+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
7790 Awaitable [WorkflowHandle [OutputT ]],
7891 ]
7992 ],
8093 Callable [
81- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
94+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
8295 Awaitable [WorkflowHandle [OutputT ]],
8396 ],
8497 ]
@@ -87,11 +100,11 @@ def workflow_run_operation(
87100
88101 def decorator (
89102 start : Callable [
90- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
103+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
91104 Awaitable [WorkflowHandle [OutputT ]],
92105 ],
93106 ) -> Callable [
94- [ServiceHandlerT , WorkflowRunOperationContext , InputT ],
107+ [NexusServiceType , WorkflowRunOperationContext , InputT ],
95108 Awaitable [WorkflowHandle [OutputT ]],
96109 ]:
97110 (
@@ -100,7 +113,7 @@ def decorator(
100113 ) = get_workflow_run_start_method_input_and_output_type_annotations (start )
101114
102115 def operation_handler_factory (
103- self : ServiceHandlerT ,
116+ self : NexusServiceType ,
104117 ) -> OperationHandler [InputT , OutputT ]:
105118 async def _start (
106119 ctx : StartOperationContext , input : InputT
@@ -130,3 +143,109 @@ async def _start(
130143 return decorator
131144
132145 return decorator (start )
146+
147+
148+ TemporalNexusOperationStartHandlerFunc : TypeAlias = Callable [
149+ [
150+ NexusServiceType ,
151+ TemporalNexusStartOperationContext ,
152+ TemporalNexusClient ,
153+ InputT ,
154+ ],
155+ Awaitable [TemporalOperationResult [OutputT ]],
156+ ]
157+
158+
159+ @overload
160+ def temporal_operation (
161+ start : TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ],
162+ ) -> TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ]: ...
163+
164+
165+ @overload
166+ def temporal_operation (
167+ * ,
168+ name : str | None = None ,
169+ ) -> Callable [
170+ [TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ]],
171+ TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ],
172+ ]: ...
173+
174+
175+ def temporal_operation (
176+ start : None
177+ | TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ] = None ,
178+ * ,
179+ name : str | None = None ,
180+ ) -> (
181+ TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ]
182+ | Callable [
183+ [TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ]],
184+ TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ],
185+ ]
186+ ):
187+ """Decorator marking a method as the start method for an operation that interacts with Temporal.
188+
189+ .. warning::
190+ This API is experimental and unstable.
191+ """
192+
193+ def decorator (
194+ start : TemporalNexusOperationStartHandlerFunc [
195+ NexusServiceType , InputT , OutputT
196+ ],
197+ ) -> TemporalNexusOperationStartHandlerFunc [NexusServiceType , InputT , OutputT ]:
198+ if not is_async_callable (start ):
199+ raise RuntimeError (
200+ f"{ start } is not an `async def` method. "
201+ "@temporal_operation must decorate an `async def` start method."
202+ )
203+ (
204+ input_type ,
205+ output_type ,
206+ ) = get_temporal_operation_start_method_input_and_output_type_annotations (start )
207+
208+ def operation_handler_factory (
209+ self : NexusServiceType ,
210+ ) -> OperationHandler [InputT , OutputT ]:
211+ async def _start (
212+ ctx : TemporalNexusStartOperationContext ,
213+ client : TemporalNexusClient ,
214+ input : InputT ,
215+ ) -> TemporalOperationResult [OutputT ]:
216+ return await start (
217+ self ,
218+ ctx ,
219+ client ,
220+ input ,
221+ )
222+
223+ class _TemporalNexusOperationHandler (TemporalNexusOperationHandler ):
224+ @override
225+ async def start_operation (
226+ self ,
227+ ctx : TemporalNexusStartOperationContext ,
228+ client : TemporalNexusClient ,
229+ input : InputT ,
230+ ) -> TemporalOperationResult [OutputT ]:
231+ return await _start (ctx , client , input )
232+
233+ _TemporalNexusOperationHandler .start_operation .__doc__ = start .__doc__
234+ return _TemporalNexusOperationHandler ()
235+
236+ method_name = get_callable_name (start )
237+ op = nexusrpc .Operation (
238+ name = name or method_name ,
239+ input_type = input_type ,
240+ output_type = output_type ,
241+ )
242+ op .method_name = method_name
243+ nexusrpc .set_operation (operation_handler_factory , op )
244+
245+ set_operation_factory (start , operation_handler_factory )
246+ return start
247+
248+ if start is None :
249+ return decorator
250+
251+ return decorator (start )
0 commit comments