diff --git a/CHANGELOG.md b/CHANGELOG.md index a786ae1..a962b3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,80 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased -N/A +ADDED + +- Added a pluggable `DataConverter` (`durabletask.serialization`) accepted by + `TaskHubGrpcWorker`, `TaskHubGrpcClient`, and `AsyncTaskHubGrpcClient` via a + `data_converter` argument. Every payload boundary (inputs, outputs, events, + custom status, entity state) routes through it. The default + `JsonDataConverter` preserves existing behavior, so a custom converter (for + example one backed by pydantic) is opt-in. Custom objects can opt in via a + `to_json()` hook and a `from_json(value)` classmethod. +- `OrchestrationContext.call_activity`, `call_sub_orchestrator`, and + `call_entity` accept an optional `return_type`, and `wait_for_external_event` + accepts an optional `data_type`. When provided, the result/event payload is + reconstructed as that type (dataclasses — including nested dataclass, + `Optional`, and `list` fields — and `from_json()`-capable types) and the + returned task is typed accordingly (e.g. `call_activity(..., return_type=Foo)` + yields `CompletableTask[Foo]`). When omitted, the raw deserialized JSON is + returned as before. +- Inbound payloads are reconstructed from function type annotations. When an + orchestrator, activity, or entity operation annotates its input parameter (or + an activity its return value) with a dataclass or `from_json()`-capable type, + the payload is reconstructed as that type. Builtins and unannotated/unknown + types are passed through unchanged. An explicit `return_type` takes precedence + over a discovered annotation. +- Added typed accessors to `client.OrchestrationState`: `get_input()`, + `get_output()`, and `get_custom_status()` each accept an optional + `expected_type` and deserialize the corresponding payload, reconstructing + dataclasses and `from_json()`-capable types. The raw `serialized_*` fields are + retained. +- Objects exposing a `to_json()` method are now JSON-serializable when passed as + activity/orchestrator inputs or outputs. +- Added `EntityMetadata.get_typed_state(intended_type=...)`, which deserializes + the entity's persisted state and reconstructs dataclasses and + `from_json()`-capable types. The existing `get_state()` is unchanged: with no + argument it returns the raw serialized JSON payload, and `get_state(some_type)` + applies constructor-based coercion (`some_type(raw)`). +- Entity runtime state retrieval (`EntityContext.get_state(intended_type=...)` / + `DurableEntity.get_state(...)`) now also reconstructs dataclasses and + `from_json()`-capable types, in addition to the existing constructor-based + coercion. + +CHANGED + +- Custom objects (dataclasses, `SimpleNamespace`, namedtuples) are now + serialized as plain JSON. Decoding such a payload *without* a type hint now + yields a plain `dict` (previously a `SimpleNamespace`; a namedtuple now + round-trips as a JSON array). To get the original type back, pass the new + `return_type` / `data_type` arguments, annotate the consuming function's + parameter or return type, or use the typed client accessors. Payloads produced + by older SDK versions still deserialize — including into a `SimpleNamespace` + when no type is supplied — so in-flight orchestrations continue to replay + across an upgrade. +- JSON serialization failures now raise a `TypeError` that chains the original + error (`__cause__`) and names the offending type. + +FIXED + +- Falsy entity states (`0`, `""`, `[]`, `{}`) are no longer dropped when an + entity batch is persisted. Previously a falsy current state was treated as + "no state" and written as `None`, effectively deleting it; only an actual + `None` state now clears the persisted entity state. + +BREAKING CHANGES (type-level only — no runtime impact for typical users) + +These changes do not alter runtime behavior, but because the package ships +`py.typed`, consumers running strict type checkers (pyright/mypy) — or +subclassing the public abstract types — may need to update their code: + +- `OrchestrationContext.call_activity`, `call_sub_orchestrator`, `call_entity`, + and `wait_for_external_event` gained new keyword-only parameters + (`return_type` / `data_type`). Subclasses overriding these methods should add + the parameter to match the base signature. +- `client.OrchestrationState` gained a non-public `_data_converter` field + (excluded from equality and `repr`). Code constructing `OrchestrationState` + positionally should pass it via the new field or rely on its default. ## v1.6.0 diff --git a/docs/supported-patterns.md b/docs/supported-patterns.md index 354f400..34b5944 100644 --- a/docs/supported-patterns.md +++ b/docs/supported-patterns.md @@ -68,7 +68,8 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): yield ctx.call_activity(send_approval_request, input=order) # Approvals must be received within 24 hours or they will be cancelled. - approval_event = ctx.wait_for_external_event("approval_received") + # Passing ``data_type`` reconstructs the event payload as an ``Approval``. + approval_event = ctx.wait_for_external_event("approval_received", data_type=Approval) timeout_event = ctx.create_timer(timedelta(hours=24)) winner = yield task.when_any([approval_event, timeout_event]) if winner == timeout_event: @@ -81,9 +82,11 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): ``` As an aside, you'll also notice that the example orchestration above works with custom business -objects. Support for custom business objects includes support for custom classes, custom data -classes, and named tuples. Serialization and deserialization of these objects is handled -automatically by the SDK. +objects. Custom classes, data classes, and named tuples are serialized to plain JSON automatically. +To reconstruct the original type on the receiving side, supply the type — for example via the +`data_type` argument to `wait_for_external_event` (shown above), the `return_type` argument to +`call_activity` / `call_sub_orchestrator` / `call_entity`, or by annotating the consuming function's +input parameter. Without a type, the payload is returned as plain JSON (a `dict` or `list`). See the full [human interaction sample](../examples/human_interaction.py). diff --git a/durabletask/client.py b/durabletask/client.py index 47e711c..725bd3d 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -7,10 +7,10 @@ import time import uuid from collections.abc import AsyncIterable, Iterable, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Generic, Protocol, TypeVar, cast +from typing import Any, Generic, Protocol, TypeVar, cast, overload import grpc import grpc.aio @@ -50,10 +50,12 @@ ) from durabletask.payload import helpers as payload_helpers from durabletask.payload.store import PayloadStore +from durabletask.serialization import DEFAULT_DATA_CONVERTER, DataConverter, JsonDataConverter TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') TItem = TypeVar('TItem') +T = TypeVar('T') class OrchestrationStatus(Enum): @@ -81,6 +83,102 @@ class OrchestrationState: serialized_output: str | None serialized_custom_status: str | None failure_details: task.FailureDetails | None + # Converter used by the typed accessors below. Defaults to the SDK's JSON + # converter; the client populates it with its own converter so custom + # serialization applies on the read side too. Excluded from equality/repr so + # two states with equal payloads remain equal regardless of converter. + _data_converter: DataConverter = field( + default=DEFAULT_DATA_CONVERTER, compare=False, repr=False) + + @overload + def get_input(self, expected_type: type[T]) -> T | None: + ... + + @overload + def get_input(self, expected_type: None = ...) -> Any: + ... + + def get_input(self, expected_type: type | None = None) -> Any: + """Deserialize the orchestration's input. + + Parameters + ---------- + expected_type : type | None + Optional type used to reconstruct the input. When provided, the + payload is coerced to this type (dataclasses are constructed from + their dict payloads, types exposing a ``from_json()`` classmethod + are reconstructed via that hook) and the return value is typed as + ``expected_type | None``. When omitted, the raw deserialized JSON is + returned. + + Returns + ------- + Any + The deserialized input, or None if there is no input. + """ + if self.serialized_input is None: + return None + return self._data_converter.deserialize(self.serialized_input, expected_type) + + @overload + def get_output(self, expected_type: type[T]) -> T | None: + ... + + @overload + def get_output(self, expected_type: None = ...) -> Any: + ... + + def get_output(self, expected_type: type | None = None) -> Any: + """Deserialize the orchestration's output. + + Parameters + ---------- + expected_type : type | None + Optional type used to reconstruct the output. When provided, the + payload is coerced to this type (dataclasses are constructed from + their dict payloads, types exposing a ``from_json()`` classmethod + are reconstructed via that hook) and the return value is typed as + ``expected_type | None``. When omitted, the raw deserialized JSON is + returned. + + Returns + ------- + Any + The deserialized output, or None if there is no output. + """ + if self.serialized_output is None: + return None + return self._data_converter.deserialize(self.serialized_output, expected_type) + + @overload + def get_custom_status(self, expected_type: type[T]) -> T | None: + ... + + @overload + def get_custom_status(self, expected_type: None = ...) -> Any: + ... + + def get_custom_status(self, expected_type: type | None = None) -> Any: + """Deserialize the orchestration's custom status. + + Parameters + ---------- + expected_type : type | None + Optional type used to reconstruct the custom status. When provided, + the payload is coerced to this type (dataclasses are constructed + from their dict payloads, types exposing a ``from_json()`` + classmethod are reconstructed via that hook) and the return value is + typed as ``expected_type | None``. When omitted, the raw + deserialized JSON is returned. + + Returns + ------- + Any + The deserialized custom status, or None if there is no custom status. + """ + if self.serialized_custom_status is None: + return None + return self._data_converter.deserialize(self.serialized_custom_status, expected_type) def raise_if_failed(self): if self.failure_details is not None: @@ -138,18 +236,22 @@ def failure_details(self): return self._failure_details -def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> OrchestrationState | None: +def new_orchestration_state( + instance_id: str, res: pb.GetInstanceResponse, + data_converter: DataConverter | None = None) -> OrchestrationState | None: if not res.exists: return None state = res.orchestrationState - new_state = parse_orchestration_state(state) + new_state = parse_orchestration_state(state, data_converter) new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior return new_state -def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState: +def parse_orchestration_state( + state: pb.OrchestrationState, + data_converter: DataConverter | None = None) -> OrchestrationState: failure_details = None if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '': failure_details = task.FailureDetails( @@ -166,7 +268,8 @@ def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationStat state.input.value if not helpers.is_empty(state.input) else None, state.output.value if not helpers.is_empty(state.output) else None, state.customStatus.value if not helpers.is_empty(state.customStatus) else None, - failure_details) + failure_details, + data_converter if data_converter is not None else DEFAULT_DATA_CONVERTER) # Grace period before a retired SDK-owned channel is force-closed. Long enough @@ -309,9 +412,11 @@ def __init__(self, *, channel_options: GrpcChannelOptions | None = None, resiliency_options: GrpcClientResiliencyOptions | None = None, default_version: str | None = None, - payload_store: PayloadStore | None = None): + payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None): self._owns_channel = channel is None + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() self._host_address = ( host_address if host_address else shared.get_default_host_address() ) @@ -514,7 +619,8 @@ def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOu req = build_schedule_new_orchestration_req( orchestrator, input=input, instance_id=instance_id, start_at=start_at, reuse_id_policy=reuse_id_policy, tags=tags, - version=version if version else self.default_version) + version=version if version else self.default_version, + data_converter=self._data_converter) # Inject the active PRODUCER span context into the request so the sidecar # stores it in the executionStarted event and the worker can parent all @@ -538,7 +644,7 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr # De-externalize any large-payload tokens in the response if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) - return new_orchestration_state(req.instanceId, res) + return new_orchestration_state(req.instanceId, res, self._data_converter) def get_orchestration_history(self, instance_id: str, *, @@ -594,7 +700,7 @@ def get_all_orchestration_states(self, resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req) if self._payload_store is not None: payload_helpers.deexternalize_payloads(resp, self._payload_store) - states += [parse_orchestration_state(res) for res in resp.orchestrationState] + states += [parse_orchestration_state(res, self._data_converter) for res in resp.orchestrationState] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: @@ -614,7 +720,7 @@ def wait_for_orchestration_start(self, instance_id: str, *, ) if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) - return new_orchestration_state(req.instanceId, res) + return new_orchestration_state(req.instanceId, res, self._data_converter) except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore # Replace gRPC error with the built-in TimeoutError @@ -634,7 +740,7 @@ def wait_for_orchestration_completion(self, instance_id: str, *, ) if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) - state = new_orchestration_state(req.instanceId, res) + state = new_orchestration_state(req.instanceId, res, self._data_converter) log_completion_state(self._logger, instance_id, state) return state except grpc.RpcError as rpc_error: @@ -646,7 +752,7 @@ def wait_for_orchestration_completion(self, instance_id: str, *, def raise_orchestration_event(self, instance_id: str, event_name: str, *, data: Any | None = None) -> None: with tracing.start_raise_event_span(event_name, instance_id): - req = build_raise_event_req(instance_id, event_name, data) + req = build_raise_event_req(instance_id, event_name, data, self._data_converter) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") if self._payload_store is not None: payload_helpers.externalize_payloads( @@ -657,7 +763,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, def terminate_orchestration(self, instance_id: str, *, output: Any | None = None, recursive: bool = True) -> None: - req = build_terminate_req(instance_id, output, recursive) + req = build_terminate_req(instance_id, output, recursive, self._data_converter) self._logger.info(f"Terminating instance '{instance_id}'.") if self._payload_store is not None: @@ -720,7 +826,7 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Any | None = None) -> None: - req = build_signal_entity_req(entity_instance_id, operation_name, input) + req = build_signal_entity_req(entity_instance_id, operation_name, input, self._data_converter) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") if self._payload_store is not None: payload_helpers.externalize_payloads( @@ -739,7 +845,7 @@ def get_entity(self, return None if self._payload_store is not None: payload_helpers.deexternalize_payloads(res, self._payload_store) - return EntityMetadata.from_entity_metadata(res.entity, include_state) + return EntityMetadata.from_entity_metadata(res.entity, include_state, self._data_converter) def get_all_entities(self, entity_query: EntityQuery | None = None) -> list[EntityMetadata]: @@ -756,7 +862,7 @@ def get_all_entities(self, resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request) if self._payload_store is not None: payload_helpers.deexternalize_payloads(resp, self._payload_store) - entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] + entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState, self._data_converter) for entity in resp.entities] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: @@ -805,9 +911,11 @@ def __init__(self, *, channel_options: GrpcChannelOptions | None = None, resiliency_options: GrpcClientResiliencyOptions | None = None, default_version: str | None = None, - payload_store: PayloadStore | None = None): + payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None): self._owns_channel = channel is None + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() self._host_address = ( host_address if host_address else shared.get_default_host_address() ) @@ -998,7 +1106,8 @@ async def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInpu req = build_schedule_new_orchestration_req( orchestrator, input=input, instance_id=instance_id, start_at=start_at, reuse_id_policy=reuse_id_policy, tags=tags, - version=version if version else self.default_version) + version=version if version else self.default_version, + data_converter=self._data_converter) parent_trace_ctx = tracing.get_current_trace_context() if parent_trace_ctx is not None: @@ -1019,7 +1128,7 @@ async def get_orchestration_state(self, instance_id: str, *, res: pb.GetInstanceResponse = await self._stub.GetInstance(req) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) - return new_orchestration_state(req.instanceId, res) + return new_orchestration_state(req.instanceId, res, self._data_converter) async def get_orchestration_history(self, instance_id: str, *, @@ -1075,7 +1184,7 @@ async def get_all_orchestration_states(self, resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req) if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) - states += [parse_orchestration_state(res) for res in resp.orchestrationState] + states += [parse_orchestration_state(res, self._data_converter) for res in resp.orchestrationState] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: @@ -1095,7 +1204,7 @@ async def wait_for_orchestration_start(self, instance_id: str, *, ) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) - return new_orchestration_state(req.instanceId, res) + return new_orchestration_state(req.instanceId, res, self._data_converter) except grpc.aio.AioRpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: raise TimeoutError("Timed-out waiting for the orchestration to start") @@ -1114,7 +1223,7 @@ async def wait_for_orchestration_completion(self, instance_id: str, *, ) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) - state = new_orchestration_state(req.instanceId, res) + state = new_orchestration_state(req.instanceId, res, self._data_converter) log_completion_state(self._logger, instance_id, state) return state except grpc.aio.AioRpcError as rpc_error: @@ -1126,7 +1235,7 @@ async def wait_for_orchestration_completion(self, instance_id: str, *, async def raise_orchestration_event(self, instance_id: str, event_name: str, *, data: Any | None = None) -> None: with tracing.start_raise_event_span(event_name, instance_id): - req = build_raise_event_req(instance_id, event_name, data) + req = build_raise_event_req(instance_id, event_name, data, self._data_converter) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") if self._payload_store is not None: await payload_helpers.externalize_payloads_async( @@ -1137,7 +1246,7 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *, async def terminate_orchestration(self, instance_id: str, *, output: Any | None = None, recursive: bool = True) -> None: - req = build_terminate_req(instance_id, output, recursive) + req = build_terminate_req(instance_id, output, recursive, self._data_converter) self._logger.info(f"Terminating instance '{instance_id}'.") if self._payload_store is not None: @@ -1200,7 +1309,7 @@ async def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Any | None = None) -> None: - req = build_signal_entity_req(entity_instance_id, operation_name, input) + req = build_signal_entity_req(entity_instance_id, operation_name, input, self._data_converter) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") if self._payload_store is not None: await payload_helpers.externalize_payloads_async( @@ -1219,7 +1328,7 @@ async def get_entity(self, return None if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) - return EntityMetadata.from_entity_metadata(res.entity, include_state) + return EntityMetadata.from_entity_metadata(res.entity, include_state, self._data_converter) async def get_all_entities(self, entity_query: EntityQuery | None = None) -> list[EntityMetadata]: @@ -1236,7 +1345,7 @@ async def get_all_entities(self, resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request) if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) - entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] + entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState, self._data_converter) for entity in resp.entities] if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index 03ece71..c5435a2 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -1,22 +1,30 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload import uuid from durabletask.entities.entity_instance_id import EntityInstanceId -from durabletask.internal import helpers, shared +from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim import durabletask.internal.orchestrator_service_pb2 as pb +if TYPE_CHECKING: + from durabletask.serialization import DataConverter + TState = TypeVar("TState") class EntityContext: - def __init__(self, orchestration_id: str, operation: str, state: StateShim, entity_id: EntityInstanceId): + def __init__(self, orchestration_id: str, operation: str, state: StateShim, + entity_id: EntityInstanceId, data_converter: "DataConverter | None" = None): self._orchestration_id = orchestration_id self._operation = operation self._state = state self._entity_id = entity_id + if data_converter is None: + from durabletask.serialization import JsonDataConverter + data_converter = JsonDataConverter() + self._data_converter = data_converter @property def orchestration_id(self) -> str: @@ -95,7 +103,7 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in input : Any, optional The input to provide to the entity for the operation. """ - encoded_input: str | None = shared.to_json(input) if input is not None else None + encoded_input: str | None = self._data_converter.serialize(input) self._state.add_operation_action( pb.OperationAction( sendSignal=pb.SendSignalAction( @@ -126,7 +134,7 @@ def schedule_new_orchestration(self, orchestration_name: str, input: Any | None str The instance ID of the scheduled orchestration. """ - encoded_input: str | None = shared.to_json(input) if input is not None else None + encoded_input: str | None = self._data_converter.serialize(input) if not instance_id: instance_id = uuid.uuid4().hex self._state.add_operation_action( diff --git a/durabletask/entities/entity_metadata.py b/durabletask/entities/entity_metadata.py index a2ed219..37c437e 100644 --- a/durabletask/entities/entity_metadata.py +++ b/durabletask/entities/entity_metadata.py @@ -2,11 +2,14 @@ # Licensed under the MIT License. from datetime import datetime, timezone -from typing import Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload from durabletask.entities.entity_instance_id import EntityInstanceId import durabletask.internal.orchestrator_service_pb2 as pb +if TYPE_CHECKING: + from durabletask.serialization import DataConverter + TState = TypeVar("TState") @@ -32,7 +35,8 @@ def __init__(self, backlog_queue_size: int, locked_by: str, includes_state: bool, - state: Any | None): + state: Any | None, + data_converter: "DataConverter | None" = None): """Initializes a new instance of the EntityMetadata class. Args: @@ -44,13 +48,20 @@ def __init__(self, self._locked_by = locked_by self.includes_state = includes_state self._state = state + if data_converter is None: + from durabletask.serialization import JsonDataConverter + data_converter = JsonDataConverter() + self._data_converter = data_converter @staticmethod - def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool): - return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state) + def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool, + data_converter: "DataConverter | None" = None): + return EntityMetadata.from_entity_metadata( + entity_response.entity, includes_state, data_converter) @staticmethod - def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool): + def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool, + data_converter: "DataConverter | None" = None): try: entity_id = EntityInstanceId.parse(entity.instanceId) except ValueError: @@ -64,7 +75,8 @@ def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool): backlog_queue_size=entity.backlogQueueSize, locked_by=entity.lockedBy.value, includes_state=includes_state, - state=entity_state + state=entity_state, + data_converter=data_converter, ) @overload @@ -76,7 +88,17 @@ def get_state(self, intended_type: None = None) -> Any: ... def get_state(self, intended_type: type[TState] | None = None) -> TState | Any | None: - """Get the current state of the entity, optionally converting it to a specified type.""" + """Get the entity's raw persisted state, optionally constructor-coerced. + + The state is held as the raw serialized JSON payload (a ``str``). With no + argument the raw payload is returned unchanged; passing ``intended_type`` + applies the legacy constructor-based coercion (``intended_type(raw)``) + and raises ``TypeError`` if that fails. + + This preserves the pre-existing contract. To deserialize the payload and + reconstruct dataclasses or ``from_json()``-capable types, use + :meth:`get_typed_state` instead. + """ if intended_type is None or self._state is None: return self._state @@ -90,6 +112,27 @@ def get_state(self, intended_type: type[TState] | None = None) -> TState | Any | f"Could not convert state of type '{type(self._state).__name__}' to '{intended_type.__name__}'" ) from ex + @overload + def get_typed_state(self, intended_type: type[TState]) -> TState | None: + ... + + @overload + def get_typed_state(self, intended_type: None = None) -> Any: + ... + + def get_typed_state(self, intended_type: type[TState] | None = None) -> TState | Any | None: + """Deserialize the entity's persisted state, optionally reconstructing a type. + + The state is stored as its raw serialized JSON payload and deserialized + here. When ``intended_type`` is provided the payload is reconstructed as + that type (dataclasses, ``from_json()``-capable types, etc.); otherwise + the plain deserialized JSON value is returned. + """ + if self._state is None: + return None + + return self._data_converter.deserialize(self._state, intended_type) + def get_locked_by(self) -> EntityInstanceId | None: """Get the identifier of the worker that currently holds the lock on the entity. diff --git a/durabletask/extensions/history_export/client.py b/durabletask/extensions/history_export/client.py index cc9f87c..fdd302a 100644 --- a/durabletask/extensions/history_export/client.py +++ b/durabletask/extensions/history_export/client.py @@ -52,7 +52,6 @@ from __future__ import annotations -import json import time import uuid from collections.abc import Iterator @@ -221,12 +220,8 @@ def get_job(self, job_id: str) -> ExportJobDescription | None: meta = self._client.get_entity(entity_id, include_state=True) if meta is None: return None - raw = meta.get_state(str) - if not raw: - return None - try: - state = json.loads(raw) - except (TypeError, ValueError): + state = meta.get_typed_state() + if not state: return None if not isinstance(state, dict): return None @@ -265,31 +260,24 @@ def list_jobs( # explicit entity-name check. if meta.id.entity != ENTITY_NAME.lower(): continue - raw = meta.get_state(str) + raw = meta.get_typed_state() if not raw: logger.warning( "list_jobs: skipping export-job entity %r with no " "persisted state", meta.id.key, ) continue - try: - state = json.loads(raw) - except (TypeError, ValueError) as ex: - logger.warning( - "list_jobs: skipping export-job entity %r; failed to " - "parse state JSON (%s)", meta.id.key, ex, - ) - continue - if not isinstance(state, dict): + if not isinstance(raw, dict): logger.warning( "list_jobs: skipping export-job entity %r; persisted " "state is not a JSON object (got %s)", - meta.id.key, type(state).__name__, + meta.id.key, type(raw).__name__, ) continue + state = cast("dict[str, Any]", raw) try: desc = ExportJobDescription.from_state_dict( - meta.id.key, cast("dict[str, Any]", state), + meta.id.key, state, ) except (KeyError, ValueError) as ex: logger.warning( diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py index ef27c50..a42c3c4 100644 --- a/durabletask/internal/client_helpers.py +++ b/durabletask/internal/client_helpers.py @@ -28,11 +28,20 @@ OrchestrationStatus, ) from durabletask.entities import EntityInstanceId + from durabletask.serialization import DataConverter TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') +def _serialize(value: Any, data_converter: DataConverter | None) -> str | None: + """Serialize ``value`` using the supplied converter, defaulting to the SDK's.""" + if data_converter is None: + from durabletask.serialization import DEFAULT_DATA_CONVERTER + data_converter = DEFAULT_DATA_CONVERTER + return data_converter.serialize(value) + + def prepare_sync_interceptors( metadata: list[tuple[str, str]] | None, interceptors: Sequence[shared.ClientInterceptor] | None @@ -70,13 +79,14 @@ def build_schedule_new_orchestration_req( start_at: datetime | None, reuse_id_policy: pb.OrchestrationIdReusePolicy | None, tags: dict[str, str] | None, - version: str | None) -> pb.CreateInstanceRequest: + version: str | None, + data_converter: DataConverter) -> pb.CreateInstanceRequest: """Build a CreateInstanceRequest for scheduling a new orchestration.""" name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) return pb.CreateInstanceRequest( name=name, instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=helpers.get_string_value(shared.to_json(input) if input is not None else None), + input=helpers.get_string_value(data_converter.serialize(input)), scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=helpers.get_string_value(version), orchestrationIdReusePolicy=reuse_id_policy, @@ -168,23 +178,25 @@ def log_completion_state( def build_raise_event_req( instance_id: str, event_name: str, - data: Any | None = None) -> pb.RaiseEventRequest: + data: Any | None = None, + data_converter: DataConverter | None = None) -> pb.RaiseEventRequest: """Build a RaiseEventRequest for raising an orchestration event.""" return pb.RaiseEventRequest( instanceId=instance_id, name=event_name, - input=helpers.get_string_value(shared.to_json(data) if data is not None else None) + input=helpers.get_string_value(_serialize(data, data_converter)) ) def build_terminate_req( instance_id: str, output: Any | None = None, - recursive: bool = True) -> pb.TerminateRequest: + recursive: bool = True, + data_converter: DataConverter | None = None) -> pb.TerminateRequest: """Build a TerminateRequest for terminating an orchestration.""" return pb.TerminateRequest( instanceId=instance_id, - output=helpers.get_string_value(shared.to_json(output) if output is not None else None), + output=helpers.get_string_value(_serialize(output, data_converter)), recursive=recursive ) @@ -192,12 +204,13 @@ def build_terminate_req( def build_signal_entity_req( entity_instance_id: EntityInstanceId, operation_name: str, - input: Any | None = None) -> pb.SignalEntityRequest: + input: Any | None = None, + data_converter: DataConverter | None = None) -> pb.SignalEntityRequest: """Build a SignalEntityRequest for signaling an entity.""" return pb.SignalEntityRequest( instanceId=str(entity_instance_id), name=operation_name, - input=helpers.get_string_value(shared.to_json(input) if input is not None else None), + input=helpers.get_string_value(_serialize(input, data_converter)), requestId=str(uuid.uuid4()), scheduledTime=None, parentTraceContext=None, diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 6d6fd25..99f2801 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -1,19 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload import durabletask.internal.orchestrator_service_pb2 as pb +if TYPE_CHECKING: + from durabletask.serialization import DataConverter + TState = TypeVar("TState") class StateShim: - def __init__(self, start_state: Any): + def __init__(self, start_state: Any, data_converter: "DataConverter | None" = None): self._current_state: Any = start_state self._checkpoint_state: Any = start_state self._operation_actions: list[pb.OperationAction] = [] self._actions_checkpoint_state: int = 0 + if data_converter is None: + from durabletask.serialization import JsonDataConverter + data_converter = JsonDataConverter() + self._data_converter = data_converter @overload def get_state(self, intended_type: type[TState], default: TState) -> TState: @@ -34,15 +41,22 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | if intended_type is None: return self._current_state - if isinstance(self._current_state, intended_type): - return self._current_state - - try: - return intended_type(self._current_state) # type: ignore[call-arg] - except Exception as ex: + coerced = self._data_converter.coerce(self._current_state, intended_type) + + # An explicit ``intended_type`` is a request to receive that type. The + # default converter is best-effort and would silently return the raw + # value on a failed coercion; restore the stricter contract here by + # raising when a non-None state could not be coerced to a concrete type. + # ``intended_type`` may be a typing generic (e.g. ``list[int]``) at + # runtime, which is not a ``type`` instance, so the guard is required. + if (self._current_state is not None + and isinstance(intended_type, type) # pyright: ignore[reportUnnecessaryIsInstance] + and not isinstance(coerced, intended_type)): raise TypeError( f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" - ) from ex + ) + + return coerced def set_state(self, state: Any) -> None: self._current_state = state diff --git a/durabletask/internal/json_codec.py b/durabletask/internal/json_codec.py new file mode 100644 index 0000000..8fda0ea --- /dev/null +++ b/durabletask/internal/json_codec.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Internal JSON codec for Durable Task payloads. + +This module holds the low-level serialization *mechanism* -- the JSON string +encode/decode primitives and the value-level type coercion used to reconstruct +custom objects. Serialization *policy* (the public, pluggable strategy) lives in +:mod:`durabletask.serialization`; the default ``JsonDataConverter`` is the only +production consumer of ``to_json`` / ``from_json``, while ``coerce_to_type`` is +also used directly by entity state accessors that already hold a parsed value. +""" + +from __future__ import annotations + +import dataclasses +import json +import types +import typing +from collections.abc import Sequence +from types import SimpleNamespace +from typing import Any, cast + +# Marker formerly added to JSON payloads to flag objects for automatic +# deserialization into a SimpleNamespace. New code no longer emits this marker +# (objects are serialized as plain JSON), but the decoder still recognizes it so +# that orchestration histories produced by older SDK versions continue to replay. +AUTO_SERIALIZED = "__durabletask_autoobject__" + + +def to_json(obj: Any) -> str: + """Serialize a value to a JSON string. + + Builtins serialize to plain JSON. Dataclasses, ``SimpleNamespace`` + instances, and objects exposing a ``to_json()`` method are serialized to + plain JSON as well (without any type marker); custom objects can be + reconstructed on the receiving side by passing ``expected_type`` to + :func:`from_json`. + """ + try: + return json.dumps(obj, default=_encode_custom_object) + except TypeError as e: + # Preserve the original error as the cause so serialization failures are + # easier to diagnose, while naming the offending top-level type. + raise TypeError( + f"Failed to serialize object of type '{type(obj).__name__}' to JSON: {e}" + ) from e + + +def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: + """Deserialize a JSON string, optionally coercing the result to a type. + + When ``expected_type`` is ``None`` (the default) the raw parsed JSON is + returned. For backwards compatibility, payloads carrying the legacy + :data:`AUTO_SERIALIZED` marker are reconstructed as ``SimpleNamespace`` + instances so that in-flight orchestrations produced by older SDK versions + continue to replay. + + When ``expected_type`` is provided, the legacy marker (if present) is + stripped and the parsed value is coerced to ``expected_type`` -- dataclasses + are constructed from their dict payloads, types exposing a ``from_json()`` + classmethod are reconstructed via that hook, and ``Optional``/``Union`` and + ``list`` type hints are honored recursively. The destination type is always + supplied by the caller; it is never read from the payload. + """ + if expected_type is None: + return json.loads(json_str, object_hook=_legacy_object_hook) + raw = json.loads(json_str, object_hook=_strip_legacy_marker) + return coerce_to_type(raw, expected_type) + + +def _encode_custom_object(o: Any) -> Any: + """``default`` hook for :func:`json.dumps` that emits plain JSON. + + Called only for values the JSON encoder cannot natively serialize. Note that + namedtuples are handled natively by the encoder (serialized as JSON arrays) + and never reach this hook. + """ + if dataclasses.is_dataclass(o) and not isinstance(o, type): + return dataclasses.asdict(o) + if isinstance(o, SimpleNamespace): + return vars(o) + # Custom objects may opt in via a ``to_json`` hook. It is resolved off the + # type and called with the instance (``type(o).to_json(o)``) so that both + # instance methods and ``@staticmethod`` hooks work -- matching the calling + # convention used by ``azure-functions-durable``. The hook returns a + # JSON-serializable value (a structure or a string), not a JSON document. + to_json_hook = getattr(cast(Any, type(o)), "to_json", None) + if callable(to_json_hook): + return to_json_hook(o) + # This will raise a TypeError describing the unsupported type. + raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") + + +def _legacy_object_hook(d: dict[str, Any]) -> Any: + # If the object carries the legacy marker, deserialize it as a SimpleNamespace. + if d.pop(AUTO_SERIALIZED, False): + return SimpleNamespace(**d) + return d + + +def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: + # Discard the legacy marker so typed coercion sees a plain dict. + d.pop(AUTO_SERIALIZED, None) + return d + + +def coerce_to_type(value: Any, expected_type: Any) -> Any: + """Coerce an already-parsed JSON value to ``expected_type``. + + Handles ``None``/``Optional``/``Union`` and ``list`` type hints recursively, + types exposing a ``from_json()`` classmethod, and dataclasses (including + nested dataclass fields). The destination type is always caller-supplied and + never derived from the payload, keeping deserialization secure. + """ + if expected_type is None or value is None: + return value + + origin = typing.get_origin(expected_type) + if origin is not None: + return _coerce_generic(value, expected_type, origin) + + if not isinstance(expected_type, type): + # Not a concrete, instantiable type (e.g. a typing special form we don't + # special-case) -- return the value unchanged. + return value + + if isinstance(value, expected_type): + return value + + from_json_hook = getattr(expected_type, "from_json", None) + if callable(from_json_hook): + return from_json_hook(value) + + if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): + return _build_dataclass(expected_type, cast(dict[str, Any], value)) + + type_ctor = cast(Any, expected_type) + try: + return type_ctor(value) + except Exception as e: + type_name = getattr(type_ctor, "__name__", None) or str(type_ctor) + raise TypeError( + f"Could not coerce value of type '{type(value).__name__}' to " + f"'{type_name}'" + ) from e + + +def _coerce_generic(value: Any, expected_type: Any, origin: Any) -> Any: + args = typing.get_args(expected_type) + if origin is typing.Union or origin is types.UnionType: + # If the value already matches a member type, keep it as-is. + non_none = [a for a in args if a is not type(None)] + for arg in non_none: + if isinstance(arg, type) and isinstance(value, arg): + return value + # ``Optional[T]`` (exactly one non-None member): coerce to that member. + # For a genuine multi-member ``Union`` where the value matched none of + # the members, leave it untouched rather than guessing the first arg -- + # forcing a coercion there can silently mis-construct the wrong type. + if len(non_none) == 1: + return coerce_to_type(value, non_none[0]) + return value + if origin in (list, Sequence) and isinstance(value, list): + elem_type = args[0] if args else None + return [coerce_to_type(item, elem_type) for item in cast(list[Any], value)] + # Other generics (dict, tuple, ...) are returned as parsed JSON. + return value + + +def _build_dataclass(cls: Any, data: dict[str, Any]) -> Any: + """Construct a dataclass from its dict payload, recursing into typed fields.""" + try: + hints = typing.get_type_hints(cls) + except Exception: + hints = {} + kwargs: dict[str, Any] = {} + for field in dataclasses.fields(cls): + if field.name not in data: + continue + field_type = hints.get(field.name) + kwargs[field.name] = coerce_to_type(data[field.name], field_type) + return cls(**kwargs) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index f8afc7c..9ef136d 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -1,15 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import dataclasses -import json import logging from collections.abc import Sequence -from types import SimpleNamespace -from typing import Any, TypeAlias +from typing import TypeAlias import grpc import grpc.aio + +# Backwards-compatibility re-exports. The JSON codec moved to +# ``durabletask.internal.json_codec``; these aliases keep older imports from +# ``durabletask.internal.shared`` working. +from durabletask.internal.json_codec import ( # noqa: F401 + AUTO_SERIALIZED as AUTO_SERIALIZED, + from_json as from_json, + to_json as to_json, +) from durabletask.grpc_options import GrpcChannelOptions ClientInterceptor: TypeAlias = ( @@ -26,10 +32,6 @@ | grpc.aio.StreamStreamClientInterceptor ) -# Field name used to indicate that an object was automatically serialized -# and should be deserialized as a SimpleNamespace -AUTO_SERIALIZED = "__durabletask_autoobject__" - SECURE_PROTOCOLS = ["https://", "grpcs://"] INSECURE_PROTOCOLS = ["http://", "grpc://"] @@ -156,51 +158,3 @@ def get_logger( datefmt='%Y-%m-%d %H:%M:%S') log_handler.setFormatter(log_formatter) return logger - - -def to_json(obj: Any) -> str: - return json.dumps(obj, cls=InternalJSONEncoder) - - -def from_json(json_str: str | bytes | bytearray) -> Any: - return json.loads(json_str, cls=InternalJSONDecoder) - - -class InternalJSONEncoder(json.JSONEncoder): - """JSON encoder that supports serializing specific Python types.""" - - def encode(self, o: Any) -> str: # pyright: ignore[reportIncompatibleMethodOverride] - # if the object is a namedtuple, convert it to a dict with the AUTO_SERIALIZED key added - if isinstance(o, tuple): - namedtuple_obj: Any = o # pyright: ignore[reportUnknownVariableType] - if hasattr(namedtuple_obj, "_fields") and hasattr(namedtuple_obj, "_asdict"): - d: dict[str, Any] = namedtuple_obj._asdict() - d[AUTO_SERIALIZED] = True - o = d - return super().encode(o) - - def default(self, o: Any) -> Any: # pyright: ignore[reportIncompatibleMethodOverride] - if dataclasses.is_dataclass(o) and not isinstance(o, type): - # Dataclasses are not serializable by default, so we convert them to a dict and mark them for - # automatic deserialization by the receiver - d: dict[str, Any] = dataclasses.asdict(o) - d[AUTO_SERIALIZED] = True - return d - elif isinstance(o, SimpleNamespace): - # Most commonly used for serializing custom objects that were previously serialized using our encoder - d = vars(o) - d[AUTO_SERIALIZED] = True - return d - # This will typically raise a TypeError - return json.JSONEncoder.default(self, o) - - -class InternalJSONDecoder(json.JSONDecoder): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(object_hook=self.dict_to_object, *args, **kwargs) - - def dict_to_object(self, d: dict[str, Any]) -> Any: - # If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace - if d.pop(AUTO_SERIALIZED, False): - return SimpleNamespace(**d) - return d diff --git a/durabletask/internal/type_discovery.py b/durabletask/internal/type_discovery.py new file mode 100644 index 0000000..58fd6f8 --- /dev/null +++ b/durabletask/internal/type_discovery.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Best-effort discovery of input type hints for user functions. + +These helpers resolve the annotation of the *input* parameter of an +orchestrator, activity, or entity function so that inbound payloads can be +reconstructed into the annotated custom type (a dataclass or a type exposing a +``from_json()`` classmethod) without the caller having to pass an explicit type. + +Discovery is intentionally conservative: it only returns an annotation when the +target is a *reconstructable* custom type (a dataclass, a ``from_json()``-capable +type, or an ``Optional`` / ``list`` wrapping one). Primitive and unknown +annotations resolve to ``None`` so that existing payloads are passed through +unchanged -- inbound type discovery never invokes an arbitrary constructor on +untrusted data, and never alters the value for builtins. + +All public helpers swallow exceptions and return ``None`` on failure; the caller +treats ``None`` as "no type information available" and uses the raw payload. +""" + +from __future__ import annotations + +import collections.abc +import dataclasses +import functools +import inspect +import types +import typing +from typing import Any, Callable, cast + + +def is_reconstructable(annotation: Any) -> bool: + """Return True if ``annotation`` names a custom type we can rebuild. + + Reconstructable targets are dataclasses, types exposing a callable + ``from_json``, and ``Optional`` / ``list`` hints wrapping such types. + Builtins (``int``, ``str``, ``dict``, ...) and unknown annotations are not + reconstructable and resolve to ``False``. + """ + origin = typing.get_origin(annotation) + if origin is not None: + args = typing.get_args(annotation) + if origin is typing.Union or origin is types.UnionType: + return any( + is_reconstructable(a) for a in args if a is not type(None) + ) + if origin in (list, collections.abc.Sequence): + return any(is_reconstructable(a) for a in args) + return False + if not isinstance(annotation, type): + return False + if dataclasses.is_dataclass(annotation): + return True + return callable(getattr(cast(Any, annotation), "from_json", None)) + + +# Bounded so a worker that registers dynamically-created functions or closures +# cannot accumulate cache entries unboundedly over the process lifetime. The +# common case (a fixed set of module-level orchestrators/activities) fits well +# within this bound. +@functools.lru_cache(maxsize=2048) +def _resolved_hints(fn: Callable[..., Any]) -> dict[str, Any] | None: + """Resolve a function's type hints, honoring postponed annotations. + + Results are memoized per function because discovery runs on every + orchestrator/activity/entity execution (including replay). + """ + try: + return typing.get_type_hints(fn) + except Exception: + return None + + +def _input_annotation(fn: Callable[..., Any], position: int) -> Any | None: + """Return the resolved annotation of the positional parameter at ``position``. + + ``position`` is the zero-based index among positional parameters (so the + ``input`` parameter of a ``(ctx, input)`` function is at position 1, and the + ``input`` parameter of an unbound ``(self, input)`` entity method is also at + position 1). Returns ``None`` when the parameter is absent, unannotated, or + its annotation is not a reconstructable custom type. + """ + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return None + + positional = [ + p for p in sig.parameters.values() + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + if position >= len(positional): + return None + param = positional[position] + + annotation: Any = param.annotation + hints = _resolved_hints(fn) + if hints is not None and param.name in hints: + annotation = hints[param.name] + elif isinstance(annotation, str): + # Could not resolve a postponed (string) annotation -- give up. + return None + + if annotation is inspect.Parameter.empty or annotation is Any: + return None + return annotation if is_reconstructable(annotation) else None + + +def orchestrator_input_type(fn: Callable[..., Any]) -> Any | None: + """Discover the input type of an orchestrator function ``(ctx, input)``.""" + return _input_annotation(fn, 1) + + +def activity_input_type(fn: Callable[..., Any]) -> Any | None: + """Discover the input type of an activity function ``(ctx, input)``.""" + return _input_annotation(fn, 1) + + +def activity_output_type(fn: Any) -> Any | None: + """Discover the return type of an activity function. + + Returns the resolved return annotation when it names a reconstructable + custom type (a dataclass or a ``from_json()``-capable type, optionally + wrapped in ``Optional`` / ``list``). Returns ``None`` for plain callables + that are not annotated with such a type, for string activity names, or when + the annotation cannot be resolved. + """ + if not callable(fn): + return None + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return None + + annotation: Any = sig.return_annotation + hints = _resolved_hints(fn) + if hints is not None and "return" in hints: + annotation = hints["return"] + elif isinstance(annotation, str): + # Could not resolve a postponed (string) annotation -- give up. + return None + + if annotation is inspect.Signature.empty or annotation is Any or annotation is None: + return None + return annotation if is_reconstructable(annotation) else None + + +def entity_input_type(fn: Any, operation: str) -> Any | None: + """Discover the input type of an entity operation. + + For class-based entities (a ``DurableEntity`` subclass) the operation is a + method; its input is the first parameter after ``self``. For function-based + entities the signature is ``(ctx, input)``. Returns ``None`` when no + reconstructable input annotation is found. + """ + if isinstance(fn, type): + method = getattr(fn, operation, None) + if method is None or not callable(method): + return None + # Unbound method includes ``self`` at position 0, so ``input`` is at 1. + return _input_annotation(method, 1) + return _input_annotation(fn, 1) diff --git a/durabletask/serialization.py b/durabletask/serialization.py new file mode 100644 index 0000000..b1b469b --- /dev/null +++ b/durabletask/serialization.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pluggable serialization for Durable Task payloads. + +All user payloads (orchestrator/activity/entity inputs and outputs, external +event data, custom status, and entity state) flow through a +:class:`DataConverter`. The worker and client both accept a converter and share +it across every serialization boundary, so a single object controls how Python +values become JSON on the wire and how they are reconstructed on the way back. + +The default :class:`JsonDataConverter` preserves the SDK's built-in behavior: +builtins serialize as plain JSON, dataclasses / ``SimpleNamespace`` instances +and objects exposing a ``to_json()`` hook serialize to plain JSON structures, +and a caller-supplied ``target_type`` drives reconstruction on the read side +(the destination type is never read from the payload). + +To customize serialization -- for example to validate with pydantic, encode +custom ``datetime`` / ``Decimal`` formats, or integrate another model framework +-- implement :class:`DataConverter` and pass it to the worker and client:: + + converter = MyDataConverter() + worker = TaskHubGrpcWorker(data_converter=converter) + client = TaskHubGrpcClient(data_converter=converter) +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +from durabletask.internal import json_codec + +logger = logging.getLogger("durabletask") + + +class DataConverter(ABC): + """Strategy for serializing and deserializing Durable Task payloads. + + Implementations are used by both the worker and the client and must be + deterministic: the same value must always serialize to the same string so + that orchestration replay stays consistent. + """ + + @abstractmethod + def serialize(self, value: Any) -> str | None: + """Serialize ``value`` to a string, or ``None`` when ``value`` is ``None``.""" + ... + + @abstractmethod + def deserialize(self, data: str | None, target_type: type | None = None) -> Any: + """Deserialize ``data``, optionally coercing the result to ``target_type``. + + ``data`` is ``None`` (or empty) when there is no payload, in which case + ``None`` is returned. When ``target_type`` is provided the result is + reconstructed as that type; otherwise the raw deserialized value is + returned. The destination type is always supplied by the caller and is + never derived from the payload. + + Whether a failure to coerce to ``target_type`` raises or falls back to + the raw value is an implementation choice. The default + :class:`JsonDataConverter` is best-effort and falls back; a validating + converter may instead raise. + """ + ... + + @abstractmethod + def coerce(self, value: Any, target_type: type | None = None) -> Any: + """Coerce an **already-deserialized** ``value`` to ``target_type``. + + Unlike :meth:`deserialize`, the input is a live Python value rather than + a serialized string. Used where the SDK holds a parsed value (for + example, durable entity state during a batch) and needs to reconstruct + the caller's requested type without re-serializing. When ``target_type`` + is ``None`` the value is returned unchanged. The same coercion policy + (strict vs. best-effort) that an implementation applies in + :meth:`deserialize` should apply here. + """ + ... + + +class JsonDataConverter(DataConverter): + """Default :class:`DataConverter` backed by the SDK's JSON codec. + + Serialization emits plain JSON. Custom objects may opt in by exposing a + ``to_json()`` method (called as ``type(obj).to_json(obj)``, so both instance + methods and ``@staticmethod`` hooks work) and a ``from_json(value)`` + classmethod used during type-directed reconstruction. This matches the + ``to_json`` / ``from_json`` convention used by ``azure-functions-durable``. + + Deserialization (and value-level :meth:`coerce`) is **best-effort**: when a + ``target_type`` is supplied and the value cannot be coerced to it, the raw + value is returned (and a debug message is logged) rather than raising. This + keeps the core SDK permissive; a stricter, validating converter can be + supplied for callers who want coercion failures to surface as errors. + """ + + def serialize(self, value: Any) -> str | None: + if value is None: + return None + return json_codec.to_json(value) + + def deserialize(self, data: str | None, target_type: type | None = None) -> Any: + if data is None or data == "": + return None + if target_type is None: + return json_codec.from_json(data) + try: + return json_codec.from_json(data, target_type) + except Exception as e: + # Best-effort: fall back to the raw deserialized value rather than + # failing the operation. Logged so the mismatch remains discoverable. + self._log_coercion_fallback(target_type, e) + return json_codec.from_json(data) + + def coerce(self, value: Any, target_type: type | None = None) -> Any: + if target_type is None or value is None: + return value + try: + return json_codec.coerce_to_type(value, target_type) + except Exception as e: + self._log_coercion_fallback(target_type, e) + return value + + @staticmethod + def _log_coercion_fallback(target_type: type, error: Exception) -> None: + logger.debug( + "Could not coerce payload to '%s' (%s); returning the raw " + "deserialized value.", + getattr(target_type, "__name__", target_type), error, + ) + + +# Shared default instance used when no converter is supplied. +DEFAULT_DATA_CONVERTER: DataConverter = JsonDataConverter() diff --git a/durabletask/task.py b/durabletask/task.py index b1ae27c..18af6cc 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Generator, Sequence from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast, overload from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext import durabletask.internal.helpers as pbh @@ -115,11 +115,28 @@ def create_timer(self, fire_at: datetime | timedelta) -> TimerTask: """ pass + @overload + def call_activity(self, activity: Activity[TInput, TOutput] | str, *, + input: TInput | None = ..., + retry_policy: RetryPolicy | None = ..., + tags: dict[str, str] | None = ..., + return_type: type[T]) -> CompletableTask[T]: + ... + + @overload + def call_activity(self, activity: Activity[TInput, TOutput] | str, *, + input: TInput | None = ..., + retry_policy: RetryPolicy | None = ..., + tags: dict[str, str] | None = ..., + return_type: None = ...) -> CompletableTask[TOutput]: + ... + @abstractmethod def call_activity(self, activity: Activity[TInput, TOutput] | str, *, input: TInput | None = None, retry_policy: RetryPolicy | None = None, - tags: dict[str, str] | None = None) -> CompletableTask[TOutput]: + tags: dict[str, str] | None = None, + return_type: type | None = None) -> CompletableTask[Any]: """Schedule an activity for execution. Parameters @@ -132,6 +149,16 @@ def call_activity(self, activity: Activity[TInput, TOutput] | str, *, The retry policy to use for this activity call. tags: dict[str, str] | None Optional tags to associate with the activity invocation. + return_type: type | None + Optional type used to deserialize the activity's result. When + provided, the result is coerced to this type (dataclasses are + constructed from their dict payloads, types exposing a + ``from_json()`` classmethod are reconstructed via that hook), and + the returned task is typed as ``CompletableTask[return_type]``. + When omitted, the return type is discovered from the activity + function's return annotation (if a function reference is passed and + it is annotated with a reconstructable type); otherwise the raw + deserialized JSON is returned. Returns ------- @@ -140,11 +167,31 @@ def call_activity(self, activity: Activity[TInput, TOutput] | str, *, """ pass + @overload + def call_entity(self, + entity: EntityInstanceId, + operation: str, + input: Any = ..., + *, + return_type: type[T]) -> CompletableTask[T]: + ... + + @overload + def call_entity(self, + entity: EntityInstanceId, + operation: str, + input: Any = ..., + *, + return_type: None = ...) -> CompletableTask[Any]: + ... + @abstractmethod def call_entity(self, entity: EntityInstanceId, operation: str, - input: Any = None) -> CompletableTask[Any]: + input: Any = None, + *, + return_type: type | None = None) -> CompletableTask[Any]: """Schedule entity function for execution. Parameters @@ -155,6 +202,11 @@ def call_entity(self, The name of the operation to invoke on the entity. input: TInput | None The optional JSON-serializable input to pass to the entity function. + return_type: type | None + Optional type used to deserialize the operation's result. When + provided, the result is coerced to this type and the returned task + is typed as ``CompletableTask[return_type]``; when omitted, the raw + deserialized JSON is returned. Returns ------- @@ -203,12 +255,31 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> CompletableTask[Ent """ pass + @overload + def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput] | str, *, + input: TInput | None = ..., + instance_id: str | None = ..., + retry_policy: RetryPolicy | None = ..., + version: str | None = ..., + return_type: type[T]) -> CompletableTask[T]: + ... + + @overload + def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput] | str, *, + input: TInput | None = ..., + instance_id: str | None = ..., + retry_policy: RetryPolicy | None = ..., + version: str | None = ..., + return_type: None = ...) -> CompletableTask[TOutput]: + ... + @abstractmethod def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput] | str, *, input: TInput | None = None, instance_id: str | None = None, retry_policy: RetryPolicy | None = None, - version: str | None = None) -> CompletableTask[TOutput]: + version: str | None = None, + return_type: type | None = None) -> CompletableTask[Any]: """Schedule sub-orchestrator function for execution. Parameters @@ -222,6 +293,11 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput] | st random UUID will be used. retry_policy: RetryPolicy | None The retry policy to use for this sub-orchestrator call. + return_type: type | None + Optional type used to deserialize the sub-orchestrator's result. When + provided, the result is coerced to this type and the returned task is + typed as ``CompletableTask[return_type]``; when omitted, the raw + deserialized JSON is returned. Returns ------- @@ -232,14 +308,30 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput] | st # TOOD: Add a timeout parameter, which allows the task to be cancelled if the event is # not received within the specified timeout. This requires support for task cancellation. + @overload + def wait_for_external_event(self, name: str, *, + data_type: type[T]) -> CancellableTask[T]: + ... + + @overload + def wait_for_external_event(self, name: str, *, + data_type: None = ...) -> CancellableTask[Any]: + ... + @abstractmethod - def wait_for_external_event(self, name: str) -> CancellableTask[Any]: + def wait_for_external_event(self, name: str, *, + data_type: type | None = None) -> CancellableTask[Any]: """Wait asynchronously for an event to be raised with the name `name`. Parameters ---------- name : str The event name of the event that the task is waiting for. + data_type : type | None + Optional type used to deserialize the event payload. When provided, + the payload is coerced to this type and the returned task is typed + as ``CancellableTask[data_type]``; when omitted, the raw + deserialized JSON is returned. Returns ------- @@ -474,9 +566,10 @@ def get_completed_tasks(self) -> int: class CompletableTask(Task[T]): - def __init__(self) -> None: + def __init__(self, expected_type: type | None = None) -> None: super().__init__() self._retryable_parent: RetryableTask[Any] | None = None + self._expected_type = expected_type def complete(self, result: T): if self._is_complete: @@ -498,8 +591,8 @@ def fail(self, message: str, details: Exception | pb.TaskFailureDetails): class CancellableTask(CompletableTask[T]): """A completable task that can be cancelled before it finishes.""" - def __init__(self) -> None: - super().__init__() + def __init__(self, expected_type: type | None = None) -> None: + super().__init__(expected_type) self._is_cancelled = False self._cancel_handler: Callable[[], None] | None = None @@ -541,8 +634,9 @@ class RetryableTask(CompletableTask[T]): """A task that can be retried according to a retry policy.""" def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, - start_time: datetime, is_sub_orch: bool) -> None: - super().__init__() + start_time: datetime, is_sub_orch: bool, + expected_type: type | None = None) -> None: + super().__init__(expected_type) self._action = action self._retry_policy = retry_policy self._attempt_count = 1 diff --git a/durabletask/worker.py b/durabletask/worker.py index aff8f4f..d3a3d86 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -38,6 +38,7 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared import durabletask.internal.tracing as tracing +import durabletask.internal.type_discovery as type_discovery from durabletask.internal.grpc_resiliency import ( FailureTracker, get_full_jitter_delay_seconds, @@ -47,9 +48,11 @@ from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.payload.store import PayloadStore +from durabletask.serialization import DataConverter, JsonDataConverter TInput = TypeVar("TInput") TOutput = TypeVar("TOutput") +T = TypeVar("T") DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3) _STREAM_CLOSED_SENTINEL = object() @@ -528,6 +531,7 @@ def __init__( concurrency_options: ConcurrencyOptions | None = None, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, ): self._registry = _Registry() self._host_address = ( @@ -537,6 +541,7 @@ def __init__( self._shutdown = Event() self._is_running = False self._channel = channel + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() # The SDK owns (and may recreate) the gRPC channel only when the caller # did not provide one. Mirrors ``TaskHubGrpcClient._owns_channel`` so # both files use the same name for the same concept. @@ -1111,7 +1116,8 @@ def _execute_orchestrator( executor = _OrchestrationExecutor( self._registry, self._logger, persisted_orch_span_id=persisted_orch_span_id, - maximum_timer_interval=self.maximum_timer_interval) + maximum_timer_interval=self.maximum_timer_interval, + data_converter=self._data_converter) result = executor.execute(instance_id, req.pastEvents, req.newEvents) # Determine completion status for span @@ -1236,7 +1242,7 @@ def _execute_activity( if self._payload_store is not None: payload_helpers.deexternalize_payloads(req, self._payload_store) try: - executor = _ActivityExecutor(self._registry, self._logger) + executor = _ActivityExecutor(self._registry, self._logger, self._data_converter) with tracing.start_span( tracing.create_span_name("activity", req.name), trace_context=req.parentTraceContext, @@ -1310,7 +1316,9 @@ def _execute_entity_batch( if self._payload_store is not None: payload_helpers.deexternalize_payloads(req, self._payload_store) - entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None) + entity_state = StateShim( + self._data_converter.deserialize(req.entityState.value) if req.entityState.value else None, + self._data_converter) instance_id = req.instanceId try: @@ -1321,7 +1329,7 @@ def _execute_entity_batch( results: list[pb.OperationResult] = [] for operation in req.operations: start_time = datetime.now(timezone.utc) - executor = _EntityExecutor(self._registry, self._logger) + executor = _EntityExecutor(self._registry, self._logger, self._data_converter) operation_result = None @@ -1368,7 +1376,7 @@ def _execute_entity_batch( batch_result = pb.EntityBatchResult( results=results, actions=entity_state.get_operation_actions(), - entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None, # pyright: ignore[reportPrivateUsage] + entityState=helpers.get_string_value(self._data_converter.serialize(entity_state._current_state)) if entity_state._current_state is not None else None, # pyright: ignore[reportPrivateUsage] failureDetails=None, completionToken=completionToken, operationInfos=operation_infos, @@ -1412,6 +1420,7 @@ def __init__(self, instance_id: str, registry: _Registry, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, + data_converter: DataConverter | None = None, ): self._generator = None self._is_replaying = True @@ -1432,7 +1441,7 @@ def __init__(self, self._entity_context = OrchestrationEntityContext(instance_id) self._version: str | None = None self._completion_status: pb.OrchestrationStatus | None = None - self._received_events: dict[str, list[Any]] = {} + self._received_events: dict[str, list[str | None]] = {} self._pending_events: dict[str, list[task.CancellableTask[Any]]] = {} self._new_input: Any | None = None self._save_events = False @@ -1440,6 +1449,7 @@ def __init__(self, self._parent_trace_context: pb.TraceContext | None = None self._orchestration_trace_context: pb.TraceContext | None = None self._maximum_timer_interval = maximum_timer_interval + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() def run(self, generator: Generator[task.Task[Any], Any, Any]) -> None: self._generator = generator @@ -1497,7 +1507,7 @@ def set_complete( result_json: str | None = None if result is not None: try: - result_json = result if is_result_encoded else shared.to_json(result) + result_json = result if is_result_encoded else self._data_converter.serialize(result) except (ValueError, TypeError): self._is_complete = False self._result = None @@ -1557,18 +1567,15 @@ def get_actions(self) -> list[pb.OrchestratorAction]: # replayed when the new instance starts. for event_name, values in self._received_events.items(): for event_value in values: - encoded_value = ( - shared.to_json(event_value) if event_value else None - ) + # Buffered events are stored as their raw JSON payload + # (or None), so carry them over as-is without re-encoding. carryover_events.append( - ph.new_event_raised_event(event_name, encoded_value) + ph.new_event_raised_event(event_name, event_value) ) action = ph.new_complete_orchestration_action( self.next_sequence_number(), pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW, - result=shared.to_json(self._new_input) - if self._new_input is not None - else None, + result=self._data_converter.serialize(self._new_input), failure_details=None, carryover_events=carryover_events, ) @@ -1602,7 +1609,7 @@ def is_replaying(self) -> bool: def set_custom_status(self, custom_status: Any) -> None: self._encoded_custom_status = ( - shared.to_json(custom_status) if custom_status is not None else None + self._data_converter.serialize(custom_status) if custom_status is not None else None ) def create_timer(self, fire_at: datetime | timedelta) -> task.TimerTask: @@ -1644,6 +1651,30 @@ def _cancel_timer() -> None: self._pending_tasks[id] = timer_task return timer_task + @overload + def call_activity( + self, + activity: task.Activity[TInput, TOutput] | str, + *, + input: TInput | None = ..., + retry_policy: task.RetryPolicy | None = ..., + tags: dict[str, str] | None = ..., + return_type: type[T], + ) -> task.CompletableTask[T]: + ... + + @overload + def call_activity( + self, + activity: task.Activity[TInput, TOutput] | str, + *, + input: TInput | None = ..., + retry_policy: task.RetryPolicy | None = ..., + tags: dict[str, str] | None = ..., + return_type: None = ..., + ) -> task.CompletableTask[TOutput]: + ... + def call_activity( self, activity: task.Activity[TInput, TOutput] | str, @@ -1651,24 +1682,57 @@ def call_activity( input: TInput | None = None, retry_policy: task.RetryPolicy | None = None, tags: dict[str, str] | None = None, - ) -> task.CompletableTask[TOutput]: + return_type: type | None = None, + ) -> task.CompletableTask[Any]: id = self.next_sequence_number() + # An explicit return_type takes precedence; otherwise, when an activity + # function reference is supplied, discover its return annotation. The + # converter decides how a coercion failure is handled (the default is + # best-effort). + if return_type is None and not isinstance(activity, str): + return_type = type_discovery.activity_output_type(activity) + self.call_activity_function_helper( - id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags + id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags, + return_type=return_type, ) - return cast(task.CompletableTask[TOutput], self._pending_tasks.get(id, task.CompletableTask[TOutput]())) + return self._pending_tasks.get(id, task.CompletableTask[Any]()) + + @overload + def call_entity( + self, + entity: EntityInstanceId, + operation: str, + input: Any = ..., + *, + return_type: type[T], + ) -> task.CompletableTask[T]: + ... + + @overload + def call_entity( + self, + entity: EntityInstanceId, + operation: str, + input: Any = ..., + *, + return_type: None = ..., + ) -> task.CompletableTask[Any]: + ... def call_entity( self, entity: EntityInstanceId, operation: str, input: Any = None, + *, + return_type: type | None = None, ) -> task.CompletableTask[Any]: id = self.next_sequence_number() self.call_entity_function_helper( - id, entity, operation, input=input + id, entity, operation, input=input, return_type=return_type ) return self._pending_tasks.get(id, task.CompletableTask[Any]()) @@ -1693,6 +1757,32 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> task.CompletableTas ) return cast(task.CompletableTask[EntityLock], self._pending_tasks.get(id, task.CompletableTask[EntityLock]())) + @overload + def call_sub_orchestrator( + self, + orchestrator: task.Orchestrator[TInput, TOutput] | str, + *, + input: TInput | None = ..., + instance_id: str | None = ..., + retry_policy: task.RetryPolicy | None = ..., + version: str | None = ..., + return_type: type[T], + ) -> task.CompletableTask[T]: + ... + + @overload + def call_sub_orchestrator( + self, + orchestrator: task.Orchestrator[TInput, TOutput] | str, + *, + input: TInput | None = ..., + instance_id: str | None = ..., + retry_policy: task.RetryPolicy | None = ..., + version: str | None = ..., + return_type: None = ..., + ) -> task.CompletableTask[TOutput]: + ... + def call_sub_orchestrator( self, orchestrator: task.Orchestrator[TInput, TOutput] | str, @@ -1701,7 +1791,8 @@ def call_sub_orchestrator( instance_id: str | None = None, retry_policy: task.RetryPolicy | None = None, version: str | None = None, - ) -> task.CompletableTask[TOutput]: + return_type: type | None = None, + ) -> task.CompletableTask[Any]: id = self.next_sequence_number() if isinstance(orchestrator, str): orchestrator_name = orchestrator @@ -1716,9 +1807,10 @@ def call_sub_orchestrator( retry_policy=retry_policy, is_sub_orch=True, instance_id=instance_id, - version=orchestrator_version + version=orchestrator_version, + return_type=return_type, ) - return cast(task.CompletableTask[TOutput], self._pending_tasks.get(id, task.CompletableTask[TOutput]())) + return self._pending_tasks.get(id, task.CompletableTask[Any]()) def call_activity_function_helper( self, @@ -1732,12 +1824,13 @@ def call_activity_function_helper( instance_id: str | None = None, fn_task: task.CompletableTask[TOutput] | None = None, version: str | None = None, + return_type: type | None = None, ): if id is None: id = self.next_sequence_number() if fn_task is None: - encoded_input = shared.to_json(input) if input is not None else None + encoded_input = self._data_converter.serialize(input) else: # Here, we don't need to convert the input to JSON because it is already converted. # We just need to take string representation of it. @@ -1785,13 +1878,14 @@ def call_activity_function_helper( if fn_task is None: if retry_policy is None: - fn_task = task.CompletableTask[TOutput]() + fn_task = task.CompletableTask[TOutput](expected_type=return_type) else: fn_task = task.RetryableTask[TOutput]( retry_policy=retry_policy, action=action, start_time=self.current_utc_datetime, is_sub_orch=is_sub_orch, + expected_type=return_type, ) self._pending_tasks[id] = fn_task @@ -1802,6 +1896,7 @@ def call_entity_function_helper( operation: str, *, input: Any = None, + return_type: type | None = None, ) -> None: if id is None: id = self.next_sequence_number() @@ -1810,11 +1905,11 @@ def call_entity_function_helper( if not transition_valid: raise RuntimeError(error_message) - encoded_input = shared.to_json(input) if input is not None else None + encoded_input = self._data_converter.serialize(input) action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input, self.new_uuid()) self._pending_actions[id] = action - fn_task = task.CompletableTask[Any]() + fn_task = task.CompletableTask[Any](expected_type=return_type) self._pending_tasks[id] = fn_task def signal_entity_function_helper( @@ -1832,7 +1927,7 @@ def signal_entity_function_helper( if not transition_valid: raise RuntimeError(error_message) - encoded_input = shared.to_json(input) if input is not None else None + encoded_input = self._data_converter.serialize(input) action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid()) self._pending_actions[id] = action @@ -1869,20 +1964,31 @@ def _exit_critical_section(self) -> None: action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) self._pending_actions[task_id] = action - def wait_for_external_event(self, name: str) -> task.CancellableTask[Any]: + @overload + def wait_for_external_event(self, name: str, *, + data_type: type[T]) -> task.CancellableTask[T]: + ... + + @overload + def wait_for_external_event(self, name: str, *, + data_type: None = ...) -> task.CancellableTask[Any]: + ... + + def wait_for_external_event(self, name: str, *, + data_type: type | None = None) -> task.CancellableTask[Any]: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an # event with the given name so that we can resume the generator when it # arrives. If there are multiple events with the same name, we return # them in the order they were received. - external_event_task: task.CancellableTask[Any] = task.CancellableTask() + external_event_task: task.CancellableTask[Any] = task.CancellableTask(expected_type=data_type) event_name = name.casefold() event_list = self._received_events.get(event_name, None) if event_list: event_data = event_list.pop(0) if not event_list: del self._received_events[event_name] - external_event_task.complete(event_data) + external_event_task.complete(self._data_converter.deserialize(event_data, data_type)) else: task_list = self._pending_events.get(event_name, None) if not task_list: @@ -1945,9 +2051,11 @@ def __init__( logger: logging.Logger, persisted_orch_span_id: str | None = None, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, + data_converter: DataConverter | None = None, ): self._registry = registry self._logger = logger + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() self._maximum_timer_interval = maximum_timer_interval self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] @@ -1986,6 +2094,7 @@ def execute( instance_id, self._registry, maximum_timer_interval=self._maximum_timer_interval, + data_converter=self._data_converter, ) try: # Rebuild local state by replaying old history into the orchestrator function @@ -2101,7 +2210,9 @@ def process_event( if ( event.executionStarted.HasField("input") and event.executionStarted.input.value != "" ): - input = shared.from_json(event.executionStarted.input.value) + input_type = type_discovery.orchestrator_input_type(fn) + input = self._data_converter.deserialize( + event.executionStarted.input.value, input_type) result = fn( ctx, input @@ -2250,7 +2361,9 @@ def _cancel_timer() -> None: ) result = None if not ph.is_empty(event.taskCompleted.result): - result = shared.from_json(event.taskCompleted.result.value) + result = self._data_converter.deserialize( + event.taskCompleted.result.value, activity_task._expected_type # pyright: ignore[reportPrivateUsage] + ) activity_task.complete(result) ctx.resume() elif event.HasField("taskFailed"): @@ -2359,8 +2472,9 @@ def _cancel_timer() -> None: ) result = None if not ph.is_empty(event.subOrchestrationInstanceCompleted.result): - result = shared.from_json( - event.subOrchestrationInstanceCompleted.result.value + result = self._data_converter.deserialize( + event.subOrchestrationInstanceCompleted.result.value, + sub_orch_task._expected_type, # pyright: ignore[reportPrivateUsage] ) sub_orch_task.complete(result) ctx.resume() @@ -2429,20 +2543,27 @@ def _cancel_timer() -> None: if task_list: event_task = task_list.pop(0) if not ph.is_empty(event.eventRaised.input): - decoded_result = shared.from_json(event.eventRaised.input.value) + decoded_result = self._data_converter.deserialize( + event.eventRaised.input.value, event_task._expected_type # pyright: ignore[reportPrivateUsage] + ) event_task.complete(decoded_result) if not task_list: del ctx._pending_events[event_name] # pyright: ignore[reportPrivateUsage] ctx.resume() else: - # buffer the event + # Buffer the raw event payload (the JSON string). It is + # deserialized -- with the waiter's expected type, if any + # -- when ``wait_for_external_event`` later consumes it, + # so buffered and non-buffered events flow through the + # same converter path. event_list = ctx._received_events.get(event_name, None) # pyright: ignore[reportPrivateUsage] if not event_list: event_list = [] ctx._received_events[event_name] = event_list # pyright: ignore[reportPrivateUsage] + buffered_payload: str | None = None if not ph.is_empty(event.eventRaised.input): - decoded_result = shared.from_json(event.eventRaised.input.value) - event_list.append(decoded_result) + buffered_payload = event.eventRaised.input.value + event_list.append(buffered_payload) if not ctx.is_replaying: self._logger.info( f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." @@ -2566,7 +2687,10 @@ def _cancel_timer() -> None: return result = None if not ph.is_empty(event.entityOperationCompleted.output): - result = shared.from_json(event.entityOperationCompleted.output.value) + result = self._data_converter.deserialize( + event.entityOperationCompleted.output.value, + entity_task._expected_type, # pyright: ignore[reportPrivateUsage] + ) ctx._entity_context.recover_lock_after_call(entity_id) # pyright: ignore[reportPrivateUsage] entity_task.complete(result) ctx.resume() @@ -2647,7 +2771,15 @@ def _handle_entity_event_raised(self, result = None if not ph.is_empty(event.eventRaised.input): # TODO: Investigate why the event result is wrapped in a dict with "result" key - result = shared.from_json(event.eventRaised.input.value)["result"] + # The expected type applies to the unwrapped result value, not the + # transport wrapper. Unwrap first, then route the inner value back + # through the converter so custom converters and the expected type + # both apply. + unwrapped = self._data_converter.deserialize(event.eventRaised.input.value)["result"] + result = self._data_converter.deserialize( + self._data_converter.serialize(unwrapped), + entity_task._expected_type, # pyright: ignore[reportPrivateUsage] + ) if is_lock_event: ctx._entity_context.complete_acquire(event.eventRaised.name) # pyright: ignore[reportPrivateUsage] entity_task.complete(EntityLock(ctx)) @@ -2700,9 +2832,11 @@ def compare_versions(self, source_version: str | None, default_version: str | No class _ActivityExecutor: - def __init__(self, registry: _Registry, logger: logging.Logger): + def __init__(self, registry: _Registry, logger: logging.Logger, + data_converter: DataConverter | None = None): self._registry = registry self._logger = logger + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() def execute( self, @@ -2721,15 +2855,14 @@ def execute( f"Activity function named '{name}' was not registered!" ) - activity_input = shared.from_json(encoded_input) if encoded_input else None + input_type = type_discovery.activity_input_type(fn) if encoded_input else None + activity_input = self._data_converter.deserialize(encoded_input, input_type) ctx = task.ActivityContext(orchestration_id, task_id) # Execute the activity function activity_output = fn(ctx, activity_input) - encoded_output = ( - shared.to_json(activity_output) if activity_output is not None else None - ) + encoded_output = self._data_converter.serialize(activity_output) chars = len(encoded_output) if encoded_output else 0 self._logger.debug( f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output." @@ -2738,9 +2871,11 @@ def execute( class _EntityExecutor: - def __init__(self, registry: _Registry, logger: logging.Logger): + def __init__(self, registry: _Registry, logger: logging.Logger, + data_converter: DataConverter | None = None): self._registry = registry self._logger = logger + self._data_converter = data_converter if data_converter is not None else JsonDataConverter() self._entity_method_cache: dict[tuple[type, str], bool] = {} def execute( @@ -2761,8 +2896,9 @@ def execute( f"Entity function named '{entity_id.entity}' was not registered!" ) - entity_input = shared.from_json(encoded_input) if encoded_input else None - ctx = EntityContext(orchestration_id, operation, state, entity_id) + input_type = type_discovery.entity_input_type(fn, operation) if encoded_input else None + entity_input = self._data_converter.deserialize(encoded_input, input_type) + ctx = EntityContext(orchestration_id, operation, state, entity_id, self._data_converter) if isinstance(fn, type) and issubclass(fn, DurableEntity): entity_instance = fn() @@ -2792,9 +2928,7 @@ def execute( # Execute the entity function entity_output = fn(ctx, entity_input) - encoded_output = ( - shared.to_json(entity_output) if entity_output is not None else None - ) + encoded_output = self._data_converter.serialize(entity_output) chars = len(encoded_output) if encoded_output else 0 self._logger.debug( f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output." diff --git a/examples/human_interaction.py b/examples/human_interaction.py index fd92157..4025bc2 100644 --- a/examples/human_interaction.py +++ b/examples/human_interaction.py @@ -12,7 +12,7 @@ from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from typing import Any, NamedTuple +from typing import Any from azure.identity import DefaultAzureCredential @@ -21,11 +21,6 @@ from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -class Approval(NamedTuple): - """Represents an approval event payload""" - approver: str - - @dataclass class Order: """Represents a purchase order""" @@ -37,6 +32,12 @@ def __str__(self): return f'{self.Product} ({self.Quantity})' +@dataclass +class Approval: + """Represents an approval decision raised as an external event.""" + approver: str + + def send_approval_request(_: task.ActivityContext, order: Order) -> None: """Activity function that sends an approval request to the manager""" time.sleep(5) @@ -58,9 +59,11 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order) -> Gen yield ctx.call_activity(send_approval_request, input=order) # Approvals must be received within 24 hours or they will be cancelled. - approval_event = ctx.wait_for_external_event("approval_received") + # Passing ``data_type`` reconstructs the event payload as an ``Approval``. + approval_event = ctx.wait_for_external_event("approval_received", data_type=Approval) timeout_event = ctx.create_timer(timedelta(hours=24)) - winner = yield task.when_any([approval_event, timeout_event]) + pending: list[task.Task[Any]] = [approval_event, timeout_event] + winner = yield task.when_any(pending) if winner == timeout_event: return "Cancelled" @@ -96,7 +99,7 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order) -> Gen def prompt_for_approval(): input("Press [ENTER] to approve the order...\n") - approval_event = Approval(args.approver) + approval_event = Approval(approver=args.approver) c.raise_orchestration_event(instance_id, "approval_received", data=approval_event) # Prompt the user for approval on a background thread @@ -142,7 +145,7 @@ def prompt_for_approval(): def prompt_for_approval(): input("Press [ENTER] to approve the order...\n") - approval_event = Approval(args.approver) + approval_event = Approval(approver=args.approver) c.raise_orchestration_event(instance_id, "approval_received", data=approval_event) # Prompt the user for approval on a background thread diff --git a/examples/in_memory_backend_example/src/workflows.py b/examples/in_memory_backend_example/src/workflows.py index eecbd2c..8d4ba6e 100644 --- a/examples/in_memory_backend_example/src/workflows.py +++ b/examples/in_memory_backend_example/src/workflows.py @@ -9,11 +9,12 @@ Note on serialization --------------------- -The Durable Task SDK serializes dataclass and namedtuple inputs to JSON. -When deserialized on the receiving side, top-level objects become -``SimpleNamespace`` instances while nested objects become plain ``dict``s. -Activities that receive complex inputs should therefore use dict-style -access (``item["quantity"]``) for nested data. +The Durable Task SDK serializes dataclass inputs to JSON. Annotating an +orchestrator's or activity's input parameter with its dataclass type lets the +SDK reconstruct that type on the receiving side (including nested dataclass +fields), so the functions below use attribute access (``order.items``, +``item.quantity``). Without a type annotation, payloads arrive as plain +``dict`` / ``list`` values and would need dict-style access instead. """ from collections.abc import Generator @@ -27,14 +28,14 @@ # --------------------------------------------------------------------------- # Data models # --------------------------------------------------------------------------- -# These dataclasses document the expected shape of the data. At runtime, -# they are serialized to JSON and arrive in activities as SimpleNamespace -# (top-level) or dict (nested) objects. +# These dataclasses describe the shape of the data. Because the orchestrators +# and activities annotate their inputs with these types, the SDK reconstructs +# them (including the nested ``OrderItem`` list) on the receiving side. @dataclass class OrderItem: - """A single item in an order (arrives as a ``dict`` inside activities).""" + """A single item in an order.""" name: str quantity: int unit_price: float @@ -42,7 +43,7 @@ class OrderItem: @dataclass class Order: - """An order containing one or more items (arrives as ``SimpleNamespace``).""" + """An order containing one or more items.""" customer: str items: list[OrderItem] @@ -52,7 +53,7 @@ class Order: # --------------------------------------------------------------------------- -def validate_order(ctx: task.ActivityContext, order: Any) -> None: +def validate_order(ctx: task.ActivityContext, order: Order) -> None: """Validate that the order has items and all quantities/prices are valid. Raises ``ValueError`` on invalid input. @@ -60,17 +61,17 @@ def validate_order(ctx: task.ActivityContext, order: Any) -> None: if not order.items: raise ValueError("Order must contain at least one item") for item in order.items: - if item["quantity"] <= 0: + if item.quantity <= 0: raise ValueError( - f"Invalid quantity for '{item['name']}': {item['quantity']}") - if item["unit_price"] < 0: + f"Invalid quantity for '{item.name}': {item.quantity}") + if item.unit_price < 0: raise ValueError( - f"Invalid price for '{item['name']}': {item['unit_price']}") + f"Invalid price for '{item.name}': {item.unit_price}") -def calculate_total(ctx: task.ActivityContext, items: list[Any]) -> float: - """Return the total cost for a list of item dicts.""" - return sum(item["quantity"] * item["unit_price"] for item in items) +def calculate_total(ctx: task.ActivityContext, items: list[OrderItem]) -> float: + """Return the total cost for a list of order items.""" + return sum(item.quantity * item.unit_price for item in items) def process_payment(ctx: task.ActivityContext, amount: float) -> str: @@ -97,7 +98,7 @@ def ship_item(ctx: task.ActivityContext, item_name: str) -> str: # --------------------------------------------------------------------------- -def process_order(ctx: task.OrchestrationContext, order: Any) -> Generator[task.Task[Any], Any, dict[str, Any]]: +def process_order(ctx: task.OrchestrationContext, order: Order) -> Generator[task.Task[Any], Any, dict[str, Any]]: """Process a complete order: validate, pay, ship items in parallel, confirm. Demonstrates: @@ -117,7 +118,7 @@ def process_order(ctx: task.OrchestrationContext, order: Any) -> Generator[task. # 4. Ship all items in parallel (fan-out / fan-in) ship_tasks: list[task.Task[str]] = [ - ctx.call_activity(ship_item, input=item["name"]) + ctx.call_activity(ship_item, input=item.name) for item in order.items ] tracking_ids: list[str] = yield task.when_all(ship_tasks) @@ -137,7 +138,7 @@ def process_order(ctx: task.OrchestrationContext, order: Any) -> Generator[task. } -def order_with_approval(ctx: task.OrchestrationContext, order: Any) -> Generator[task.Task[Any], Any, dict[str, Any]]: +def order_with_approval(ctx: task.OrchestrationContext, order: Order) -> Generator[task.Task[Any], Any, dict[str, Any]]: """Order workflow that requires manager approval for high-value orders. Demonstrates: diff --git a/tests/durabletask/extensions/history_export/test_entity.py b/tests/durabletask/extensions/history_export/test_entity.py index 5442a4d..434d9cb 100644 --- a/tests/durabletask/extensions/history_export/test_entity.py +++ b/tests/durabletask/extensions/history_export/test_entity.py @@ -10,7 +10,6 @@ from __future__ import annotations -import json from datetime import datetime, timezone from typing import Callable, Optional @@ -88,9 +87,9 @@ def _create_payload() -> dict: def _state_dict(metadata) -> dict: - raw = metadata.get_state(str) - assert raw is not None - return json.loads(raw) + state = metadata.get_typed_state() + assert isinstance(state, dict) + return state def _wait_for_state( @@ -106,12 +105,8 @@ def _check() -> Optional[dict]: meta = c.get_entity(entity_id, include_state=True) if meta is None: return None - raw = meta.get_state(str) - if not raw: - return None - try: - state = json.loads(raw) - except (TypeError, ValueError): + state = meta.get_typed_state() + if not isinstance(state, dict): return None return state if predicate(state) else None diff --git a/tests/durabletask/test_batch_actions.py b/tests/durabletask/test_batch_actions.py index 01f0941..65637df 100644 --- a/tests/durabletask/test_batch_actions.py +++ b/tests/durabletask/test_batch_actions.py @@ -387,6 +387,41 @@ def counter_entity(ctx: entities.EntityContext, input): worker.stop() +def test_entity_falsy_state_is_persisted(backend): + """A falsy entity state (e.g. 0) must be persisted, not dropped as None. + + Regression test: the batch result previously persisted state only when it + was truthy, so a valid falsy state such as ``0`` was written as ``None`` and + effectively deleted. Only an actual ``None`` should clear the state. + """ + def counter_entity(ctx: entities.EntityContext, input): + if ctx.operation == "set": + ctx.set_state(input) + elif ctx.operation == "get": + return ctx.get_state(int) + + worker = TaskHubGrpcWorker(host_address=HOST) + + worker.add_entity(counter_entity) + worker.start() + + try: + with TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("counter_entity", "falsyCounter") + # Set the state to a falsy-but-valid value. + c.signal_entity(entity_id, "set", input=0) + time.sleep(3) # Wait for the signal to be processed + + query = client.EntityQuery(include_state=True) + all_entities = c.get_all_entities(query) + matches = [e for e in all_entities if e.id == entity_id] + # The entity must still exist with its falsy state intact. + assert len(matches) == 1 + assert matches[0].get_state(int) == 0 + finally: + worker.stop() + + def test_get_entities_by_instance_id_prefix(backend): def counter_entity(ctx: entities.EntityContext, input): if ctx.operation == "set": diff --git a/tests/durabletask/test_data_converter.py b/tests/durabletask/test_data_converter.py new file mode 100644 index 0000000..2497940 --- /dev/null +++ b/tests/durabletask/test_data_converter.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for the DataConverter abstraction and the default JsonDataConverter.""" + +import json +import logging +from dataclasses import dataclass +from typing import Any + +from durabletask.serialization import ( + DEFAULT_DATA_CONVERTER, + DataConverter, + JsonDataConverter, +) + + +@dataclass +class Order: + item: str + quantity: int + + +# ----- JsonDataConverter ----- + + +def test_serialize_none_returns_none(): + assert JsonDataConverter().serialize(None) is None + + +def test_serialize_dataclass_plain_json(): + assert json.loads(JsonDataConverter().serialize(Order("widget", 3))) == { + "item": "widget", + "quantity": 3, + } + + +def test_deserialize_none_or_empty_returns_none(): + conv = JsonDataConverter() + assert conv.deserialize(None) is None + assert conv.deserialize("") is None + assert conv.deserialize(None, Order) is None + + +def test_deserialize_without_type_returns_raw(): + conv = JsonDataConverter() + assert conv.deserialize('{"item": "x", "quantity": 1}') == {"item": "x", "quantity": 1} + + +def test_deserialize_coerces_to_type(): + conv = JsonDataConverter() + result = conv.deserialize('{"item": "x", "quantity": 1}', Order) + assert isinstance(result, Order) + assert result == Order("x", 1) + + +def test_deserialize_best_effort_falls_back_to_raw(caplog): + conv = JsonDataConverter() + # Missing required 'quantity' field -> coercion fails -> raw dict returned. + with caplog.at_level(logging.DEBUG, logger="durabletask"): + result = conv.deserialize('{"item": "x"}', Order) + assert result == {"item": "x"} + assert any("coerce" in r.message.lower() for r in caplog.records) + + +def test_round_trip_through_converter(): + conv = JsonDataConverter() + encoded = conv.serialize(Order("book", 2)) + assert conv.deserialize(encoded, Order) == Order("book", 2) + + +def test_default_converter_is_json_converter(): + assert isinstance(DEFAULT_DATA_CONVERTER, JsonDataConverter) + + +# ----- Custom converter ----- + + +def test_custom_converter_is_a_dataconverter_subclass(): + class UpperConverter(DataConverter): + def serialize(self, value: Any) -> str | None: + return None if value is None else json.dumps(str(value).upper()) + + def deserialize(self, data: str | None, target_type: type | None = None) -> Any: + return None if data is None else json.loads(data) + + def coerce(self, value: Any, target_type: type | None = None) -> Any: + return value + + conv = UpperConverter() + assert conv.serialize("hello") == '"HELLO"' + assert conv.deserialize('"HELLO"') == "HELLO" + assert conv.coerce("HELLO") == "HELLO" diff --git a/tests/durabletask/test_entity_executor.py b/tests/durabletask/test_entity_executor.py index c3f369d..e853f13 100644 --- a/tests/durabletask/test_entity_executor.py +++ b/tests/durabletask/test_entity_executor.py @@ -132,3 +132,59 @@ def counter(ctx: entities.EntityContext, input): result = executor.execute("test-orch", entity_id, "get", state, None) assert result == "42" + + +class TestStateShimCoercion: + """Tests for StateShim.get_state type coercion via the data converter.""" + + def test_get_state_none_returns_default(self): + state = StateShim(None) + assert state.get_state(int, 0) == 0 + + def test_get_state_none_without_default_returns_none(self): + state = StateShim(None) + assert state.get_state(int) is None + + def test_get_state_passes_through_matching_type(self): + state = StateShim(5) + assert state.get_state(int) == 5 + + def test_get_state_constructor_coercion(self): + state = StateShim("5") + assert state.get_state(int) == 5 + + def test_get_state_coerces_dataclass(self): + from dataclasses import dataclass + + @dataclass + class Counter: + value: int + + # State is stored as a plain dict (as it would be after from_json). + state = StateShim({"value": 7}) + result = state.get_state(Counter) + assert isinstance(result, Counter) + assert result.value == 7 + + def test_get_state_uses_from_json_hook(self): + class Wrapped: + def __init__(self, n: int): + self.n = n + + @classmethod + def from_json(cls, data): + return cls(data["n"]) + + state = StateShim({"n": 3}) + result = state.get_state(Wrapped) + assert isinstance(result, Wrapped) + assert result.n == 3 + + def test_get_state_invalid_coercion_raises(self): + # An explicit intended_type that the state cannot be coerced to raises, + # restoring the pre-existing strict contract for entity state access. + import pytest + + state = StateShim("not-an-int") + with pytest.raises(TypeError): + state.get_state(int) diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 0134b12..790f8e2 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -450,6 +450,151 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): assert complete_action.result.value == encoded_output +def test_activity_task_completion_with_return_type(): + """Tests that call_activity(return_type=...) coerces the result to a dataclass.""" + from dataclasses import dataclass + + @dataclass + class Result: + message: str + + def dummy_activity(ctx, _): + pass + + captured: dict = {} + + def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): + result = yield ctx.call_activity(dummy_activity, return_type=Result) + captured["type"] = type(result).__name__ + captured["message"] = result.message + return result.message + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] + + new_events = [helpers.new_task_completed_event(1, json.dumps({"message": "hi"}))] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert captured["type"] == "Result" + assert captured["message"] == "hi" + + +def test_activity_return_type_discovered_from_annotation(): + """Tests that call_activity discovers the return type from the activity's annotation.""" + from dataclasses import dataclass + + @dataclass + class Result: + message: str + + def annotated_activity(ctx, _) -> Result: + ... + + captured: dict = {} + + def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): + result = yield ctx.call_activity(annotated_activity) + captured["type"] = type(result).__name__ + captured["message"] = result.message + return result.message + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(annotated_activity))] + + new_events = [helpers.new_task_completed_event(1, json.dumps({"message": "hi"}))] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + assert captured["type"] == "Result" + assert captured["message"] == "hi" + + +def test_explicit_return_type_overrides_discovered_annotation(): + """Tests that an explicit return_type takes precedence over the annotation.""" + from dataclasses import dataclass + + @dataclass + class Annotated: + value: str + + @dataclass + class Override: + value: str + + def annotated_activity(ctx, _) -> Annotated: + ... + + captured: dict = {} + + def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): + result = yield ctx.call_activity(annotated_activity, return_type=Override) + captured["type"] = type(result).__name__ + return result.value + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(annotated_activity))] + + new_events = [helpers.new_task_completed_event(1, json.dumps({"value": "x"}))] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, old_events, new_events) + + assert captured["type"] == "Override" + + +def test_orchestrator_input_type_discovery(): + """Tests that an orchestrator's dataclass input annotation is reconstructed.""" + from dataclasses import dataclass + + @dataclass + class StartArgs: + name: str + count: int + + captured: dict = {} + + def orchestrator(ctx: task.OrchestrationContext, args: StartArgs): + captured["type"] = type(args).__name__ + captured["name"] = args.name + captured["count"] = args.count + return args.count + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + encoded_input = json.dumps({"name": "abc", "count": 5}) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=encoded_input)] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor.execute(TEST_INSTANCE_ID, [], new_events) + + assert captured["type"] == "StartArgs" + assert captured["name"] == "abc" + assert captured["count"] == 5 + + def test_activity_task_failed(): """Tests the failure of an activity task""" def dummy_activity(ctx, _): diff --git a/tests/durabletask/test_orchestration_state.py b/tests/durabletask/test_orchestration_state.py new file mode 100644 index 0000000..fcd7797 --- /dev/null +++ b/tests/durabletask/test_orchestration_state.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for typed accessors on client.OrchestrationState.""" + +import json +from dataclasses import dataclass +from datetime import datetime, timezone + +from durabletask.client import OrchestrationState, OrchestrationStatus + + +@dataclass +class Result: + message: str + count: int + + +class Money: + def __init__(self, amount: int): + self.amount = amount + + def to_json(self) -> dict: + return {"amount": self.amount} + + @classmethod + def from_json(cls, data: dict) -> "Money": + return cls(data["amount"]) + + +def _state(serialized_input=None, serialized_output=None, serialized_custom_status=None) -> OrchestrationState: + now = datetime.now(timezone.utc) + return OrchestrationState( + instance_id="abc123", + name="test", + runtime_status=OrchestrationStatus.COMPLETED, + created_at=now, + last_updated_at=now, + serialized_input=serialized_input, + serialized_output=serialized_output, + serialized_custom_status=serialized_custom_status, + failure_details=None, + ) + + +# ----- get_output ----- + + +def test_get_output_none_returns_none(): + assert _state(serialized_output=None).get_output() is None + assert _state(serialized_output=None).get_output(Result) is None + + +def test_get_output_raw_without_type(): + state = _state(serialized_output=json.dumps({"message": "hi", "count": 2})) + assert state.get_output() == {"message": "hi", "count": 2} + + +def test_get_output_coerced_to_dataclass(): + state = _state(serialized_output=json.dumps({"message": "hi", "count": 2})) + result = state.get_output(Result) + assert isinstance(result, Result) + assert result == Result("hi", 2) + + +def test_get_output_uses_from_json_hook(): + state = _state(serialized_output=json.dumps({"amount": 50})) + result = state.get_output(Money) + assert isinstance(result, Money) + assert result.amount == 50 + + +def test_get_output_primitive(): + assert _state(serialized_output=json.dumps(42)).get_output(int) == 42 + + +# ----- get_input ----- + + +def test_get_input_none_returns_none(): + assert _state(serialized_input=None).get_input(Result) is None + + +def test_get_input_coerced_to_dataclass(): + state = _state(serialized_input=json.dumps({"message": "in", "count": 1})) + result = state.get_input(Result) + assert isinstance(result, Result) + assert result == Result("in", 1) + + +# ----- get_custom_status ----- + + +def test_get_custom_status_none_returns_none(): + assert _state(serialized_custom_status=None).get_custom_status(Result) is None + + +def test_get_custom_status_raw_without_type(): + state = _state(serialized_custom_status=json.dumps({"phase": "step1"})) + assert state.get_custom_status() == {"phase": "step1"} + + +def test_get_custom_status_coerced_to_dataclass(): + state = _state(serialized_custom_status=json.dumps({"message": "s", "count": 9})) + result = state.get_custom_status(Result) + assert isinstance(result, Result) + assert result == Result("s", 9) diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py new file mode 100644 index 0000000..33003c2 --- /dev/null +++ b/tests/durabletask/test_serialization.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the JSON serialization codec in durabletask.internal.json_codec.""" + +import json +from collections import namedtuple +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from durabletask.internal import json_codec + + +# ----- Test fixtures ----- + + +@dataclass +class Address: + street: str + city: str + + +@dataclass +class Person: + name: str + age: int + address: Address | None = None + + +class Widget: + """A custom object using the to_json/from_json convention.""" + + def __init__(self, label: str, size: int): + self.label = label + self.size = size + + def to_json(self) -> dict: + return {"label": self.label, "size": self.size} + + @classmethod + def from_json(cls, data: dict) -> "Widget": + return cls(data["label"], data["size"]) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Widget) + and other.label == self.label + and other.size == self.size + ) + + +Point = namedtuple("Point", ["x", "y"]) + + +class StaticWidget: + """Custom object whose to_json/from_json are static methods returning a str. + + This mirrors the ``azure-functions-durable`` sample convention where + ``to_json(obj)`` is a ``@staticmethod`` that returns a string. + """ + + def __init__(self, name: str): + self.name = name + + @staticmethod + def to_json(obj: "StaticWidget") -> str: + return obj.name + + @staticmethod + def from_json(data: str) -> "StaticWidget": + return StaticWidget(data) + + def __eq__(self, other: object) -> bool: + return isinstance(other, StaticWidget) and other.name == self.name + + +def test_to_json_static_hook_receives_instance(): + # type(obj).to_json(obj) must invoke the @staticmethod with the instance. + assert json_codec.to_json(StaticWidget("gizmo")) == '"gizmo"' + + +def test_static_hook_round_trips_with_expected_type(): + encoded = json_codec.to_json(StaticWidget("gizmo")) + result = json_codec.from_json(encoded, StaticWidget) + assert isinstance(result, StaticWidget) + assert result == StaticWidget("gizmo") + + +def test_instance_to_json_hook_receives_instance(): + # The same type(obj).to_json(obj) path works for plain instance methods. + assert json.loads(json_codec.to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} + + +# ----- to_json ----- + + +def test_to_json_builtins_are_plain_json(): + assert json_codec.to_json({"a": 1, "b": [1, 2, 3]}) == json.dumps({"a": 1, "b": [1, 2, 3]}) + assert json_codec.to_json("hello") == '"hello"' + assert json_codec.to_json(42) == "42" + + +def test_to_json_dataclass_emits_plain_dict_without_marker(): + encoded = json_codec.to_json(Address("1 Main St", "Redmond")) + parsed = json.loads(encoded) + assert parsed == {"street": "1 Main St", "city": "Redmond"} + assert json_codec.AUTO_SERIALIZED not in encoded + + +def test_to_json_nested_dataclass_has_no_marker(): + encoded = json_codec.to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) + assert json_codec.AUTO_SERIALIZED not in encoded + parsed = json.loads(encoded) + assert parsed["address"] == {"street": "1 Main St", "city": "Redmond"} + + +def test_to_json_simplenamespace_emits_plain_dict(): + encoded = json_codec.to_json(SimpleNamespace(a=1, b="two")) + assert json_codec.AUTO_SERIALIZED not in encoded + assert json.loads(encoded) == {"a": 1, "b": "two"} + + +def test_to_json_custom_object_uses_to_json_hook(): + encoded = json_codec.to_json(Widget("gear", 5)) + assert json.loads(encoded) == {"label": "gear", "size": 5} + + +def test_to_json_namedtuple_serializes_as_array(): + # Without an expected_type the field names are not preserved on the wire. + assert json_codec.to_json(Point(1, 2)) == "[1, 2]" + + +def test_to_json_unserializable_raises_typeerror_with_cause(): + class NotSerializable: + pass + + with pytest.raises(TypeError) as exc_info: + json_codec.to_json(NotSerializable()) + assert "NotSerializable" in str(exc_info.value) + assert exc_info.value.__cause__ is not None + + +# ----- from_json without expected_type (loose mode) ----- + + +def test_from_json_returns_raw_without_expected_type(): + assert json_codec.from_json('{"a": 1}') == {"a": 1} + assert json_codec.from_json("[1, 2, 3]") == [1, 2, 3] + assert json_codec.from_json("42") == 42 + + +def test_from_json_legacy_marker_decodes_to_simplenamespace(): + legacy = json.dumps({"a": 1, "b": 2, json_codec.AUTO_SERIALIZED: True}) + result = json_codec.from_json(legacy) + assert isinstance(result, SimpleNamespace) + assert result.a == 1 + assert result.b == 2 + + +def test_legacy_simplenamespace_reserializes_cleanly(): + legacy = json.dumps({"a": 1, json_codec.AUTO_SERIALIZED: True}) + ns = json_codec.from_json(legacy) + reencoded = json_codec.to_json(ns) + assert json_codec.AUTO_SERIALIZED not in reencoded + assert json.loads(reencoded) == {"a": 1} + + +# ----- from_json with expected_type (type-directed) ----- + + +def test_from_json_coerces_to_dataclass(): + encoded = json_codec.to_json(Address("1 Main St", "Redmond")) + result = json_codec.from_json(encoded, Address) + assert isinstance(result, Address) + assert result == Address("1 Main St", "Redmond") + + +def test_from_json_coerces_nested_dataclass(): + encoded = json_codec.to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) + result = json_codec.from_json(encoded, Person) + assert isinstance(result, Person) + assert isinstance(result.address, Address) + assert result.address.city == "Redmond" + + +def test_from_json_coerces_optional_dataclass_when_present(): + result = json_codec.from_json('{"name": "Ada", "age": 30, "address": null}', Person) + assert isinstance(result, Person) + assert result.address is None + + +def test_from_json_coerces_list_of_dataclasses(): + encoded = json_codec.to_json([Address("a", "b"), Address("c", "d")]) + result = json_codec.from_json(encoded, list[Address]) + assert all(isinstance(item, Address) for item in result) + assert result[1] == Address("c", "d") + + +def test_from_json_uses_from_json_hook(): + encoded = json_codec.to_json(Widget("gear", 5)) + result = json_codec.from_json(encoded, Widget) + assert isinstance(result, Widget) + assert result == Widget("gear", 5) + + +def test_from_json_primitive_passthrough_with_expected_type(): + assert json_codec.from_json("42", int) == 42 + assert json_codec.from_json('"hi"', str) == "hi" + + +def test_from_json_legacy_marker_with_expected_type_strips_and_builds(): + # A payload persisted by an older SDK version (with the marker) must still + # decode when the caller now passes an expected_type. + legacy = json.dumps( + {"street": "1 Main St", "city": "Redmond", json_codec.AUTO_SERIALIZED: True} + ) + result = json_codec.from_json(legacy, Address) + assert isinstance(result, Address) + assert result == Address("1 Main St", "Redmond") + + +def test_from_json_none_payload_with_expected_type(): + assert json_codec.from_json("null", Address) is None + + +# ----- coerce_to_type ----- + + +def test_coerce_to_type_none_type_returns_value(): + value = {"a": 1} + assert json_codec.coerce_to_type(value, None) is value + + +def test_coerce_to_type_already_correct_type(): + addr = Address("a", "b") + assert json_codec.coerce_to_type(addr, Address) is addr + + +def test_coerce_to_type_invalid_conversion_raises(): + with pytest.raises(TypeError): + json_codec.coerce_to_type("not-a-number", int) + + +def test_coerce_optional_dataclass_coerces_member(): + from typing import Optional + result = json_codec.coerce_to_type({"street": "a", "city": "b"}, Optional[Address]) + assert isinstance(result, Address) + + +def test_coerce_genuine_union_leaves_unmatched_value_untouched(): + from typing import Union + + @dataclass + class A: + x: int + + @dataclass + class B: + y: int + + # A dict matching neither A nor B by isinstance must be returned unchanged, + # not force-coerced into the first union member. + value = {"z": 1} + assert json_codec.coerce_to_type(value, Union[A, B]) == {"z": 1} diff --git a/tests/durabletask/test_type_discovery.py b/tests/durabletask/test_type_discovery.py new file mode 100644 index 0000000..0f7a6ec --- /dev/null +++ b/tests/durabletask/test_type_discovery.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for annotation-based input type discovery and inbound coercion.""" + +import json +import logging +from dataclasses import dataclass +from typing import Any, Optional + +from durabletask import entities, task, worker +from durabletask.internal import type_discovery +from durabletask.internal.entity_state_shim import StateShim + +TEST_LOGGER = logging.getLogger("tests") + + +# ----- fixtures ----- + + +@dataclass +class Order: + item: str + quantity: int + + +class Money: + def __init__(self, amount: int): + self.amount = amount + + def to_json(self) -> dict[str, Any]: + return {"amount": self.amount} + + @classmethod + def from_json(cls, data: dict[str, Any]) -> "Money": + return cls(data["amount"]) + + +# ----- type_discovery helper ----- + + +class TestIsReconstructable: + def test_dataclass_is_reconstructable(self): + assert type_discovery.is_reconstructable(Order) is True + + def test_from_json_type_is_reconstructable(self): + assert type_discovery.is_reconstructable(Money) is True + + def test_builtins_are_not_reconstructable(self): + assert type_discovery.is_reconstructable(int) is False + assert type_discovery.is_reconstructable(str) is False + assert type_discovery.is_reconstructable(dict) is False + + def test_optional_dataclass_is_reconstructable(self): + assert type_discovery.is_reconstructable(Optional[Order]) is True + + def test_list_of_dataclass_is_reconstructable(self): + assert type_discovery.is_reconstructable(list[Order]) is True + + def test_list_of_builtin_is_not_reconstructable(self): + assert type_discovery.is_reconstructable(list[int]) is False + + +class TestInputTypeDiscovery: + def test_orchestrator_input_type_dataclass(self): + def orch(ctx, order: Order): + ... + assert type_discovery.orchestrator_input_type(orch) is Order + + def test_activity_input_type_dataclass(self): + def act(ctx, order: Order): + ... + assert type_discovery.activity_input_type(act) is Order + + def test_input_type_builtin_returns_none(self): + def act(ctx, value: int): + ... + assert type_discovery.activity_input_type(act) is None + + def test_input_type_unannotated_returns_none(self): + def act(ctx, value): + ... + assert type_discovery.activity_input_type(act) is None + + def test_input_type_no_input_param_returns_none(self): + def orch(ctx): + ... + assert type_discovery.orchestrator_input_type(orch) is None + + def test_postponed_annotation_resolves(self): + # Annotation provided as a string (PEP 563 style) still resolves because + # Order is importable in this module's globals. + def act(ctx, order: "Order"): + ... + assert type_discovery.activity_input_type(act) is Order + + def test_function_entity_input_type(self): + def counter(ctx, order: Order): + ... + assert type_discovery.entity_input_type(counter, "any_op") is Order + + def test_class_entity_input_type_per_operation(self): + class Store(entities.DurableEntity): + def add(self, order: Order): + ... + + def clear(self): + ... + + assert type_discovery.entity_input_type(Store, "add") is Order + # Operation with no input parameter. + assert type_discovery.entity_input_type(Store, "clear") is None + # Unknown operation. + assert type_discovery.entity_input_type(Store, "missing") is None + + +class TestActivityOutputTypeDiscovery: + def test_dataclass_return_annotation(self): + def act(ctx, _) -> Order: + ... + assert type_discovery.activity_output_type(act) is Order + + def test_from_json_return_annotation(self): + def act(ctx, _) -> Money: + ... + assert type_discovery.activity_output_type(act) is Money + + def test_builtin_return_annotation_returns_none(self): + def act(ctx, _) -> int: + ... + assert type_discovery.activity_output_type(act) is None + + def test_unannotated_return_returns_none(self): + def act(ctx, _): + ... + assert type_discovery.activity_output_type(act) is None + + def test_optional_dataclass_return(self): + def act(ctx, _) -> Optional[Order]: + ... + assert type_discovery.activity_output_type(act) is Optional[Order] + + def test_postponed_return_annotation_resolves(self): + def act(ctx, _) -> "Order": + ... + assert type_discovery.activity_output_type(act) is Order + + def test_string_name_returns_none(self): + assert type_discovery.activity_output_type("some_activity_name") is None + + +# ----- activity executor inbound coercion ----- + + +def _activity_executor(fn) -> tuple[worker._ActivityExecutor, str]: + registry = worker._Registry() + name = registry.add_activity(fn) + return worker._ActivityExecutor(registry, TEST_LOGGER), name + + +def test_activity_input_coerced_to_dataclass(): + seen: dict[str, Any] = {} + + def handle(ctx: task.ActivityContext, order: Order): + seen["type"] = type(order).__name__ + seen["item"] = order.item + return order.quantity + + executor, name = _activity_executor(handle) + result = executor.execute("orch1", name, 1, json.dumps({"item": "widget", "quantity": 3})) + assert seen["type"] == "Order" + assert seen["item"] == "widget" + assert json.loads(result) == 3 + + +def test_activity_input_coerced_via_from_json(): + seen: dict[str, Any] = {} + + def handle(ctx: task.ActivityContext, money: Money): + seen["type"] = type(money).__name__ + seen["amount"] = money.amount + return money.amount + + executor, name = _activity_executor(handle) + result = executor.execute("orch1", name, 1, json.dumps({"amount": 50})) + assert seen["type"] == "Money" + assert seen["amount"] == 50 + assert json.loads(result) == 50 + + +def test_activity_builtin_input_unchanged(): + seen: dict[str, Any] = {} + + def handle(ctx: task.ActivityContext, value: int): + seen["type"] = type(value).__name__ + return value + + executor, name = _activity_executor(handle) + executor.execute("orch1", name, 1, json.dumps(7)) + assert seen["type"] == "int" + + +def test_activity_input_coercion_failure_falls_back_to_raw(): + seen: dict[str, Any] = {} + + def handle(ctx: task.ActivityContext, order: Order): + # Payload is missing the required 'quantity' field, so coercion fails and + # the raw dict is passed through instead of raising. + seen["type"] = type(order).__name__ + return "ok" + + executor, name = _activity_executor(handle) + result = executor.execute("orch1", name, 1, json.dumps({"item": "widget"})) + assert seen["type"] == "dict" + assert json.loads(result) == "ok" + + +# ----- entity executor inbound coercion ----- + + +def test_function_entity_input_coerced_to_dataclass(): + seen: dict[str, Any] = {} + + def store(ctx: entities.EntityContext, order: Order): + seen["type"] = type(order).__name__ + seen["item"] = order.item + + registry = worker._Registry() + registry.add_entity(store) + executor = worker._EntityExecutor(registry, TEST_LOGGER) + entity_id = entities.EntityInstanceId("store", "k1") + state = StateShim(None) + executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) + assert seen["type"] == "Order" + assert seen["item"] == "book" + + +def test_class_entity_input_coerced_per_operation(): + seen: dict[str, Any] = {} + + class Store(entities.DurableEntity): + def save(self, order: Order): + seen["type"] = type(order).__name__ + seen["item"] = order.item + + registry = worker._Registry() + registry.add_entity(Store) + executor = worker._EntityExecutor(registry, TEST_LOGGER) + entity_id = entities.EntityInstanceId("store", "k1") + state = StateShim(None) + executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) + assert seen["type"] == "Order" + assert seen["item"] == "book"