Skip to content

Commit 91abc81

Browse files
Narrow overloads on the SANO client (#1552)
* Narrow overloads on the sano client. * remove result_type params for overloads that don't need them * Add test to show that invalid functions produce a type error * address linter errors
1 parent f0c6afb commit 91abc81

2 files changed

Lines changed: 133 additions & 53 deletions

File tree

temporalio/client/_nexus.py

Lines changed: 103 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import Callable, Mapping, Sequence
4+
from collections.abc import Awaitable, Callable, Mapping, Sequence
55
from dataclasses import dataclass, field
66
from datetime import datetime, timedelta, timezone
77
from typing import TYPE_CHECKING, Any, Generic, cast, overload
@@ -473,7 +473,7 @@ class NexusClient(ABC, Generic[NexusServiceType]):
473473
Use :py:meth:`Client.create_nexus_client` to create a client.
474474
"""
475475

476-
# Overload for nexusrpc.Operation with input
476+
# Overload for nexusrpc.Operation
477477
@overload
478478
@abstractmethod
479479
async def start_operation(
@@ -494,18 +494,18 @@ async def start_operation(
494494
rpc_timeout: timedelta | None = None,
495495
) -> NexusOperationHandle[OutputT]: ...
496496

497-
# Overload for Callable with result_type
497+
# Overload for string operation name
498498
@overload
499499
@abstractmethod
500500
async def start_operation(
501501
self,
502-
operation: Callable[..., Any],
502+
operation: str,
503503
arg: Any,
504504
*,
505505
id: str,
506506
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
507507
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
508-
result_type: type[OutputT],
508+
result_type: type[OutputT] | None = None,
509509
schedule_to_close_timeout: timedelta | None = None,
510510
schedule_to_start_timeout: timedelta | None = None,
511511
start_to_close_timeout: timedelta | None = None,
@@ -516,13 +516,16 @@ async def start_operation(
516516
rpc_timeout: timedelta | None = None,
517517
) -> NexusOperationHandle[OutputT]: ...
518518

519-
# Overload for Callable without result_type
519+
# Overload for workflow_run_operation methods
520520
@overload
521521
@abstractmethod
522522
async def start_operation(
523523
self,
524-
operation: Callable[..., Any],
525-
arg: Any,
524+
operation: Callable[
525+
[NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT],
526+
Awaitable[temporalio.nexus.WorkflowHandle[OutputT]],
527+
],
528+
arg: InputT,
526529
*,
527530
id: str,
528531
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
@@ -535,20 +538,22 @@ async def start_operation(
535538
headers: Mapping[str, str] | None = None,
536539
rpc_metadata: Mapping[str, str | bytes] = {},
537540
rpc_timeout: timedelta | None = None,
538-
) -> NexusOperationHandle[Any]: ...
541+
) -> NexusOperationHandle[OutputT]: ...
539542

540-
# Overload for str with result_type
543+
# Overload for sync_operation methods (async def)
541544
@overload
542545
@abstractmethod
543546
async def start_operation(
544547
self,
545-
operation: str,
546-
arg: Any,
548+
operation: Callable[
549+
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
550+
Awaitable[OutputT],
551+
],
552+
arg: InputT,
547553
*,
548554
id: str,
549555
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
550556
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
551-
result_type: type[OutputT],
552557
schedule_to_close_timeout: timedelta | None = None,
553558
schedule_to_start_timeout: timedelta | None = None,
554559
start_to_close_timeout: timedelta | None = None,
@@ -559,13 +564,16 @@ async def start_operation(
559564
rpc_timeout: timedelta | None = None,
560565
) -> NexusOperationHandle[OutputT]: ...
561566

562-
# Overload for str without result_type
567+
# Overload for sync_operation methods (def)
563568
@overload
564569
@abstractmethod
565570
async def start_operation(
566571
self,
567-
operation: str,
568-
arg: Any,
572+
operation: Callable[
573+
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
574+
OutputT,
575+
],
576+
arg: InputT,
569577
*,
570578
id: str,
571579
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
@@ -578,7 +586,30 @@ async def start_operation(
578586
headers: Mapping[str, str] | None = None,
579587
rpc_metadata: Mapping[str, str | bytes] = {},
580588
rpc_timeout: timedelta | None = None,
581-
) -> NexusOperationHandle[Any]: ...
589+
) -> NexusOperationHandle[OutputT]: ...
590+
591+
# Overload for operation_handler
592+
@overload
593+
@abstractmethod
594+
async def start_operation(
595+
self,
596+
operation: Callable[
597+
[NexusServiceType], nexusrpc.handler.OperationHandler[InputT, OutputT]
598+
],
599+
arg: InputT,
600+
*,
601+
id: str,
602+
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
603+
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
604+
schedule_to_close_timeout: timedelta | None = None,
605+
schedule_to_start_timeout: timedelta | None = None,
606+
start_to_close_timeout: timedelta | None = None,
607+
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
608+
summary: str | None = None,
609+
headers: Mapping[str, str] | None = None,
610+
rpc_metadata: Mapping[str, str | bytes] = {},
611+
rpc_timeout: timedelta | None = None,
612+
) -> NexusOperationHandle[OutputT]: ...
582613

583614
@abstractmethod
584615
async def start_operation(
@@ -611,7 +642,8 @@ async def start_operation(
611642
id: Unique identifier for this operation.
612643
id_reuse_policy: Policy for reusing operation IDs.
613644
id_conflict_policy: Policy for handling ID conflicts.
614-
result_type: The result type to deserialize into.
645+
result_type: For string operation names, this can set the specific
646+
result type hint to deserialize into.
615647
schedule_to_close_timeout: End-to-end timeout for the Nexus
616648
operation. If unset, defaults to the maximum allowed by the
617649
Temporal server.
@@ -633,7 +665,7 @@ async def start_operation(
633665
"""
634666
...
635667

636-
# Overload for nexusrpc.Operation with input
668+
# Overload for nexusrpc.Operation
637669
@overload
638670
@abstractmethod
639671
async def execute_operation(
@@ -654,18 +686,18 @@ async def execute_operation(
654686
rpc_timeout: timedelta | None = None,
655687
) -> OutputT: ...
656688

657-
# Overload for Callable with result_type
689+
# Overload for string operation name
658690
@overload
659691
@abstractmethod
660692
async def execute_operation(
661693
self,
662-
operation: Callable[..., Any],
694+
operation: str,
663695
arg: Any,
664696
*,
665697
id: str,
666698
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
667699
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
668-
result_type: type[OutputT],
700+
result_type: type[OutputT] | None = None,
669701
schedule_to_close_timeout: timedelta | None = None,
670702
schedule_to_start_timeout: timedelta | None = None,
671703
start_to_close_timeout: timedelta | None = None,
@@ -676,13 +708,16 @@ async def execute_operation(
676708
rpc_timeout: timedelta | None = None,
677709
) -> OutputT: ...
678710

679-
# Overload for Callable without result_type
711+
# Overload for workflow_run_operation methods
680712
@overload
681713
@abstractmethod
682714
async def execute_operation(
683715
self,
684-
operation: Callable[..., Any],
685-
arg: Any,
716+
operation: Callable[
717+
[NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT],
718+
Awaitable[temporalio.nexus.WorkflowHandle[OutputT]],
719+
],
720+
arg: InputT,
686721
*,
687722
id: str,
688723
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
@@ -695,20 +730,22 @@ async def execute_operation(
695730
headers: Mapping[str, str] | None = None,
696731
rpc_metadata: Mapping[str, str | bytes] = {},
697732
rpc_timeout: timedelta | None = None,
698-
) -> Any: ...
733+
) -> OutputT: ...
699734

700-
# Overload for str with result_type
735+
# Overload for sync_operation methods (async def)
701736
@overload
702737
@abstractmethod
703738
async def execute_operation(
704739
self,
705-
operation: str,
706-
arg: Any,
740+
operation: Callable[
741+
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
742+
Awaitable[OutputT],
743+
],
744+
arg: InputT,
707745
*,
708746
id: str,
709747
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
710748
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
711-
result_type: type[OutputT],
712749
schedule_to_close_timeout: timedelta | None = None,
713750
schedule_to_start_timeout: timedelta | None = None,
714751
start_to_close_timeout: timedelta | None = None,
@@ -719,13 +756,40 @@ async def execute_operation(
719756
rpc_timeout: timedelta | None = None,
720757
) -> OutputT: ...
721758

722-
# Overload for str without result_type
759+
# Overload for sync_operation methods (async def)
723760
@overload
724761
@abstractmethod
725762
async def execute_operation(
726763
self,
727-
operation: str,
728-
arg: Any,
764+
operation: Callable[
765+
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
766+
OutputT,
767+
],
768+
arg: InputT,
769+
*,
770+
id: str,
771+
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
772+
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
773+
schedule_to_close_timeout: timedelta | None = None,
774+
schedule_to_start_timeout: timedelta | None = None,
775+
start_to_close_timeout: timedelta | None = None,
776+
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
777+
summary: str | None = None,
778+
headers: Mapping[str, str] | None = None,
779+
rpc_metadata: Mapping[str, str | bytes] = {},
780+
rpc_timeout: timedelta | None = None,
781+
) -> OutputT: ...
782+
783+
# Overload for operation_handler
784+
@overload
785+
@abstractmethod
786+
async def execute_operation(
787+
self,
788+
operation: Callable[
789+
[NexusServiceType],
790+
nexusrpc.handler.OperationHandler[InputT, OutputT],
791+
],
792+
arg: InputT,
729793
*,
730794
id: str,
731795
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
@@ -738,7 +802,7 @@ async def execute_operation(
738802
headers: Mapping[str, str] | None = None,
739803
rpc_metadata: Mapping[str, str | bytes] = {},
740804
rpc_timeout: timedelta | None = None,
741-
) -> Any: ...
805+
) -> OutputT: ...
742806

743807
@abstractmethod
744808
async def execute_operation(
@@ -773,7 +837,8 @@ async def execute_operation(
773837
id: Unique identifier for this operation.
774838
id_reuse_policy: Policy for reusing operation IDs.
775839
id_conflict_policy: Policy for handling ID conflicts.
776-
result_type: The result type to deserialize into.
840+
result_type: For string operation names, this can set the specific
841+
result type hint to deserialize into.
777842
schedule_to_close_timeout: End-to-end timeout for the Nexus
778843
operation. If unset, defaults to the maximum allowed by the
779844
Temporal server.
@@ -860,7 +925,9 @@ async def start_operation(
860925
This API is experimental and unstable.
861926
"""
862927
op_name, output_type = self._resolve_operation(operation)
863-
final_result_type: type | None = result_type or output_type
928+
final_result_type: type | None = (
929+
result_type if isinstance(operation, str) else output_type
930+
)
864931

865932
return await self._client._impl.start_nexus_operation(
866933
StartNexusOperationInput(

tests/nexus/test_nexus_type_errors.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,13 @@ async def standalone_operation_type_tests():
238238
start_to_close_timeout=timedelta(seconds=2),
239239
)
240240

241-
# result_type overrides output type from operation definition
242-
# conflicting result_type and annotation on variable cause type error
243-
# assert-type-error-pyright: 'Type "str" is not assignable to declared type "MyOutput"'
244-
_bad_result_type_output: MyOutput = await nexus_client.execute_operation( # type: ignore
245-
MyServiceHandler.my_sync_operation,
241+
# result_type is not allowed when an operation is provided
242+
await nexus_client.execute_operation(
243+
# assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"'
244+
MyService.my_sync_operation, # type: ignore
246245
MyInput(),
247246
id="op-1",
248-
result_type=str, # type: ignore
247+
result_type=str,
249248
)
250249

251250
# string operation name and result_type infers output type
@@ -337,19 +336,14 @@ async def standalone_operation_type_tests():
337336
)
338337
_defn_handle_output: MyOutput = await _defn_handle.result()
339338

340-
# result_type overrides output type from operation definition
341-
# conflicting result_type and annotation on variable cause type error
342-
_result_type_handle: NexusOperationHandle[
343-
MyOutput
344-
# assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"'
345-
] = await nexus_client.start_operation( # type: ignore
346-
MyServiceHandler.my_sync_operation,
339+
# result_type is not allowed when an operation is provided
340+
await nexus_client.start_operation(
341+
# assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"'
342+
MyServiceHandler.my_sync_operation, # type: ignore
347343
MyInput(),
348344
id="op-1",
349-
result_type=str, # type: ignore
345+
result_type=str,
350346
)
351-
# handle still respects type declaration on the variable
352-
_result_type_handle_output: MyOutput = await _result_type_handle.result()
353347

354348
# starting with string operation name and result_type infers output type on the handle
355349
# and result from the handle
@@ -389,11 +383,30 @@ async def standalone_operation_type_tests():
389383
)
390384
)
391385

392-
# mismatched types on get_nexus_operation_handle produces type error
386+
# mismatched types on get_nexus_operation_handle produce a type error
393387
# assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"'
394388
_mismatch_handle: NexusOperationHandle[MyOutput] = (
395389
client.get_nexus_operation_handle( # type: ignore
396390
"op-1",
397391
result_type=str, # type: ignore
398392
)
399393
)
394+
395+
# functions with invalid signatures produce a type error
396+
class InvalidServiceHandler:
397+
async def invalid(self, _ctx: str, _input: str) -> str:
398+
raise NotImplementedError()
399+
400+
# assert-type-error-pyright: 'No overloads for "start_operation" match'
401+
_invalid_handle: NexusOperationHandle[str] = await nexus_client.start_operation(
402+
InvalidServiceHandler.invalid, # type: ignore
403+
"foo",
404+
id="invalid",
405+
)
406+
407+
# assert-type-error-pyright: 'No overloads for "execute_operation" match'
408+
_invalid_result: str = await nexus_client.execute_operation(
409+
InvalidServiceHandler.invalid, # type: ignore
410+
"foo",
411+
id="invalid",
412+
)

0 commit comments

Comments
 (0)