diff --git a/CHANGELOG.md b/CHANGELOG.md index a962b3a2..a85930d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,54 +12,44 @@ ADDED - Added a pluggable `DataConverter` (`durabletask.serialization`) accepted by `TaskHubGrpcWorker`, `TaskHubGrpcClient`, and `AsyncTaskHubGrpcClient` via a `data_converter` argument. Every payload boundary (inputs, outputs, events, - custom status, entity state) routes through it. The default + custom status, entity state) routes through it, so one converter controls how + Python values become JSON and how they are reconstructed. The default `JsonDataConverter` preserves existing behavior, so a custom converter (for - example one backed by pydantic) is opt-in. Custom objects can opt in via a - `to_json()` hook and a `from_json(value)` classmethod. -- `OrchestrationContext.call_activity`, `call_sub_orchestrator`, and - `call_entity` accept an optional `return_type`, and `wait_for_external_event` - accepts an optional `data_type`. When provided, the result/event payload is - reconstructed as that type (dataclasses — including nested dataclass, - `Optional`, and `list` fields — and `from_json()`-capable types) and the - returned task is typed accordingly (e.g. `call_activity(..., return_type=Foo)` - yields `CompletableTask[Foo]`). When omitted, the raw deserialized JSON is - returned as before. -- Inbound payloads are reconstructed from function type annotations. When an - orchestrator, activity, or entity operation annotates its input parameter (or - an activity its return value) with a dataclass or `from_json()`-capable type, - the payload is reconstructed as that type. Builtins and unannotated/unknown - types are passed through unchanged. An explicit `return_type` takes precedence - over a discovered annotation. -- Added typed accessors to `client.OrchestrationState`: `get_input()`, - `get_output()`, and `get_custom_status()` each accept an optional - `expected_type` and deserialize the corresponding payload, reconstructing - dataclasses and `from_json()`-capable types. The raw `serialized_*` fields are - retained. -- Objects exposing a `to_json()` method are now JSON-serializable when passed as - activity/orchestrator inputs or outputs. -- Added `EntityMetadata.get_typed_state(intended_type=...)`, which deserializes - the entity's persisted state and reconstructs dataclasses and - `from_json()`-capable types. The existing `get_state()` is unchanged: with no - argument it returns the raw serialized JSON payload, and `get_state(some_type)` - applies constructor-based coercion (`some_type(raw)`). -- Entity runtime state retrieval (`EntityContext.get_state(intended_type=...)` / - `DurableEntity.get_state(...)`) now also reconstructs dataclasses and - `from_json()`-capable types, in addition to the existing constructor-based - coercion. + example one backed by pydantic) is fully opt-in. +- Custom objects can participate in serialization by exposing a `to_json()` + method and a `from_json(value)` classmethod. Both are honored recursively, so + nested custom objects round-trip through their own hooks. +- Payloads are reconstructed into a caller-supplied type — dataclasses + (including nested fields), `from_json()`-capable types, and `enum.Enum` + members, recursing through `list`, `dict`, `tuple`, and `Optional`/`Union` + hints. The type comes from a function's annotations, from an explicit + `return_type` on `call_activity` / `call_sub_orchestrator` / `call_entity` + (or `data_type` on `wait_for_external_event`), or from the typed accessors + `get_input()` / `get_output()` / `get_custom_status()` on + `client.OrchestrationState` and `EntityMetadata.get_typed_state(...)`. It is + never inferred from the payload. Which annotated types are eligible is decided + by the converter via the overridable `DataConverter.can_reconstruct(...)`; a + custom converter can override it to recognize its own types (for example + `pydantic.BaseModel` subclasses). CHANGED - Custom objects (dataclasses, `SimpleNamespace`, namedtuples) are now serialized as plain JSON. Decoding such a payload *without* a type hint now yields a plain `dict` (previously a `SimpleNamespace`; a namedtuple now - round-trips as a JSON array). To get the original type back, pass the new - `return_type` / `data_type` arguments, annotate the consuming function's - parameter or return type, or use the typed client accessors. Payloads produced - by older SDK versions still deserialize — including into a `SimpleNamespace` - when no type is supplied — so in-flight orchestrations continue to replay - across an upgrade. + round-trips as a JSON array). To get the original type back, supply a type via + one of the mechanisms above. Payloads produced by older SDK versions still + deserialize — including into a `SimpleNamespace` when no type is supplied — so + in-flight orchestrations continue to replay across an upgrade. - JSON serialization failures now raise a `TypeError` that chains the original error (`__cause__`) and names the offending type. +- `EntityContext.get_state()` / `DurableEntity.get_state()` now return a freshly + reconstructed value on every call rather than a reference to a single cached + object. This changes v1.6.0 behavior: mutating the returned value in place no + longer affects persisted state — write it back with `set_state()`. State is + also serialized eagerly at `set_state()` time, so a non-serializable value + fails inside the operation (which rolls back) instead of after the batch has + run. FIXED @@ -68,19 +58,31 @@ FIXED "no state" and written as `None`, effectively deleting it; only an actual `None` state now clears the persisted entity state. -BREAKING CHANGES (type-level only — no runtime impact for typical users) +DEPRECATED + +- `durabletask.internal.shared.to_json` and `durabletask.internal.shared.from_json` + are deprecated and now emit a `DeprecationWarning`. Use a + `durabletask.serialization.DataConverter` (for example the default + `JsonDataConverter`) instead. The functions continue to work for backwards + compatibility. + +BREAKING CHANGES (no runtime impact for typical users) -These changes do not alter runtime behavior, but because the package ships -`py.typed`, consumers running strict type checkers (pyright/mypy) — or -subclassing the public abstract types — may need to update their code: +Most of these are type-level only: because the package ships `py.typed`, +consumers running strict type checkers (pyright/mypy) — or subclassing the +public abstract types — may need to update their code. The constructor change +below also affects callers who *directly* construct the named classes, which is +uncommon since they are normally handed to you by the SDK. - `OrchestrationContext.call_activity`, `call_sub_orchestrator`, `call_entity`, and `wait_for_external_event` gained new keyword-only parameters (`return_type` / `data_type`). Subclasses overriding these methods should add the parameter to match the base signature. -- `client.OrchestrationState` gained a non-public `_data_converter` field - (excluded from equality and `repr`). Code constructing `OrchestrationState` - positionally should pass it via the new field or rely on its default. +- `EntityContext` and `EntityMetadata` (and its `from_entity_metadata` / + `from_entity_response` factories) now require a `data_converter` argument. + These objects are normally constructed by the SDK — you receive an + `EntityContext` in an entity function and an `EntityMetadata` from the client — + so this only affects code that constructs them directly. ## v1.6.0 diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 005cb666..e9eb2b09 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -7,7 +7,12 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased -N/A +ADDED + +- `DurableTaskSchedulerWorker`, `DurableTaskSchedulerClient`, and the async + client now accept a `data_converter` argument and forward it to the base + worker/client, so a custom `durabletask.serialization.DataConverter` (for + example a pydantic-backed one) can be used with the Durable Task Scheduler. ## v1.6.0 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 308c3411..fa3875f3 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -20,6 +20,7 @@ ) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore +from durabletask.serialization import DataConverter # Client class used for Durable Task Scheduler (DTS) @@ -35,6 +36,7 @@ def __init__(self, *, resiliency_options: GrpcClientResiliencyOptions | None = None, default_version: str | None = None, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None): @@ -59,7 +61,8 @@ def __init__(self, *, channel_options=channel_options, resiliency_options=resiliency_options, default_version=default_version, - payload_store=payload_store) + payload_store=payload_store, + data_converter=data_converter) # Async client class used for Durable Task Scheduler (DTS) @@ -113,6 +116,7 @@ def __init__(self, *, resiliency_options: GrpcClientResiliencyOptions | None = None, default_version: str | None = None, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None): @@ -137,4 +141,5 @@ def __init__(self, *, channel_options=channel_options, resiliency_options=resiliency_options, default_version=default_version, - payload_store=payload_store) + payload_store=payload_store, + data_converter=data_converter) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 8acbad51..c27ebc1f 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -18,6 +18,7 @@ ) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore +from durabletask.serialization import DataConverter from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker @@ -81,6 +82,7 @@ def __init__(self, *, resiliency_options: GrpcWorkerResiliencyOptions | None = None, concurrency_options: ConcurrencyOptions | None = None, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None): @@ -110,5 +112,6 @@ def __init__(self, *, concurrency_options=concurrency_options, # DTS natively supports long timers so chunking is unnecessary maximum_timer_interval=None, - payload_store=payload_store + payload_store=payload_store, + data_converter=data_converter ) diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index c5435a20..7a44c22e 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -16,14 +16,11 @@ class EntityContext: def __init__(self, orchestration_id: str, operation: str, state: StateShim, - entity_id: EntityInstanceId, data_converter: "DataConverter | None" = None): + entity_id: EntityInstanceId, data_converter: "DataConverter"): self._orchestration_id = orchestration_id self._operation = operation self._state = state self._entity_id = entity_id - if data_converter is None: - from durabletask.serialization import JsonDataConverter - data_converter = JsonDataConverter() self._data_converter = data_converter @property diff --git a/durabletask/entities/entity_metadata.py b/durabletask/entities/entity_metadata.py index 37c437e7..91e329eb 100644 --- a/durabletask/entities/entity_metadata.py +++ b/durabletask/entities/entity_metadata.py @@ -36,7 +36,7 @@ def __init__(self, locked_by: str, includes_state: bool, state: Any | None, - data_converter: "DataConverter | None" = None): + data_converter: "DataConverter"): """Initializes a new instance of the EntityMetadata class. Args: @@ -48,20 +48,17 @@ def __init__(self, self._locked_by = locked_by self.includes_state = includes_state self._state = state - if data_converter is None: - from durabletask.serialization import JsonDataConverter - data_converter = JsonDataConverter() self._data_converter = data_converter @staticmethod def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool, - data_converter: "DataConverter | None" = None): + data_converter: "DataConverter"): return EntityMetadata.from_entity_metadata( entity_response.entity, includes_state, data_converter) @staticmethod def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool, - data_converter: "DataConverter | None" = None): + data_converter: "DataConverter"): try: entity_id = EntityInstanceId.parse(entity.instanceId) except ValueError: diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 99f2801e..b18d3d63 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -12,15 +12,49 @@ class StateShim: - def __init__(self, start_state: Any, data_converter: "DataConverter | None" = None): - self._current_state: Any = start_state - self._checkpoint_state: Any = start_state + """In-memory view of an entity's state during a batch. + + The state is held internally as its serialized JSON string at all times. + The raw payload off the wire is stored verbatim; a live value supplied via + :meth:`set_state` (or as a non-serialized constructor argument) is + serialized immediately. Keeping a single, always-serialized representation + has two consequences worth noting: + + * Deserialization is deferred to :meth:`get_state`, so the caller's + requested type reaches the data converter together with the original + payload (a custom converter can deserialize the string directly into the + target type), and the unmodified wire payload is handed back by + :meth:`encode_state` without being re-encoded. + * Serialization errors surface inside the failing operation (at + :meth:`set_state`) rather than after the batch has run, so a bad write + rolls back just that operation. + + Because the held value is always the serialized form, :meth:`get_state` + returns a freshly reconstructed object on every call; it does **not** return + a reference to a stored live object. Mutating a value read from + :meth:`get_state` therefore has no effect on the persisted state unless it + is written back with :meth:`set_state`. + """ + + def __init__(self, start_state: Any, data_converter: "DataConverter", + *, is_serialized: bool = False): + self._data_converter = data_converter + # The state is normalized to its serialized string form. ``is_serialized`` + # marks ``start_state`` as a raw payload already off the wire (stored + # verbatim); otherwise a live value is serialized now. ``None`` stays + # ``None`` (no persisted state). + serialized_start = self._serialize(start_state, is_serialized) + self._current_state: str | None = serialized_start + self._checkpoint_state: str | None = serialized_start self._operation_actions: list[pb.OperationAction] = [] self._actions_checkpoint_state: int = 0 - if data_converter is None: - from durabletask.serialization import JsonDataConverter - data_converter = JsonDataConverter() - self._data_converter = data_converter + + def _serialize(self, state: Any, is_serialized: bool = False) -> str | None: + if state is None: + return None + if is_serialized: + return state + return self._data_converter.serialize(state) @overload def get_state(self, intended_type: type[TState], default: TState) -> TState: @@ -35,13 +69,14 @@ def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... def get_state(self, intended_type: type[TState] | None = None, default: TState | None = None) -> TState | Any | None: - if self._current_state is None and default is not None: + if self._current_state is None: return default + # Deferred deserialization: the converter receives the raw payload + # together with the requested type. if intended_type is None: - return self._current_state - - coerced = self._data_converter.coerce(self._current_state, intended_type) + return self._data_converter.deserialize(self._current_state) + result = self._data_converter.deserialize(self._current_state, intended_type) # An explicit ``intended_type`` is a request to receive that type. The # default converter is best-effort and would silently return the raw @@ -49,17 +84,28 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | # raising when a non-None state could not be coerced to a concrete type. # ``intended_type`` may be a typing generic (e.g. ``list[int]``) at # runtime, which is not a ``type`` instance, so the guard is required. - if (self._current_state is not None - and isinstance(intended_type, type) # pyright: ignore[reportUnnecessaryIsInstance] - and not isinstance(coerced, intended_type)): + if (isinstance(intended_type, type) # pyright: ignore[reportUnnecessaryIsInstance] + and not isinstance(result, intended_type)): raise TypeError( - f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" + f"Could not convert state of type '{type(result).__name__}' to '{intended_type.__name__}'" ) - return coerced + return result def set_state(self, state: Any) -> None: - self._current_state = state + # Serialize eagerly so the held value is always the wire form and any + # serialization error surfaces here, inside the failing operation. + self._current_state = self._serialize(state) + + def encode_state(self) -> str | None: + """Return the serialized current state for persistence back to the wire. + + The state is already held in serialized form, so this is the stored + value verbatim: ``None`` when there is no state (which clears the + persisted entity state), otherwise the JSON string. No re-encoding + occurs, so a payload that was never modified round-trips unchanged. + """ + return self._current_state def add_operation_action(self, action: pb.OperationAction) -> None: self._operation_actions.append(action) diff --git a/durabletask/internal/json_codec.py b/durabletask/internal/json_codec.py deleted file mode 100644 index 8fda0eae..00000000 --- a/durabletask/internal/json_codec.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Internal JSON codec for Durable Task payloads. - -This module holds the low-level serialization *mechanism* -- the JSON string -encode/decode primitives and the value-level type coercion used to reconstruct -custom objects. Serialization *policy* (the public, pluggable strategy) lives in -:mod:`durabletask.serialization`; the default ``JsonDataConverter`` is the only -production consumer of ``to_json`` / ``from_json``, while ``coerce_to_type`` is -also used directly by entity state accessors that already hold a parsed value. -""" - -from __future__ import annotations - -import dataclasses -import json -import types -import typing -from collections.abc import Sequence -from types import SimpleNamespace -from typing import Any, cast - -# Marker formerly added to JSON payloads to flag objects for automatic -# deserialization into a SimpleNamespace. New code no longer emits this marker -# (objects are serialized as plain JSON), but the decoder still recognizes it so -# that orchestration histories produced by older SDK versions continue to replay. -AUTO_SERIALIZED = "__durabletask_autoobject__" - - -def to_json(obj: Any) -> str: - """Serialize a value to a JSON string. - - Builtins serialize to plain JSON. Dataclasses, ``SimpleNamespace`` - instances, and objects exposing a ``to_json()`` method are serialized to - plain JSON as well (without any type marker); custom objects can be - reconstructed on the receiving side by passing ``expected_type`` to - :func:`from_json`. - """ - try: - return json.dumps(obj, default=_encode_custom_object) - except TypeError as e: - # Preserve the original error as the cause so serialization failures are - # easier to diagnose, while naming the offending top-level type. - raise TypeError( - f"Failed to serialize object of type '{type(obj).__name__}' to JSON: {e}" - ) from e - - -def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: - """Deserialize a JSON string, optionally coercing the result to a type. - - When ``expected_type`` is ``None`` (the default) the raw parsed JSON is - returned. For backwards compatibility, payloads carrying the legacy - :data:`AUTO_SERIALIZED` marker are reconstructed as ``SimpleNamespace`` - instances so that in-flight orchestrations produced by older SDK versions - continue to replay. - - When ``expected_type`` is provided, the legacy marker (if present) is - stripped and the parsed value is coerced to ``expected_type`` -- dataclasses - are constructed from their dict payloads, types exposing a ``from_json()`` - classmethod are reconstructed via that hook, and ``Optional``/``Union`` and - ``list`` type hints are honored recursively. The destination type is always - supplied by the caller; it is never read from the payload. - """ - if expected_type is None: - return json.loads(json_str, object_hook=_legacy_object_hook) - raw = json.loads(json_str, object_hook=_strip_legacy_marker) - return coerce_to_type(raw, expected_type) - - -def _encode_custom_object(o: Any) -> Any: - """``default`` hook for :func:`json.dumps` that emits plain JSON. - - Called only for values the JSON encoder cannot natively serialize. Note that - namedtuples are handled natively by the encoder (serialized as JSON arrays) - and never reach this hook. - """ - if dataclasses.is_dataclass(o) and not isinstance(o, type): - return dataclasses.asdict(o) - if isinstance(o, SimpleNamespace): - return vars(o) - # Custom objects may opt in via a ``to_json`` hook. It is resolved off the - # type and called with the instance (``type(o).to_json(o)``) so that both - # instance methods and ``@staticmethod`` hooks work -- matching the calling - # convention used by ``azure-functions-durable``. The hook returns a - # JSON-serializable value (a structure or a string), not a JSON document. - to_json_hook = getattr(cast(Any, type(o)), "to_json", None) - if callable(to_json_hook): - return to_json_hook(o) - # This will raise a TypeError describing the unsupported type. - raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") - - -def _legacy_object_hook(d: dict[str, Any]) -> Any: - # If the object carries the legacy marker, deserialize it as a SimpleNamespace. - if d.pop(AUTO_SERIALIZED, False): - return SimpleNamespace(**d) - return d - - -def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: - # Discard the legacy marker so typed coercion sees a plain dict. - d.pop(AUTO_SERIALIZED, None) - return d - - -def coerce_to_type(value: Any, expected_type: Any) -> Any: - """Coerce an already-parsed JSON value to ``expected_type``. - - Handles ``None``/``Optional``/``Union`` and ``list`` type hints recursively, - types exposing a ``from_json()`` classmethod, and dataclasses (including - nested dataclass fields). The destination type is always caller-supplied and - never derived from the payload, keeping deserialization secure. - """ - if expected_type is None or value is None: - return value - - origin = typing.get_origin(expected_type) - if origin is not None: - return _coerce_generic(value, expected_type, origin) - - if not isinstance(expected_type, type): - # Not a concrete, instantiable type (e.g. a typing special form we don't - # special-case) -- return the value unchanged. - return value - - if isinstance(value, expected_type): - return value - - from_json_hook = getattr(expected_type, "from_json", None) - if callable(from_json_hook): - return from_json_hook(value) - - if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): - return _build_dataclass(expected_type, cast(dict[str, Any], value)) - - type_ctor = cast(Any, expected_type) - try: - return type_ctor(value) - except Exception as e: - type_name = getattr(type_ctor, "__name__", None) or str(type_ctor) - raise TypeError( - f"Could not coerce value of type '{type(value).__name__}' to " - f"'{type_name}'" - ) from e - - -def _coerce_generic(value: Any, expected_type: Any, origin: Any) -> Any: - args = typing.get_args(expected_type) - if origin is typing.Union or origin is types.UnionType: - # If the value already matches a member type, keep it as-is. - non_none = [a for a in args if a is not type(None)] - for arg in non_none: - if isinstance(arg, type) and isinstance(value, arg): - return value - # ``Optional[T]`` (exactly one non-None member): coerce to that member. - # For a genuine multi-member ``Union`` where the value matched none of - # the members, leave it untouched rather than guessing the first arg -- - # forcing a coercion there can silently mis-construct the wrong type. - if len(non_none) == 1: - return coerce_to_type(value, non_none[0]) - return value - if origin in (list, Sequence) and isinstance(value, list): - elem_type = args[0] if args else None - return [coerce_to_type(item, elem_type) for item in cast(list[Any], value)] - # Other generics (dict, tuple, ...) are returned as parsed JSON. - return value - - -def _build_dataclass(cls: Any, data: dict[str, Any]) -> Any: - """Construct a dataclass from its dict payload, recursing into typed fields.""" - try: - hints = typing.get_type_hints(cls) - except Exception: - hints = {} - kwargs: dict[str, Any] = {} - for field in dataclasses.fields(cls): - if field.name not in data: - continue - field_type = hints.get(field.name) - kwargs[field.name] = coerce_to_type(data[field.name], field_type) - return cls(**kwargs) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 9ef136d8..3708e498 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -2,22 +2,58 @@ # Licensed under the MIT License. import logging +import warnings from collections.abc import Sequence -from typing import TypeAlias +from typing import Any, TypeAlias import grpc import grpc.aio -# Backwards-compatibility re-exports. The JSON codec moved to -# ``durabletask.internal.json_codec``; these aliases keep older imports from -# ``durabletask.internal.shared`` working. -from durabletask.internal.json_codec import ( # noqa: F401 - AUTO_SERIALIZED as AUTO_SERIALIZED, - from_json as from_json, - to_json as to_json, -) +# Backwards-compatibility shims. The JSON codec moved into +# ``durabletask.serialization`` and its functions are now private; the supported +# surface is the pluggable ``DataConverter`` (and the default +# ``JsonDataConverter``). These thin wrappers keep older imports from +# ``durabletask.internal.shared`` working while steering callers to the new API. +# They deliberately reach into the now-private serialization mechanism. +from durabletask import serialization as _serialization from durabletask.grpc_options import GrpcChannelOptions +# Legacy marker constant, re-exported for backwards compatibility. +AUTO_SERIALIZED = _serialization._AUTO_SERIALIZED # pyright: ignore[reportPrivateUsage] + +_SERIALIZATION_DEPRECATION = ( + "durabletask.internal.shared.{name} is deprecated and will be removed in a " + "future release. Use a durabletask.serialization.DataConverter (e.g. the " + "default JsonDataConverter) instead." +) + + +def to_json(obj: Any) -> str: + """Deprecated. Use a ``durabletask.serialization.DataConverter`` instead.""" + warnings.warn( + _SERIALIZATION_DEPRECATION.format(name="to_json"), + DeprecationWarning, + stacklevel=2, + ) + return _serialization._to_json(obj) # pyright: ignore[reportPrivateUsage] + + +def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: + """Deprecated. Use a ``durabletask.serialization.DataConverter`` instead. + + This legacy shim does not thread a ``DataConverter`` into reconstruction, so + a converter-aware ``from_json(cls, value, converter)`` hook is invoked + without the converter (its single-argument form). Call + ``JsonDataConverter().deserialize(...)`` to get the converter-aware path. + """ + warnings.warn( + _SERIALIZATION_DEPRECATION.format(name="from_json"), + DeprecationWarning, + stacklevel=2, + ) + return _serialization._from_json(json_str, expected_type) # pyright: ignore[reportPrivateUsage] + + ClientInterceptor: TypeAlias = ( grpc.UnaryUnaryClientInterceptor | grpc.UnaryStreamClientInterceptor diff --git a/durabletask/internal/type_discovery.py b/durabletask/internal/type_discovery.py index 58fd6f83..942ac8b1 100644 --- a/durabletask/internal/type_discovery.py +++ b/durabletask/internal/type_discovery.py @@ -5,15 +5,18 @@ These helpers resolve the annotation of the *input* parameter of an orchestrator, activity, or entity function so that inbound payloads can be -reconstructed into the annotated custom type (a dataclass or a type exposing a -``from_json()`` classmethod) without the caller having to pass an explicit type. +reconstructed into the annotated custom type without the caller having to pass +an explicit type. Discovery is intentionally conservative: it only returns an annotation when the -target is a *reconstructable* custom type (a dataclass, a ``from_json()``-capable -type, or an ``Optional`` / ``list`` wrapping one). Primitive and unknown -annotations resolve to ``None`` so that existing payloads are passed through -unchanged -- inbound type discovery never invokes an arbitrary constructor on -untrusted data, and never alters the value for builtins. +active :class:`~durabletask.serialization.DataConverter` reports it as +*reconstructable* via :meth:`DataConverter.can_reconstruct`. The default +converter recognizes dataclasses, ``from_json()``-capable types, and ``Optional`` +/ ``list`` hints wrapping them; a custom converter can recognize its own types +(e.g. ``pydantic.BaseModel``). Primitive and unknown annotations resolve to +``None`` so that existing payloads are passed through unchanged -- inbound type +discovery never invokes an arbitrary constructor on untrusted data, and never +alters the value for builtins. All public helpers swallow exceptions and return ``None`` on failure; the caller treats ``None`` as "no type information available" and uses the raw payload. @@ -21,38 +24,17 @@ from __future__ import annotations -import collections.abc -import dataclasses import functools import inspect -import types import typing -from typing import Any, Callable, cast +from typing import Any, Callable +from durabletask.serialization import DEFAULT_DATA_CONVERTER, DataConverter -def is_reconstructable(annotation: Any) -> bool: - """Return True if ``annotation`` names a custom type we can rebuild. - Reconstructable targets are dataclasses, types exposing a callable - ``from_json``, and ``Optional`` / ``list`` hints wrapping such types. - Builtins (``int``, ``str``, ``dict``, ...) and unknown annotations are not - reconstructable and resolve to ``False``. - """ - origin = typing.get_origin(annotation) - if origin is not None: - args = typing.get_args(annotation) - if origin is typing.Union or origin is types.UnionType: - return any( - is_reconstructable(a) for a in args if a is not type(None) - ) - if origin in (list, collections.abc.Sequence): - return any(is_reconstructable(a) for a in args) - return False - if not isinstance(annotation, type): - return False - if dataclasses.is_dataclass(annotation): - return True - return callable(getattr(cast(Any, annotation), "from_json", None)) +def _resolve_converter(converter: DataConverter | None) -> DataConverter: + """Return the supplied converter, or the shared default when ``None``.""" + return converter if converter is not None else DEFAULT_DATA_CONVERTER # Bounded so a worker that registers dynamically-created functions or closures @@ -72,14 +54,15 @@ def _resolved_hints(fn: Callable[..., Any]) -> dict[str, Any] | None: return None -def _input_annotation(fn: Callable[..., Any], position: int) -> Any | None: +def _input_annotation(fn: Callable[..., Any], position: int, + converter: DataConverter | None = None) -> Any | None: """Return the resolved annotation of the positional parameter at ``position``. ``position`` is the zero-based index among positional parameters (so the ``input`` parameter of a ``(ctx, input)`` function is at position 1, and the ``input`` parameter of an unbound ``(self, input)`` entity method is also at position 1). Returns ``None`` when the parameter is absent, unannotated, or - its annotation is not a reconstructable custom type. + its annotation is not reconstructable by ``converter``. """ try: sig = inspect.signature(fn) @@ -105,27 +88,29 @@ def _input_annotation(fn: Callable[..., Any], position: int) -> Any | None: if annotation is inspect.Parameter.empty or annotation is Any: return None - return annotation if is_reconstructable(annotation) else None + return annotation if _resolve_converter(converter).can_reconstruct(annotation) else None -def orchestrator_input_type(fn: Callable[..., Any]) -> Any | None: +def orchestrator_input_type(fn: Callable[..., Any], + converter: DataConverter | None = None) -> Any | None: """Discover the input type of an orchestrator function ``(ctx, input)``.""" - return _input_annotation(fn, 1) + return _input_annotation(fn, 1, converter) -def activity_input_type(fn: Callable[..., Any]) -> Any | None: +def activity_input_type(fn: Callable[..., Any], + converter: DataConverter | None = None) -> Any | None: """Discover the input type of an activity function ``(ctx, input)``.""" - return _input_annotation(fn, 1) + return _input_annotation(fn, 1, converter) -def activity_output_type(fn: Any) -> Any | None: +def activity_output_type(fn: Any, converter: DataConverter | None = None) -> Any | None: """Discover the return type of an activity function. - Returns the resolved return annotation when it names a reconstructable - custom type (a dataclass or a ``from_json()``-capable type, optionally - wrapped in ``Optional`` / ``list``). Returns ``None`` for plain callables - that are not annotated with such a type, for string activity names, or when - the annotation cannot be resolved. + Returns the resolved return annotation when ``converter`` reports it as + reconstructable (the default converter recognizes a dataclass or a + ``from_json()``-capable type, optionally wrapped in ``Optional`` / ``list``). + Returns ``None`` for plain callables that are not annotated with such a type, + for string activity names, or when the annotation cannot be resolved. """ if not callable(fn): return None @@ -144,10 +129,11 @@ def activity_output_type(fn: Any) -> Any | None: if annotation is inspect.Signature.empty or annotation is Any or annotation is None: return None - return annotation if is_reconstructable(annotation) else None + return annotation if _resolve_converter(converter).can_reconstruct(annotation) else None -def entity_input_type(fn: Any, operation: str) -> Any | None: +def entity_input_type(fn: Any, operation: str, + converter: DataConverter | None = None) -> Any | None: """Discover the input type of an entity operation. For class-based entities (a ``DurableEntity`` subclass) the operation is a @@ -160,5 +146,5 @@ def entity_input_type(fn: Any, operation: str) -> Any | None: if method is None or not callable(method): return None # Unbound method includes ``self`` at position 0, so ``input`` is at 1. - return _input_annotation(method, 1) - return _input_annotation(fn, 1) + return _input_annotation(method, 1, converter) + return _input_annotation(fn, 1, converter) diff --git a/durabletask/serialization.py b/durabletask/serialization.py index b1b469b4..5d1dad89 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -10,10 +10,10 @@ values become JSON on the wire and how they are reconstructed on the way back. The default :class:`JsonDataConverter` preserves the SDK's built-in behavior: -builtins serialize as plain JSON, dataclasses / ``SimpleNamespace`` instances -and objects exposing a ``to_json()`` hook serialize to plain JSON structures, -and a caller-supplied ``target_type`` drives reconstruction on the read side -(the destination type is never read from the payload). +builtins serialize as plain JSON; objects exposing a ``to_json()`` hook, then +dataclasses, then ``SimpleNamespace`` instances serialize to plain JSON +structures; and a caller-supplied ``target_type`` drives reconstruction on the +read side (the destination type is never read from the payload). To customize serialization -- for example to validate with pydantic, encode custom ``datetime`` / ``Decimal`` formats, or integrate another model framework @@ -22,18 +22,40 @@ converter = MyDataConverter() worker = TaskHubGrpcWorker(data_converter=converter) client = TaskHubGrpcClient(data_converter=converter) + +This module is the single home for both serialization *policy* (the public, +pluggable :class:`DataConverter` strategy) and the low-level JSON codec +*mechanism* (the private ``_to_json`` / ``_from_json`` / ``_coerce_to_type`` +helpers). The mechanism is intentionally private: the supported, stable surface +is the :class:`DataConverter` abstraction. """ from __future__ import annotations +import dataclasses +import enum +import functools +import inspect +import json import logging +import sys +import types +import typing from abc import ABC, abstractmethod -from typing import Any - -from durabletask.internal import json_codec +from collections.abc import Mapping, Sequence +from types import SimpleNamespace +from typing import Any, cast logger = logging.getLogger("durabletask") +# Marker formerly added to JSON payloads to flag objects for automatic +# deserialization into a SimpleNamespace. New code no longer emits this marker +# (objects are serialized as plain JSON), but the decoder still recognizes it so +# that orchestration histories produced by older SDK versions continue to +# replay. Internal detail; re-exported from ``durabletask.internal.shared`` only +# for backwards compatibility. +_AUTO_SERIALIZED = "__durabletask_autoobject__" + class DataConverter(ABC): """Strategy for serializing and deserializing Durable Task payloads. @@ -79,6 +101,27 @@ def coerce(self, value: Any, target_type: type | None = None) -> Any: """ ... + def can_reconstruct(self, target_type: Any) -> bool: + """Return True if this converter can rebuild ``target_type`` from a payload. + + Inbound type-discovery calls this to decide whether a function's + annotated *input* type (or an activity's *return* annotation) should be + passed to :meth:`deserialize` / :meth:`coerce`. When it returns ``False`` + the SDK passes the raw deserialized payload through unchanged -- this + gate is what stops the SDK from invoking reconstruction on a type the + converter does not actually handle. + + The base implementation is conservative and returns ``False``: a + converter makes no reconstruction claims unless it opts in. + :class:`JsonDataConverter` overrides this to recognize the types its + codec can rebuild (dataclasses and ``from_json()``-capable types, plus + ``Optional`` / ``list`` / ``Sequence`` hints wrapping them). Override + this in a custom converter to teach the SDK about its own types (for + example ``pydantic.BaseModel`` subclasses) so that inputs annotated with + them are reconstructed instead of arriving as raw JSON. + """ + return False + class JsonDataConverter(DataConverter): """Default :class:`DataConverter` backed by the SDK's JSON codec. @@ -89,40 +132,61 @@ class JsonDataConverter(DataConverter): classmethod used during type-directed reconstruction. This matches the ``to_json`` / ``from_json`` convention used by ``azure-functions-durable``. + The ``to_json`` hook takes precedence over the built-in dataclass / + ``SimpleNamespace`` handling, and nested values are encoded recursively, so + a dataclass (or any object) with a ``to_json`` hook -- including one nested + inside another dataclass, ``list``, or ``dict`` -- round-trips through its + own hooks. + Deserialization (and value-level :meth:`coerce`) is **best-effort**: when a ``target_type`` is supplied and the value cannot be coerced to it, the raw value is returned (and a debug message is logged) rather than raising. This keeps the core SDK permissive; a stricter, validating converter can be supplied for callers who want coercion failures to surface as errors. + + > [!NOTE] + > Type-directed reconstruction recurses through dataclass fields, + > ``list``/``Sequence``, ``dict``/``Mapping`` values, ``tuple`` elements, + > and ``Optional`` / ``Union`` hints. A type that exposes a custom + > ``from_json()`` classmethod is responsible for reconstructing its own + > nested values. To help, the converter passes itself to ``from_json`` when + > the hook declares a second parameter -- ``from_json(cls, value, converter)`` + > -- so the hook can call ``converter.coerce(nested_value, NestedType)`` (or + > ``converter.deserialize(...)``) to reconstruct nested values the built-in + > recursion does not cover. The SDK never infers nested types from the + > payload; the destination type is always supplied by the caller. """ def serialize(self, value: Any) -> str | None: if value is None: return None - return json_codec.to_json(value) + return _to_json(value) def deserialize(self, data: str | None, target_type: type | None = None) -> Any: if data is None or data == "": return None if target_type is None: - return json_codec.from_json(data) + return _from_json(data) try: - return json_codec.from_json(data, target_type) + return _from_json(data, target_type, converter=self) except Exception as e: # Best-effort: fall back to the raw deserialized value rather than # failing the operation. Logged so the mismatch remains discoverable. self._log_coercion_fallback(target_type, e) - return json_codec.from_json(data) + return _from_json(data) def coerce(self, value: Any, target_type: type | None = None) -> Any: if target_type is None or value is None: return value try: - return json_codec.coerce_to_type(value, target_type) + return _coerce_to_type(value, target_type, converter=self) except Exception as e: self._log_coercion_fallback(target_type, e) return value + def can_reconstruct(self, target_type: Any) -> bool: + return _can_reconstruct(self, target_type) + @staticmethod def _log_coercion_fallback(target_type: type, error: Exception) -> None: logger.debug( @@ -134,3 +198,362 @@ def _log_coercion_fallback(target_type: type, error: Exception) -> None: # Shared default instance used when no converter is supplied. DEFAULT_DATA_CONVERTER: DataConverter = JsonDataConverter() + + +# --------------------------------------------------------------------------- +# Private JSON codec mechanism. +# +# These are the low-level encode/decode primitives and the value-level type +# coercion used to reconstruct custom objects. They are deliberately private: +# the supported, pluggable surface is :class:`DataConverter`. ``_coerce_to_type`` +# is also used internally by entity state accessors that already hold a parsed +# value. +# --------------------------------------------------------------------------- + + +def _can_reconstruct(converter: DataConverter, target_type: Any) -> bool: + """:class:`JsonDataConverter`'s reconstruction policy. + + Recognizes dataclasses and ``from_json()``-capable types, plus ``Optional`` + / ``list`` / ``Sequence`` hints wrapping them; builtins and unknown + annotations are excluded. Recurses through ``converter.can_reconstruct`` + (not itself) so a :class:`JsonDataConverter` subclass that overrides + ``can_reconstruct`` still participates in the element-type checks of + ``Optional`` / ``list`` hints. + """ + origin = typing.get_origin(target_type) + if origin is not None: + args = typing.get_args(target_type) + if origin is typing.Union or origin is types.UnionType: + return any( + converter.can_reconstruct(a) for a in args if a is not type(None) + ) + if origin in (list, Sequence): + return any(converter.can_reconstruct(a) for a in args) + return False + if not isinstance(target_type, type): + return False + if dataclasses.is_dataclass(target_type): + return True + return callable(getattr(cast(Any, target_type), "from_json", None)) + + +def _to_json(obj: Any) -> str: + """Serialize a value to a JSON string. + + Builtins serialize to plain JSON. Objects exposing a ``to_json()`` method, + dataclasses, and ``SimpleNamespace`` instances are serialized to plain JSON + as well (without any type marker), recursively. Custom objects can be + reconstructed on the receiving side by passing ``expected_type`` to + :func:`_from_json`. + """ + try: + return json.dumps(obj, default=_encode_custom_object) + except TypeError as e: + # Preserve the original error as the cause so serialization failures are + # easier to diagnose, while naming the offending top-level type. + raise TypeError( + f"Failed to serialize object of type '{type(obj).__name__}' to JSON: {e}" + ) from e + + +def _from_json( + json_str: str | bytes | bytearray, + expected_type: type | None = None, + converter: DataConverter | None = None, +) -> Any: + """Deserialize a JSON string, optionally coercing the result to a type. + + When ``expected_type`` is ``None`` (the default) the raw parsed JSON is + returned. For backwards compatibility, payloads carrying the legacy + :data:`_AUTO_SERIALIZED` marker are reconstructed as ``SimpleNamespace`` + instances so that in-flight orchestrations produced by older SDK versions + continue to replay. + + When ``expected_type`` is provided, the legacy marker (if present) is + stripped and the parsed value is coerced to ``expected_type`` -- dataclasses + are constructed from their dict payloads, types exposing a ``from_json()`` + classmethod are reconstructed via that hook, and ``Optional``/``Union``, + ``list``, ``dict``, and ``tuple`` type hints are honored recursively. The + destination type is always supplied by the caller; it is never read from the + payload. + + ``converter`` is threaded through to ``from_json`` hooks that declare a + second parameter, letting them recursively reconstruct nested values. + """ + if expected_type is None: + return json.loads(json_str, object_hook=_legacy_object_hook) + raw = json.loads(json_str, object_hook=_strip_legacy_marker) + return _coerce_to_type(raw, expected_type, converter) + + +def _encode_custom_object(o: Any) -> Any: + """``default`` hook for :func:`json.dumps` that emits plain JSON. + + Called only for values the JSON encoder cannot natively serialize. Note that + namedtuples are handled natively by the encoder (serialized as JSON arrays) + and never reach this hook. + """ + # Custom objects may opt in via a ``to_json`` hook. It is resolved off the + # type and called with the instance (``type(o).to_json(o)``) so that both + # instance methods and ``@staticmethod`` hooks work -- matching the calling + # convention used by ``azure-functions-durable``. The hook returns a + # JSON-serializable value (a structure or a string), not a JSON document. + # + # The hook is checked *before* the dataclass / ``SimpleNamespace`` branches + # so a dataclass (or namespace) that needs custom handling -- for example + # one with a field that is not JSON-serializable on its own -- can opt in. + # Resolving off the type (rather than the instance) avoids mistaking a data + # attribute named ``to_json`` for a hook, and mirrors the decode side, which + # prefers ``from_json`` over the dataclass branch. + to_json_hook = getattr(cast(Any, type(o)), "to_json", None) + if callable(to_json_hook): + return to_json_hook(o) + if isinstance(o, enum.Enum): + # Emit the member's underlying value (a primitive). Reconstruction is + # type-directed: passing the enum type to ``deserialize`` rebuilds the + # member via ``EnumType(value)``. ``IntEnum`` / ``IntFlag`` members are + # ints and serialize natively without reaching this hook; this branch + # covers string- and other-valued enums so they round-trip too. Emitting + # only the value (never the member name or type) keeps the wire format a + # plain primitive and avoids leaking the Python type into the payload. + return o.value + if dataclasses.is_dataclass(o) and not isinstance(o, type): + # Return a *shallow* mapping of the dataclass's fields rather than + # ``dataclasses.asdict``. asdict recursively converts nested dataclasses + # to plain dicts (bypassing their ``to_json`` hooks) and deep-copies + # every leaf value. Emitting the live field values lets ``json.dumps`` + # recurse so each nested value is encoded through this same hook, + # honoring nested ``to_json`` hooks and avoiding the deep copy. + return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} + if isinstance(o, SimpleNamespace): + return vars(o) + # This will raise a TypeError describing the unsupported type. + raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") + + +def _legacy_object_hook(d: dict[str, Any]) -> Any: + # If the object carries the legacy marker, deserialize it as a SimpleNamespace. + if d.pop(_AUTO_SERIALIZED, False): + return SimpleNamespace(**d) + return d + + +def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: + # Discard the legacy marker so typed coercion sees a plain dict. + d.pop(_AUTO_SERIALIZED, None) + return d + + +def _coerce_to_type(value: Any, expected_type: Any, converter: DataConverter | None = None) -> Any: + """Coerce an already-parsed JSON value to ``expected_type``. + + Handles ``None``/``Optional``/``Union``, ``list``, ``dict``, and ``tuple`` + type hints recursively, types exposing a ``from_json()`` classmethod, and + dataclasses (including nested dataclass fields). The destination type is + always caller-supplied and never derived from the payload, keeping + deserialization secure. + + ``converter`` is passed to ``from_json`` hooks that declare a second + parameter so they can recursively reconstruct nested values. + """ + if expected_type is None or value is None: + return value + + origin = typing.get_origin(expected_type) + if origin is not None: + return _coerce_generic(value, expected_type, origin, converter) + + if not isinstance(expected_type, type): + # Not a concrete, instantiable type (e.g. a typing special form we don't + # special-case) -- return the value unchanged. + return value + + if isinstance(value, expected_type): + return value + + from_json_hook = getattr(expected_type, "from_json", None) + if callable(from_json_hook): + return _invoke_from_json(from_json_hook, value, converter) + + if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): + return _build_dataclass(expected_type, cast(dict[str, Any], value), converter) + + type_ctor = cast(Any, expected_type) + try: + return type_ctor(value) + except Exception as e: + type_name = getattr(type_ctor, "__name__", None) or str(type_ctor) + raise TypeError( + f"Could not coerce value of type '{type(value).__name__}' to " + f"'{type_name}'" + ) from e + + +def _invoke_from_json(hook: Any, value: Any, converter: DataConverter | None) -> Any: + """Call a ``from_json`` hook, passing the converter when the hook accepts it. + + A hook declared as ``from_json(cls, value)`` (the common, single-argument + convention) is called with just the value. A hook declared as + ``from_json(cls, value, converter)`` additionally receives the active + :class:`DataConverter`, letting it reconstruct nested values recursively via + ``converter.coerce(...)`` / ``converter.deserialize(...)``. + + > [!NOTE] + > ``from_json`` must be a ``@classmethod`` or ``@staticmethod`` -- the hook + > is resolved off the *type* (no instance exists yet during reconstruction). + > A plain instance method would have ``self`` consume the value and is + > unsupported regardless of the converter detection below. + > + > The converter is passed positionally as the second argument, so a hook + > opts in only by naming that parameter exactly ``converter``. This reserved + > name avoids misreading an unrelated second parameter (e.g. + > ``from_json(cls, value, strict=False)``) as converter-aware. + """ + if converter is not None and _hook_accepts_converter(hook): + return hook(value, converter) + return hook(value) + + +@functools.lru_cache(maxsize=2048) +def _hook_accepts_converter(hook: Any) -> bool: + """Return True if a bound ``from_json`` hook opts in to receiving the converter. + + The hook is inspected as accessed off the type (``cls``/``self`` already + bound), so a classmethod ``from_json(cls, value, converter)`` presents as + ``(value, converter)``. A hook is treated as converter-aware only when its + second positional parameter is named exactly ``converter`` -- the reserved + name for this argument -- so an unrelated second parameter such as + ``strict=False`` is not mistaken for it. Results are cached because + reconstruction runs on hot paths; bound classmethods hash equal across + attribute accesses, so the cache stays effective and bounded. + """ + try: + sig = inspect.signature(hook) + except (TypeError, ValueError): + return False + positional = [ + param for param in sig.parameters.values() + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + return len(positional) >= 2 and positional[1].name == "converter" + + +def _coerce_generic(value: Any, expected_type: Any, origin: Any, + converter: DataConverter | None = None) -> Any: + args = typing.get_args(expected_type) + if origin is typing.Union or origin is types.UnionType: + # If the value already matches a member type, keep it as-is. + non_none = [a for a in args if a is not type(None)] + for arg in non_none: + if isinstance(arg, type) and isinstance(value, arg): + return value + # ``Optional[T]`` (exactly one non-None member): coerce to that member. + # For a genuine multi-member ``Union`` where the value matched none of + # the members, leave it untouched rather than guessing the first arg -- + # forcing a coercion there can silently mis-construct the wrong type. + if len(non_none) == 1: + return _coerce_to_type(value, non_none[0], converter) + return value + if origin in (list, Sequence) and isinstance(value, list): + elem_type = args[0] if args else None + return [_coerce_to_type(item, elem_type, converter) for item in cast(list[Any], value)] + if origin in (dict, Mapping) and isinstance(value, dict): + # JSON object keys are always strings, so only the values can carry a + # reconstructable type. Keys are passed through unchanged. + mapping = cast("dict[Any, Any]", value) + val_type = args[1] if len(args) == 2 else None + if val_type is None: + return mapping + return {k: _coerce_to_type(v, val_type, converter) for k, v in mapping.items()} + if origin is tuple and isinstance(value, list): + # Tuples serialize to JSON arrays, so the parsed value is a list. + if not args: + return tuple(cast(list[Any], value)) + # Homogeneous ``tuple[T, ...]``. + if len(args) == 2 and args[1] is Ellipsis: + return tuple(_coerce_to_type(item, args[0], converter) for item in cast(list[Any], value)) + # Fixed-length ``tuple[T1, T2, ...]`` -- coerce element-wise by position. + arr = cast(list[Any], value) + if len(arr) != len(args): + raise TypeError( + f"Could not coerce JSON array of length {len(arr)} to " + f"tuple of length {len(args)}" + ) + return tuple(_coerce_to_type(item, t, converter) for item, t in zip(arr, args)) + # Other generics are returned as parsed JSON. + return value + + +def _build_dataclass(cls: Any, data: dict[str, Any], + converter: DataConverter | None = None) -> Any: + """Construct a dataclass from its dict payload, recursing into typed fields.""" + try: + hints = typing.get_type_hints(cls) + except Exception: + hints = {} + globalns = _type_namespace(cls) + kwargs: dict[str, Any] = {} + for field in dataclasses.fields(cls): + if field.name not in data: + continue + # ``get_type_hints`` on Python 3.10 does not deep-resolve forward + # references nested inside container args (e.g. the ``"TreeNode"`` in + # ``list["TreeNode"]`` on a self-referential dataclass), leaving a bare + # string or ``ForwardRef`` that the coercion below would skip. Resolve + # them against the class's defining module so reconstruction behaves the + # same as it does on 3.11+. + field_type = _resolve_forward_refs(hints.get(field.name), globalns) + kwargs[field.name] = _coerce_to_type(data[field.name], field_type, converter) + return cls(**kwargs) + + +def _type_namespace(cls: Any) -> dict[str, Any]: + """Build the namespace used to resolve forward references in ``cls``'s hints. + + Forward references in a class's annotations are resolved against the + module in which the class is defined, plus the class's own name (so a + self-referential type like ``list["TreeNode"]`` resolves). + """ + module = sys.modules.get(getattr(cls, "__module__", None) or "") + ns: dict[str, Any] = dict(getattr(module, "__dict__", {})) + name = getattr(cls, "__name__", None) + if name: + ns.setdefault(name, cls) + return ns + + +def _resolve_forward_refs(tp: Any, globalns: dict[str, Any]) -> Any: + """Resolve string / ``ForwardRef`` leaves in a type hint, recursing into args. + + Returns ``tp`` unchanged when it (or a nested name) cannot be resolved, so an + unresolvable hint simply falls back to "leave the value as parsed JSON" + rather than raising. Only the supported generic shapes (``Union``, ``list``, + ``dict``, ``tuple``, etc.) are rebuilt; the destination type is still + entirely caller-supplied, so this does not weaken the security model. + """ + if isinstance(tp, str): + try: + tp = eval(tp, globalns) # noqa: S307 - resolves caller-authored annotations + except Exception: + return tp + elif isinstance(tp, typing.ForwardRef): + try: + tp = eval(tp.__forward_arg__, globalns) # noqa: S307 + except Exception: + return tp + + origin = typing.get_origin(tp) + if origin is None: + return tp + args = typing.get_args(tp) + if not args: + return tp + resolved = [_resolve_forward_refs(a, globalns) for a in args] + if origin is typing.Union or origin is types.UnionType: + return typing.Union[tuple(resolved)] + try: + return origin[tuple(resolved) if len(resolved) > 1 else resolved[0]] + except TypeError: + return tp diff --git a/durabletask/worker.py b/durabletask/worker.py index d3a3d86e..66381c97 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1317,8 +1317,9 @@ def _execute_entity_batch( payload_helpers.deexternalize_payloads(req, self._payload_store) entity_state = StateShim( - self._data_converter.deserialize(req.entityState.value) if req.entityState.value else None, - self._data_converter) + req.entityState.value if req.entityState.value else None, + self._data_converter, + is_serialized=True) instance_id = req.instanceId try: @@ -1376,7 +1377,7 @@ def _execute_entity_batch( batch_result = pb.EntityBatchResult( results=results, actions=entity_state.get_operation_actions(), - entityState=helpers.get_string_value(self._data_converter.serialize(entity_state._current_state)) if entity_state._current_state is not None else None, # pyright: ignore[reportPrivateUsage] + entityState=helpers.get_string_value(entity_state.encode_state()), failureDetails=None, completionToken=completionToken, operationInfos=operation_infos, @@ -1419,8 +1420,8 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext): def __init__(self, instance_id: str, registry: _Registry, + data_converter: DataConverter, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, - data_converter: DataConverter | None = None, ): self._generator = None self._is_replaying = True @@ -1449,7 +1450,7 @@ def __init__(self, self._parent_trace_context: pb.TraceContext | None = None self._orchestration_trace_context: pb.TraceContext | None = None self._maximum_timer_interval = maximum_timer_interval - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter def run(self, generator: Generator[task.Task[Any], Any, Any]) -> None: self._generator = generator @@ -1691,7 +1692,7 @@ def call_activity( # converter decides how a coercion failure is handled (the default is # best-effort). if return_type is None and not isinstance(activity, str): - return_type = type_discovery.activity_output_type(activity) + return_type = type_discovery.activity_output_type(activity, self._data_converter) self.call_activity_function_helper( id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags, @@ -2049,13 +2050,13 @@ def __init__( self, registry: _Registry, logger: logging.Logger, + data_converter: DataConverter, persisted_orch_span_id: str | None = None, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, - data_converter: DataConverter | None = None, ): self._registry = registry self._logger = logger - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter self._maximum_timer_interval = maximum_timer_interval self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] @@ -2210,7 +2211,7 @@ def process_event( if ( event.executionStarted.HasField("input") and event.executionStarted.input.value != "" ): - input_type = type_discovery.orchestrator_input_type(fn) + input_type = type_discovery.orchestrator_input_type(fn, self._data_converter) input = self._data_converter.deserialize( event.executionStarted.input.value, input_type) @@ -2772,12 +2773,12 @@ def _handle_entity_event_raised(self, if not ph.is_empty(event.eventRaised.input): # TODO: Investigate why the event result is wrapped in a dict with "result" key # The expected type applies to the unwrapped result value, not the - # transport wrapper. Unwrap first, then route the inner value back - # through the converter so custom converters and the expected type - # both apply. + # transport wrapper. Unwrap first, then coerce the already-parsed + # inner value to the expected type via the converter (no redundant + # re-serialization round-trip). unwrapped = self._data_converter.deserialize(event.eventRaised.input.value)["result"] - result = self._data_converter.deserialize( - self._data_converter.serialize(unwrapped), + result = self._data_converter.coerce( + unwrapped, entity_task._expected_type, # pyright: ignore[reportPrivateUsage] ) if is_lock_event: @@ -2833,10 +2834,10 @@ def compare_versions(self, source_version: str | None, default_version: str | No class _ActivityExecutor: def __init__(self, registry: _Registry, logger: logging.Logger, - data_converter: DataConverter | None = None): + data_converter: DataConverter): self._registry = registry self._logger = logger - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter def execute( self, @@ -2855,7 +2856,7 @@ def execute( f"Activity function named '{name}' was not registered!" ) - input_type = type_discovery.activity_input_type(fn) if encoded_input else None + input_type = type_discovery.activity_input_type(fn, self._data_converter) if encoded_input else None activity_input = self._data_converter.deserialize(encoded_input, input_type) ctx = task.ActivityContext(orchestration_id, task_id) @@ -2872,10 +2873,10 @@ def execute( class _EntityExecutor: def __init__(self, registry: _Registry, logger: logging.Logger, - data_converter: DataConverter | None = None): + data_converter: DataConverter): self._registry = registry self._logger = logger - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter self._entity_method_cache: dict[tuple[type, str], bool] = {} def execute( @@ -2896,7 +2897,7 @@ def execute( f"Entity function named '{entity_id.entity}' was not registered!" ) - input_type = type_discovery.entity_input_type(fn, operation) if encoded_input else None + input_type = type_discovery.entity_input_type(fn, operation, self._data_converter) if encoded_input else None entity_input = self._data_converter.deserialize(encoded_input, input_type) ctx = EntityContext(orchestration_id, operation, state, entity_id, self._data_converter) diff --git a/examples/README.md b/examples/README.md index 3a4ce41f..7476e4d4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -169,6 +169,12 @@ python activity_sequence.py > account and an additional install step. See the > [large payload README](./large_payload/README.md) for details. +> [!NOTE] +> The `custom_data_converter/` example is a self-contained subproject that +> shows how to plug a third-party serialization library (pydantic) into the +> SDK via a custom `DataConverter`, with in-process tests that need no backend. +> See the [custom DataConverter README](./custom_data_converter/README.md). + ### Review Orchestration History and Status To access the Durable Task Scheduler Dashboard, follow these steps: diff --git a/examples/custom_data_converter/README.md b/examples/custom_data_converter/README.md new file mode 100644 index 00000000..e375689a --- /dev/null +++ b/examples/custom_data_converter/README.md @@ -0,0 +1,176 @@ +# Custom DataConverter Example (pydantic) + +This example shows how to plug a **third-party serialization library** into the +Durable Task Python SDK by implementing a custom +[`DataConverter`](../../durabletask/serialization.py). It uses +[pydantic](https://docs.pydantic.dev/) to serialize and *validate* payloads, and +proves the integration works end-to-end with in-process tests that need no +sidecar, emulator, or Azure resources. + +You can copy this entire folder into a new directory and run it as a standalone +project. + +## Why a custom converter? + +By default the SDK serializes payloads with its built-in JSON codec +(`JsonDataConverter`), which understands builtins, dataclasses, and objects that +expose `to_json()` / `from_json()` hooks. It does **not** know about pydantic +models. Supplying a custom `DataConverter` lets you route serialization through +any library you like — pydantic, `attrs`, `marshmallow`, a schema registry, an +encryption layer, etc. — at every payload boundary the SDK touches. + +## How the seam works + +Both the worker and the client accept a `data_converter` argument, and the SDK +routes **every** payload boundary through it — orchestrator / activity / entity +inputs and outputs, external events, and custom status: + +```python +converter = PydanticDataConverter() + +worker = DurableTaskSchedulerWorker(..., data_converter=converter) +client = DurableTaskSchedulerClient(..., data_converter=converter) +``` + +> [!IMPORTANT] +> Pass an equivalent converter to **both** the worker and the client. A payload +> serialized by one side is reconstructed by the other, so they must agree on +> the format. + +A `DataConverter` implements three methods: + +| Method | Direction | Used when | +|---|---|---| +| `serialize(value)` | Python value → JSON string | Any value leaves the process | +| `deserialize(data, target_type)` | JSON string → Python value (optionally typed) | A value arrives and the SDK knows the target type (from a function annotation, `return_type=`, or a typed client accessor) | +| `coerce(value, target_type)` | already-parsed value → typed value | The SDK already holds a parsed value (e.g. entity state) | + +The converter in [src/converter.py](src/converter.py) recognizes +`pydantic.BaseModel` subclasses and uses pydantic for them, **delegating +everything else** to the default `JsonDataConverter`. This "handle my types, +delegate the rest" shape is a good starting point for a real converter — it +costs nothing for non-pydantic payloads. + +> [!NOTE] +> To stay focused on the seam, this example only intercepts when a model is the +> *top-level* type. A model nested inside another model (like `Order.items`) +> round-trips because pydantic recurses on its own, but a model nested in a +> *top-level generic* the SDK rebuilds directly — e.g. `return_type=list[OrderItem]` +> or an input annotated `dict[str, OrderItem]` — is not intercepted and falls to +> the default codec, which leaves the elements as raw dicts. A production +> converter would recurse into such generics (for example via +> `pydantic.TypeAdapter`). + +## Inbound inputs: `can_reconstruct` + +There is one extra detail for reconstructing **inbound** orchestrator/activity +*inputs*. Before the SDK hands an input to your converter, it asks the converter +whether the function's annotated input type is something it can rebuild, via +`DataConverter.can_reconstruct(target_type)`. The default implementation +recognizes dataclasses and `from_json()`-capable types (and `Optional` / `list` +wrappers) — it does **not** know about pydantic models, so without an override an +input annotated `order: Order` would arrive as a plain `dict`. + +The converter overrides `can_reconstruct` to also recognize +`pydantic.BaseModel` subclasses, deferring everything else to the same +`JsonDataConverter` fallback it uses for serialization: + +```python +def can_reconstruct(self, target_type): + if _is_model_type(target_type): + return True + return self._fallback.can_reconstruct(target_type) # dataclasses, from_json, ... +``` + +The base `DataConverter.can_reconstruct` is conservative — it returns `False`, +so a converter only claims the types it actually rebuilds. Outbound values, +`return_type=` arguments, and typed client accessors (`state.get_output(Receipt)`) +don't depend on this hook — they pass the type to the converter directly. + + + +## Folder structure + +```text +custom_data_converter/ +├── README.md +├── requirements.txt +├── src/ +│ ├── __init__.py +│ ├── converter.py # PydanticDataConverter — the integration point +│ ├── workflows.py # pydantic models + orchestrator/activities +│ └── app.py # runs against a real DTS backend / emulator +└── test/ + ├── __init__.py + └── test_custom_converter.py # in-process proof using the in-memory backend +``` + +## What the example proves + +The models in [src/workflows.py](src/workflows.py) are plain +`pydantic.BaseModel` subclasses — *not* dataclasses — so they only round-trip +correctly **because** of the custom converter. The tests in +[test/test_custom_converter.py](test/test_custom_converter.py) verify: + +1. A pydantic `Order` passed as orchestration input arrives at the + orchestrator/activity as a **validated model instance** (attribute access), + not a raw dict. +2. The orchestration's pydantic `Receipt` result is reconstructed, typed, on + the client via `state.get_output(Receipt)`. +3. The wire payload is genuine pydantic JSON (`model_dump_json`), confirming the + custom converter — not the default codec — handled it. +4. An input that violates a pydantic constraint fails the orchestration with a + validation error, instead of passing bad data through. +5. For contrast, the default `JsonDataConverter` **cannot serialize** a pydantic + model at all (it raises `TypeError`) — which is exactly what motivates the + custom converter. + +## Getting started + +1. Copy this folder to a new location and `cd` into it: + + ```bash + cd custom_data_converter + ``` + +1. Create and activate a virtual environment: + + Bash: + + ```bash + python -m venv .venv + source .venv/bin/activate + ``` + + PowerShell: + + ```powershell + python -m venv .venv + .\.venv\Scripts\Activate.ps1 + ``` + +1. Install dependencies: + + ```bash + pip install -r requirements.txt + ``` + +## Running the tests (no backend required) + +From the `custom_data_converter/` directory: + +```bash +pytest test/ +``` + +This is the self-contained proof: it runs the full orchestration in-process +against the in-memory backend. + +## Running the app against the emulator + +Start the [DTS emulator](../README.md#running-with-the-emulator), then from the +`custom_data_converter/` directory: + +```bash +python -m src.app +``` diff --git a/examples/custom_data_converter/requirements.txt b/examples/custom_data_converter/requirements.txt new file mode 100644 index 00000000..a0c9a1b4 --- /dev/null +++ b/examples/custom_data_converter/requirements.txt @@ -0,0 +1,5 @@ +durabletask +durabletask-azuremanaged +azure-identity +pydantic>=2 +pytest diff --git a/examples/custom_data_converter/src/__init__.py b/examples/custom_data_converter/src/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/examples/custom_data_converter/src/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/examples/custom_data_converter/src/app.py b/examples/custom_data_converter/src/app.py new file mode 100644 index 00000000..caa9eb19 --- /dev/null +++ b/examples/custom_data_converter/src/app.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Run the pydantic-backed workflow against a Durable Task Scheduler backend. + +The only thing that differs from an ordinary app is the ``data_converter=`` +argument passed to **both** the worker and the client. Pass the *same* +converter instance (or an equivalent one) to both sides so payloads serialized +by one are reconstructed by the other. + +Usage (emulator -- no env vars needed): + python -m src.app + +Usage (Azure): + export ENDPOINT=https://.durabletask.io + export TASKHUB= + python -m src.app + +For a self-contained proof that needs no backend at all, run the tests instead +(see ``test/test_custom_converter.py``). +""" + +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +from src.converter import PydanticDataConverter +from src.workflows import Receipt, calculate_total, charge_payment, process_order, sample_order + + +def main() -> None: + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + print(f"Using taskhub: {taskhub_name}") + print(f"Using endpoint: {endpoint}") + + secure_channel = endpoint.startswith("https://") + credential = DefaultAzureCredential() if secure_channel else None + + # The single line that wires in the third-party converter. The same + # converter is handed to the worker and the client below. + converter = PydanticDataConverter() + + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + data_converter=converter, + ) as w: + w.add_orchestrator(process_order) + w.add_activity(calculate_total) + w.add_activity(charge_payment) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + data_converter=converter, + ) + + # ``input`` is a pydantic ``Order``; the converter serializes it. + instance_id = c.schedule_new_orchestration(process_order, input=sample_order()) + print(f"Scheduled orchestration: {instance_id}") + + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + # ``get_output(Receipt)`` reconstructs the typed, validated result. + receipt = state.get_output(Receipt) + assert receipt is not None + print("Orchestration completed. Typed receipt:") + print(f" customer = {receipt.customer}") + print(f" total = {receipt.total}") + print(f" item_count = {receipt.item_count}") + print(f" confirmation_id = {receipt.confirmation_id}") + elif state: + print(f"Orchestration failed: {state.failure_details}") + else: + print("Orchestration timed out.") + + +if __name__ == "__main__": + main() diff --git a/examples/custom_data_converter/src/converter.py b/examples/custom_data_converter/src/converter.py new file mode 100644 index 00000000..360fee66 --- /dev/null +++ b/examples/custom_data_converter/src/converter.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pydantic-backed :class:`DataConverter` for the Durable Task Python SDK. + +This is the heart of the example: a small, self-contained converter that plugs +a third-party serialization/validation library (`pydantic +`_) into the SDK's pluggable serialization seam. + +How the seam works +------------------ +Both the worker and the client accept a ``data_converter`` argument. The SDK +routes **every** payload boundary -- orchestrator/activity/entity inputs and +outputs, external events, and custom status -- through that single converter, +so one object controls how Python values become JSON on the wire and how they +are reconstructed on the way back. + +A converter implements three methods: + +* ``serialize(value)`` -- Python value -> JSON string (or ``None``). +* ``deserialize(data, t)`` -- JSON string -> Python value, optionally + reconstructed as ``t`` (the type the SDK learned from a function's + annotation, a ``return_type=`` argument, or a typed client accessor). +* ``coerce(value, t)`` -- already-parsed value -> reconstructed as ``t`` + (used where the SDK already holds a parsed value, e.g. entity state). + +It may also override one hook: + +* ``can_reconstruct(t)`` -- tells the SDK's inbound type-discovery that an + *input* annotated with type ``t`` should be handed to ``deserialize`` / + ``coerce`` (rather than passed through as raw JSON). The default recognizes + dataclasses and ``from_json()``-capable types; override it to add your own + (here, pydantic models). + +This converter recognizes :class:`pydantic.BaseModel` subclasses and uses +pydantic for them (gaining validation, aliasing, custom field types, etc.). +For everything else -- ``str``, ``int``, ``list``, dataclasses, ... -- it +delegates to the SDK's default :class:`JsonDataConverter`, so plugging it in +costs nothing for non-pydantic payloads. This "handle my types, delegate the +rest" shape is the recommended pattern for a real custom converter. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + +from durabletask.serialization import DataConverter, JsonDataConverter + + +def _is_model_type(target_type: Any) -> bool: + """Return True when ``target_type`` is a pydantic model class.""" + return isinstance(target_type, type) and issubclass(target_type, BaseModel) + + +class PydanticDataConverter(DataConverter): + """A :class:`DataConverter` that serializes pydantic models with pydantic. + + Pydantic models are serialized via ``model_dump_json()`` and reconstructed + (with full validation) via ``model_validate_json()`` / ``model_validate()`` + whenever the SDK supplies the model type. Every other value falls through to + the default :class:`JsonDataConverter`. + + > [!NOTE] + > For simplicity this converter only intercepts when a pydantic model is the + > *top-level* ``target_type``. A model nested **inside another model** (like + > ``Order.items: list[OrderItem]`` below) round-trips fine because pydantic + > handles that recursion itself. But a model nested in a *top-level generic* + > the SDK is asked to rebuild directly -- e.g. ``return_type=list[OrderItem]`` + > or an input annotated ``dict[str, OrderItem]`` -- is **not** intercepted + > here; it falls to the default codec, which cannot construct a pydantic + > model from a positional ``dict`` and so leaves the elements as raw dicts. A + > production converter would recurse into such generics (for example via + > ``pydantic.TypeAdapter``); this example keeps the surface minimal. + """ + + def __init__(self) -> None: + # Delegate non-pydantic payloads to the SDK's built-in JSON codec. + self._fallback = JsonDataConverter() + + def serialize(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, BaseModel): + # ``model_dump_json`` honors pydantic field serializers, aliases, + # and custom types (e.g. ``datetime`` -> ISO 8601) automatically. + return value.model_dump_json() + return self._fallback.serialize(value) + + def deserialize(self, data: str | None, target_type: type | None = None) -> Any: + if data is None or data == "": + return None + if _is_model_type(target_type): + # ``model_validate_json`` parses *and validates* the payload, + # raising ``pydantic.ValidationError`` on malformed data. + return target_type.model_validate_json(data) # type: ignore[union-attr] + return self._fallback.deserialize(data, target_type) + + def coerce(self, value: Any, target_type: type | None = None) -> Any: + if value is None: + return None + if _is_model_type(target_type): + return target_type.model_validate(value) # type: ignore[union-attr] + return self._fallback.coerce(value, target_type) + + def can_reconstruct(self, target_type: Any) -> bool: + # Teach the SDK's inbound type-discovery that pydantic models are + # reconstructable, so an orchestrator/activity input annotated with a + # model type is rebuilt (and validated) by this converter instead of + # arriving as a plain ``dict``. For everything else, defer to the same + # ``JsonDataConverter`` fallback this converter uses for serialization, + # so its reconstruction claims match what it actually handles + # (dataclasses, ``from_json`` types, ``Optional`` / ``list`` wrappers; + # builtins excluded). + if _is_model_type(target_type): + return True + return self._fallback.can_reconstruct(target_type) diff --git a/examples/custom_data_converter/src/workflows.py b/examples/custom_data_converter/src/workflows.py new file mode 100644 index 00000000..61e8bb0d --- /dev/null +++ b/examples/custom_data_converter/src/workflows.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pydantic models and the workflow that exercises them. + +These models are plain :class:`pydantic.BaseModel` subclasses -- *not* +dataclasses -- which is what makes them a meaningful proof: without the custom +:class:`~src.converter.PydanticDataConverter`, the SDK's default codec would not +know how to validate or reconstruct them. The orchestrator and activities +annotate their inputs/outputs with these model types, so the SDK hands those +types to the converter at every boundary. +""" + +from __future__ import annotations + +from collections.abc import Generator +from datetime import datetime, timezone +from typing import Any + +from pydantic import BaseModel, Field, field_validator + +from durabletask import task + + +# --------------------------------------------------------------------------- +# Pydantic models (validated at every serialization boundary) +# --------------------------------------------------------------------------- +# These are plain ``pydantic.BaseModel`` subclasses -- no special hooks. The +# custom ``PydanticDataConverter`` both serializes them and (because it +# overrides ``can_reconstruct``) teaches the SDK to reconstruct them for +# inbound orchestrator/activity inputs. + + +class OrderItem(BaseModel): + """A single line item. Pydantic validates the constraints below.""" + + name: str = Field(min_length=1) + quantity: int = Field(gt=0) + unit_price: float = Field(ge=0) + + +class Order(BaseModel): + """An order placed by a customer at a point in time.""" + + customer: str = Field(min_length=1) + placed_at: datetime + items: list[OrderItem] + + @field_validator("items") + @classmethod + def _must_have_items(cls, items: list[OrderItem]) -> list[OrderItem]: + if not items: + raise ValueError("an order must contain at least one item") + return items + + +class Receipt(BaseModel): + """The orchestration's typed result.""" + + customer: str + total: float + item_count: int + confirmation_id: str + + +# --------------------------------------------------------------------------- +# Activities +# --------------------------------------------------------------------------- +# Each activity annotates its input parameter with a pydantic model type. The +# worker passes that type to ``PydanticDataConverter.deserialize``, so the +# activity body receives a fully validated model instance (attribute access), +# never a raw dict. + + +def calculate_total(ctx: task.ActivityContext, order: Order) -> float: + """Return the order's total cost from the validated model.""" + return round(sum(item.quantity * item.unit_price for item in order.items), 2) + + +def charge_payment(ctx: task.ActivityContext, amount: float) -> str: + """Pretend to charge a payment processor and return a confirmation ID.""" + if amount <= 0: + raise ValueError("payment amount must be positive") + return f"PAY-{int(amount * 100)}" + + +# --------------------------------------------------------------------------- +# Orchestrator +# --------------------------------------------------------------------------- + + +def process_order(ctx: task.OrchestrationContext, order: Order) -> Generator[task.Task[Any], Any, Receipt]: + """Validate, total, and charge an order, returning a typed :class:`Receipt`. + + The orchestrator's ``order`` parameter is reconstructed as a validated + ``Order`` by the custom converter. ``call_activity(..., return_type=Order)`` + and the ``Receipt`` return value likewise round-trip through pydantic. + """ + total: float = yield ctx.call_activity(calculate_total, input=order) + confirmation_id: str = yield ctx.call_activity(charge_payment, input=total) + + return Receipt( + customer=order.customer, + total=total, + item_count=sum(item.quantity for item in order.items), + confirmation_id=confirmation_id, + ) + + +def sample_order() -> Order: + """A valid sample order used by both the runnable app and the tests.""" + return Order( + customer="Contoso", + placed_at=datetime.now(timezone.utc), + items=[ + OrderItem(name="Widget", quantity=3, unit_price=25.0), + OrderItem(name="Gadget", quantity=1, unit_price=99.99), + ], + ) diff --git a/examples/custom_data_converter/test/__init__.py b/examples/custom_data_converter/test/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/examples/custom_data_converter/test/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/examples/custom_data_converter/test/test_custom_converter.py b/examples/custom_data_converter/test/test_custom_converter.py new file mode 100644 index 00000000..ed96d3f5 --- /dev/null +++ b/examples/custom_data_converter/test/test_custom_converter.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end proof that the pydantic ``DataConverter`` plugs into the SDK. + +These tests run entirely in-process against the in-memory backend -- no +sidecar, emulator, or Azure resources. They prove that: + +1. A pydantic model passed as orchestration input round-trips through the wire + and arrives at the orchestrator/activity as a *validated model instance*. +2. The orchestration's pydantic result is reconstructed, typed, on the client + via ``state.get_output(Receipt)``. +3. The wire payload is real pydantic JSON (``model_dump_json``), confirming the + custom converter -- not the default codec -- handled the model. +4. The same converter must be supplied to both worker and client. + +Run from the example root (custom_data_converter/): + pytest test/ +""" + +from __future__ import annotations + +import json + +import pytest + +from durabletask import client, worker +from durabletask.testing import create_test_backend + +from src.converter import PydanticDataConverter +from src.workflows import ( + Order, + OrderItem, + Receipt, + calculate_total, + charge_payment, + process_order, + sample_order, +) + +HOST = "localhost:50071" +PORT = 50071 + + +@pytest.fixture(autouse=True) +def backend(): + """Start and stop the in-memory backend for each test.""" + b = create_test_backend(port=PORT) + yield b + b.stop() + b.reset() + + +def _worker(converter: PydanticDataConverter) -> worker.TaskHubGrpcWorker: + w = worker.TaskHubGrpcWorker(host_address=HOST, data_converter=converter) + w.add_orchestrator(process_order) + w.add_activity(calculate_total) + w.add_activity(charge_payment) + return w + + +def _run(order: Order): + """Run ``process_order`` with the pydantic converter wired into both sides.""" + converter = PydanticDataConverter() + with _worker(converter) as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST, data_converter=converter) as c: + instance_id = c.schedule_new_orchestration(process_order, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + return state + + +def test_pydantic_model_round_trips_and_output_is_typed(): + """A valid order completes and the client reads back a typed Receipt.""" + order = sample_order() + state = _run(order) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + # The output is reconstructed as a validated pydantic model, not a dict. + receipt = state.get_output(Receipt) + assert isinstance(receipt, Receipt) + assert receipt.customer == "Contoso" + assert receipt.item_count == 4 # 3 widgets + 1 gadget + assert receipt.total == pytest.approx(174.99) + assert receipt.confirmation_id == "PAY-17499" + + +def test_input_is_validated_pydantic_model_on_the_wire(): + """The serialized input is genuine pydantic JSON, proving the converter ran.""" + converter = PydanticDataConverter() + order = sample_order() + + serialized = converter.serialize(order) + assert serialized is not None + # ``model_dump_json`` emits the model fields; ``placed_at`` is an ISO string. + payload = json.loads(serialized) + assert payload["customer"] == "Contoso" + assert isinstance(payload["placed_at"], str) + assert payload["items"][0]["name"] == "Widget" + + # And the SDK reconstructs it back into a validated model given the type. + restored = converter.deserialize(serialized, Order) + assert isinstance(restored, Order) + assert isinstance(restored.items[0], OrderItem) + assert restored.items[0].quantity == 3 + + +def test_invalid_order_input_fails_validation_in_orchestration(): + """An order that violates a pydantic constraint surfaces as a failure. + + The activity/orchestrator receives the input via the converter, so the + pydantic ``ValidationError`` raised while reconstructing the model fails the + orchestration rather than silently passing bad data through. + """ + # ``quantity`` must be > 0; construct the bad payload past pydantic's own + # constructor guard via ``model_construct`` so the failure happens on the + # wire (during reconstruction) inside the orchestration. + bad_order = Order.model_construct( + customer="Contoso", + placed_at=sample_order().placed_at, + items=[OrderItem.model_construct(name="Widget", quantity=-1, unit_price=25.0)], + ) + + state = _run(bad_order) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert "validation" in (state.failure_details.message or "").lower() \ + or "quantity" in (state.failure_details.message or "").lower() + + +def test_default_converter_cannot_serialize_pydantic_model(): + """Contrast: the default codec has no idea how to serialize a pydantic model. + + This is what motivates the custom converter. Pydantic models are not + dataclasses and expose no ``to_json()`` hook, so the default + ``JsonDataConverter`` raises ``TypeError`` when asked to serialize one. The + ``PydanticDataConverter`` is what makes serialization work (via + ``model_dump_json``). + """ + from durabletask.serialization import JsonDataConverter + + default = JsonDataConverter() + with pytest.raises(TypeError): + default.serialize(sample_order()) diff --git a/tests/durabletask/test_activity_executor.py b/tests/durabletask/test_activity_executor.py index 408815b1..ff775adc 100644 --- a/tests/durabletask/test_activity_executor.py +++ b/tests/durabletask/test_activity_executor.py @@ -6,6 +6,7 @@ from typing import Any from durabletask import task, worker +from durabletask.serialization import JsonDataConverter logging.basicConfig( format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', @@ -53,5 +54,5 @@ def test_activity(ctx: task.ActivityContext, _): def _get_activity_executor(fn: task.Activity) -> tuple[worker._ActivityExecutor, str]: registry = worker._Registry() name = registry.add_activity(fn) - executor = worker._ActivityExecutor(registry, TEST_LOGGER) + executor = worker._ActivityExecutor(registry, TEST_LOGGER, JsonDataConverter()) return executor, name diff --git a/tests/durabletask/test_data_converter_roundtrip.py b/tests/durabletask/test_data_converter_roundtrip.py new file mode 100644 index 00000000..465aeb3a --- /dev/null +++ b/tests/durabletask/test_data_converter_roundtrip.py @@ -0,0 +1,584 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Comprehensive round-trip tests for the default ``JsonDataConverter``. + +These exercise ``serialize`` -> ``deserialize(..., target_type)`` through the +*public* converter API (not the private codec) across a broad matrix of object +shapes a user might reasonably hand to an orchestrator/activity/entity boundary: +plain dataclasses, deeply nested dataclasses, dataclasses with ``to_json`` / +``from_json`` hooks (including nested), nested non-dataclass custom classes, +enums, containers (``list`` / ``dict`` / ``tuple``), ``Optional`` / ``Union``, +recursive structures, and a set of types the SDK intentionally does **not** +auto-serialize. + +The intent is to lock in the "secure, minimal-effort, intuitive object +round-tripping" contract and to document -- via xfail/raises -- exactly where a +user must supply a ``to_json`` hook or a custom ``DataConverter``. +""" + +from __future__ import annotations + +import enum +import typing +from dataclasses import dataclass, field +from datetime import datetime +from decimal import Decimal + +import pytest + +from durabletask.serialization import JsonDataConverter + + +@pytest.fixture +def conv() -> JsonDataConverter: + return JsonDataConverter() + + +def _round_trip(conv: JsonDataConverter, value, target_type): + """Serialize then deserialize ``value`` back into ``target_type``.""" + return conv.deserialize(conv.serialize(value), target_type) + + +# --------------------------------------------------------------------------- +# Plain and nested dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class Address: + street: str + city: str + + +@dataclass +class Person: + name: str + age: int + address: Address | None = None + + +def test_plain_dataclass(conv): + value = Address("1 Main St", "Redmond") + assert _round_trip(conv, value, Address) == value + + +def test_nested_dataclass(conv): + value = Person("Ada", 30, Address("1 Main St", "Redmond")) + result = _round_trip(conv, value, Person) + assert result == value + assert isinstance(result.address, Address) + + +def test_optional_dataclass_field_present(conv): + value = Person("Ada", 30, Address("1 Main St", "Redmond")) + assert _round_trip(conv, value, Person).address == Address("1 Main St", "Redmond") + + +def test_optional_dataclass_field_absent(conv): + value = Person("Ada", 30, None) + assert _round_trip(conv, value, Person).address is None + + +@dataclass +class _L3: + v: int + + +@dataclass +class _L2: + l3: _L3 + + +@dataclass +class _L1: + l2: _L2 + + +def test_three_level_nested_dataclass(conv): + value = _L1(_L2(_L3(7))) + result = _round_trip(conv, value, _L1) + assert result == value + assert isinstance(result.l2.l3, _L3) + + +def test_frozen_dataclass(conv): + @dataclass(frozen=True) + class Frozen: + x: int + y: int + + value = Frozen(1, 2) + assert _round_trip(conv, value, Frozen) == value + + +def test_empty_dataclass(conv): + @dataclass + class Empty: + pass + + assert isinstance(_round_trip(conv, Empty(), Empty), Empty) + + +def test_dataclass_with_builtin_shadowing_field_names(conv): + @dataclass + class Shadow: + type: str + id: int + + value = Shadow("widget", 1) + assert _round_trip(conv, value, Shadow) == value + + +def test_dataclass_default_factory_list_of_dataclasses(conv): + @dataclass + class Bag: + items: list[Address] = field(default_factory=list) + + value = Bag([Address("a", "b")]) + result = _round_trip(conv, value, Bag) + assert result == value + assert isinstance(result.items[0], Address) + + +# --------------------------------------------------------------------------- +# Dataclasses with to_json / from_json hooks (incl. nested) +# --------------------------------------------------------------------------- + + +@dataclass +class Temperature: + """Dataclass whose JSON shape differs from its field layout.""" + + celsius: float + + def to_json(self) -> dict: + return {"fahrenheit": self.celsius * 9 / 5 + 32} + + @classmethod + def from_json(cls, data: dict) -> "Temperature": + return cls((data["fahrenheit"] - 32) * 5 / 9) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Temperature) and other.celsius == self.celsius + + +def test_dataclass_with_hooks(conv): + value = Temperature(100.0) + encoded = conv.serialize(value) + assert encoded == '{"fahrenheit": 212.0}' # hook shape, not raw field + assert conv.deserialize(encoded, Temperature) == value + + +def test_nested_dataclass_hook_inside_dataclass(conv): + @dataclass + class Reading: + temp: Temperature + note: str + + value = Reading(Temperature(37.0), "ok") + encoded = conv.serialize(value) + assert '"fahrenheit"' in encoded + result = conv.deserialize(encoded, Reading) + assert result.temp == Temperature(37.0) + assert result.note == "ok" + + +def test_list_of_hooked_dataclasses(conv): + value = [Temperature(0.0), Temperature(100.0)] + result = _round_trip(conv, value, list[Temperature]) + assert result == value + + +def test_dict_of_hooked_dataclasses(conv): + value = {"a": Temperature(0.0)} + result = _round_trip(conv, value, dict[str, Temperature]) + assert result["a"] == Temperature(0.0) + + +# --------------------------------------------------------------------------- +# Nested non-dataclass custom classes (with hooks) +# --------------------------------------------------------------------------- + + +class Money: + """A field type that is not JSON-serializable on its own.""" + + def __init__(self, cents: int): + self.cents = cents + + def to_json(self) -> dict: + return {"cents": self.cents} + + @classmethod + def from_json(cls, data: dict) -> "Money": + return cls(data["cents"]) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Money) and other.cents == self.cents + + +def test_dataclass_with_non_serializable_field_uses_hook(conv): + @dataclass + class Invoice: + amount: Money + + def __eq__(self, other: object) -> bool: + return isinstance(other, Invoice) and other.amount == self.amount + + value = Invoice(Money(500)) + result = _round_trip(conv, value, Invoice) + assert result == value + assert isinstance(result.amount, Money) + + +def test_standalone_custom_class_with_hooks(conv): + value = Money(250) + assert _round_trip(conv, value, Money) == value + + +def test_hook_returning_scalar_string(conv): + class Tag: + def __init__(self, name: str): + self.name = name + + def to_json(self) -> str: + return self.name + + @classmethod + def from_json(cls, data: str) -> "Tag": + return cls(data) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Tag) and other.name == self.name + + value = Tag("vip") + assert conv.serialize(value) == '"vip"' + assert _round_trip(conv, value, Tag) == value + + +def test_dict_of_custom_hooked_class(conv): + value = {"usd": Money(100), "eur": Money(200)} + result = _round_trip(conv, value, dict[str, Money]) + assert result == value + + +# --------------------------------------------------------------------------- +# Converter-aware from_json (recursive reconstruction of nested typed values) +# --------------------------------------------------------------------------- + + +class Ledger: + """A type whose from_json uses the converter to rebuild a nested value.""" + + def __init__(self, owner: str, balance: Money): + self.owner = owner + self.balance = balance + + def to_json(self) -> dict: + return {"owner": self.owner, "balance": self.balance} + + @classmethod + def from_json(cls, data: dict, converter) -> "Ledger": + return cls(data["owner"], converter.coerce(data["balance"], Money)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Ledger) + and other.owner == self.owner + and other.balance == self.balance + ) + + +def test_converter_aware_from_json(conv): + value = Ledger("ada", Money(999)) + result = _round_trip(conv, value, Ledger) + assert result == value + assert isinstance(result.balance, Money) + + +def test_converter_aware_from_json_nested_in_dataclass(conv): + @dataclass + class Account: + ledgers: list[Ledger] + + value = Account([Ledger("ada", Money(1))]) + result = _round_trip(conv, value, Account) + assert isinstance(result.ledgers[0], Ledger) + assert isinstance(result.ledgers[0].balance, Money) + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class Color(enum.IntEnum): + RED = 1 + GREEN = 2 + + +class Status(enum.Enum): + OPEN = "open" + CLOSED = "closed" + + +def test_int_enum_round_trip(conv): + assert conv.serialize(Color.RED) == "1" + assert conv.deserialize("1", Color) is Color.RED + + +def test_str_enum_round_trip(conv): + assert conv.serialize(Status.OPEN) == '"open"' + assert conv.deserialize('"open"', Status) is Status.OPEN + + +def test_enum_as_dataclass_field(conv): + @dataclass + class Ticket: + status: Status + color: Color + + value = Ticket(Status.OPEN, Color.RED) + result = _round_trip(conv, value, Ticket) + assert result.status is Status.OPEN + assert result.color is Color.RED + + +def test_list_of_enums(conv): + value = [Status.OPEN, Status.CLOSED] + result = _round_trip(conv, value, list[Status]) + assert result == value + + +def test_optional_enum_field(conv): + @dataclass + class MaybeStatus: + status: Status | None = None + + assert _round_trip(conv, MaybeStatus(Status.CLOSED), MaybeStatus).status is Status.CLOSED + assert _round_trip(conv, MaybeStatus(None), MaybeStatus).status is None + + +# --------------------------------------------------------------------------- +# Containers +# --------------------------------------------------------------------------- + + +def test_list_of_dataclasses(conv): + value = [Address("a", "b"), Address("c", "d")] + result = _round_trip(conv, value, list[Address]) + assert result == value + + +def test_dict_of_dataclasses(conv): + value = {"home": Address("a", "b")} + result = _round_trip(conv, value, dict[str, Address]) + assert isinstance(result["home"], Address) + + +def test_dict_of_optional_dataclasses(conv): + value = {"a": Address("x", "y"), "b": None} + result = _round_trip(conv, value, dict[str, typing.Optional[Address]]) + assert isinstance(result["a"], Address) + assert result["b"] is None + + +def test_fixed_length_tuple(conv): + value = (Address("a", "b"), 5) + result = _round_trip(conv, value, tuple[Address, int]) + assert isinstance(result, tuple) + assert result[0] == Address("a", "b") + assert result[1] == 5 + + +def test_homogeneous_tuple(conv): + value = (Address("a", "b"), Address("c", "d")) + result = _round_trip(conv, value, tuple[Address, ...]) + assert isinstance(result, tuple) + assert all(isinstance(item, Address) for item in result) + + +def test_tuple_field_in_dataclass(conv): + @dataclass + class Pair: + pair: tuple[Address, Address] + + value = Pair((Address("a", "b"), Address("c", "d"))) + result = _round_trip(conv, value, Pair) + assert isinstance(result.pair, tuple) + assert result.pair[0] == Address("a", "b") + + +def test_dataclass_with_list_of_nested_dataclasses(conv): + @dataclass + class Team: + members: list[Person] + + value = Team([Person("A", 1, Address("x", "y"))]) + result = _round_trip(conv, value, Team) + assert isinstance(result.members[0], Person) + assert isinstance(result.members[0].address, Address) + + +def test_dataclass_with_dict_of_custom_class(conv): + @dataclass + class Wallet: + funds: dict[str, Money] + + value = Wallet({"usd": Money(100)}) + result = _round_trip(conv, value, Wallet) + assert isinstance(result.funds["usd"], Money) + assert result.funds["usd"] == Money(100) + + +# --------------------------------------------------------------------------- +# Recursive structures +# --------------------------------------------------------------------------- + + +@dataclass +class TreeNode: + value: int + children: list["TreeNode"] = field(default_factory=list) + + +def test_recursive_tree_dataclass(conv): + value = TreeNode(1, [TreeNode(2), TreeNode(3, [TreeNode(4)])]) + result = _round_trip(conv, value, TreeNode) + assert result == value + assert isinstance(result.children[1].children[0], TreeNode) + + +# --------------------------------------------------------------------------- +# Documented limitation: nested typed reconstruction needs resolvable hints. +# +# Type-directed reconstruction relies on ``typing.get_type_hints`` to read a +# dataclass's nested field types. A dataclass defined inside a function whose +# fields reference *other function-local* types cannot be resolved (the names +# are not in the module globals), so nested coercion is skipped and the field +# is returned as parsed JSON. Module-level dataclasses (the normal case) work. +# --------------------------------------------------------------------------- + + +def test_function_local_nested_dataclass_hint_is_not_resolvable(conv): + @dataclass + class Inner: + v: int + + @dataclass + class Outer: + inner: Inner + + result = _round_trip(conv, Outer(Inner(7)), Outer) + # The outer dataclass is rebuilt, but the inner field stays a plain dict + # because its type hint can't be resolved from a function scope. + assert isinstance(result, Outer) + assert result.inner == {"v": 7} + + +# --------------------------------------------------------------------------- +# Builtins / primitives +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("value,target", [ + (42, int), + ("hello", str), + (3.14, float), + (True, bool), + (None, type(None)), +]) +def test_primitive_round_trip(conv, value, target): + assert conv.deserialize(conv.serialize(value), target) == value + + +def test_list_of_primitives(conv): + assert _round_trip(conv, [1, 2, 3], list[int]) == [1, 2, 3] + + +def test_dict_of_primitives(conv): + assert _round_trip(conv, {"a": 1, "b": 2}, dict[str, int]) == {"a": 1, "b": 2} + + +def test_no_target_type_returns_plain_json(conv): + # Without a target type, custom objects come back as plain JSON. + encoded = conv.serialize(Address("a", "b")) + assert conv.deserialize(encoded) == {"street": "a", "city": "b"} + + +# --------------------------------------------------------------------------- +# Multi-member Union: documented limitation (not force-coerced) +# --------------------------------------------------------------------------- + + +def test_multi_member_union_field_is_not_force_coerced(conv): + @dataclass + class A: + a: int + + @dataclass + class B: + b: int + + @dataclass + class HasUnion: + val: typing.Union[A, B] + + # A multi-member Union cannot be disambiguated from the payload alone, so the + # value is intentionally left as parsed JSON rather than guessing a member. + # This documents the current, secure behavior. + value = HasUnion(A(1)) + result = _round_trip(conv, value, HasUnion) + assert result.val == {"a": 1} + + +# --------------------------------------------------------------------------- +# Types the SDK intentionally does NOT auto-serialize. +# +# These require a user-supplied ``to_json`` hook or a custom ``DataConverter``. +# The tests pin the current behavior (a clear TypeError) so a future change that +# adds support is a deliberate, reviewed decision. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("value", [ + datetime(2020, 1, 1, 12, 0, 0), + Decimal("1.5"), + {1, 2, 3}, + frozenset({1, 2}), + b"bytes", +]) +def test_unsupported_types_raise_clear_typeerror(conv, value): + with pytest.raises(TypeError) as exc_info: + conv.serialize(value) + assert type(value).__name__ in str(exc_info.value) + + +def test_plain_object_without_hook_raises(conv): + class Plain: + def __init__(self): + self.a = 1 + + with pytest.raises(TypeError): + conv.serialize(Plain()) + + +def test_custom_datetime_via_to_json_hook(conv): + # The supported way to round-trip a datetime is a to_json/from_json hook. + class Timestamp: + def __init__(self, dt: datetime): + self.dt = dt + + def to_json(self) -> str: + return self.dt.isoformat() + + @classmethod + def from_json(cls, data: str) -> "Timestamp": + return cls(datetime.fromisoformat(data)) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Timestamp) and other.dt == self.dt + + value = Timestamp(datetime(2020, 1, 1, 12, 0, 0)) + assert _round_trip(conv, value, Timestamp) == value diff --git a/tests/durabletask/test_entity_executor.py b/tests/durabletask/test_entity_executor.py index e853f13f..42774050 100644 --- a/tests/durabletask/test_entity_executor.py +++ b/tests/durabletask/test_entity_executor.py @@ -2,10 +2,12 @@ # Licensed under the MIT License. """Unit tests for the _EntityExecutor class in durabletask.worker.""" +import json import logging from durabletask import entities from durabletask.internal.entity_state_shim import StateShim +from durabletask.serialization import JsonDataConverter from durabletask.worker import _EntityExecutor, _Registry @@ -14,13 +16,13 @@ def _make_executor(*entity_args) -> _EntityExecutor: registry = _Registry() for entity in entity_args: registry.add_entity(entity) - return _EntityExecutor(registry, logging.getLogger("test")) + return _EntityExecutor(registry, logging.getLogger("test"), JsonDataConverter()) def _execute(executor, entity_name, operation, encoded_input=None): """Helper to execute an entity operation.""" entity_id = entities.EntityInstanceId(entity_name, "test-key") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) return executor.execute("test-orchestration", entity_id, operation, state, encoded_input) @@ -73,7 +75,7 @@ def get(self): entity_id = entities.EntityInstanceId("Counter", "test-key") # set requires input - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("test-orch", entity_id, "set", state, "10") state.commit() @@ -125,7 +127,7 @@ def counter(ctx: entities.EntityContext, input): executor = _make_executor(counter) entity_id = entities.EntityInstanceId("counter", "test-key") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("test-orch", entity_id, "set", state, "42") state.commit() @@ -138,19 +140,19 @@ class TestStateShimCoercion: """Tests for StateShim.get_state type coercion via the data converter.""" def test_get_state_none_returns_default(self): - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) assert state.get_state(int, 0) == 0 def test_get_state_none_without_default_returns_none(self): - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) assert state.get_state(int) is None def test_get_state_passes_through_matching_type(self): - state = StateShim(5) + state = StateShim(5, JsonDataConverter()) assert state.get_state(int) == 5 def test_get_state_constructor_coercion(self): - state = StateShim("5") + state = StateShim("5", JsonDataConverter()) assert state.get_state(int) == 5 def test_get_state_coerces_dataclass(self): @@ -161,7 +163,7 @@ class Counter: value: int # State is stored as a plain dict (as it would be after from_json). - state = StateShim({"value": 7}) + state = StateShim({"value": 7}, JsonDataConverter()) result = state.get_state(Counter) assert isinstance(result, Counter) assert result.value == 7 @@ -175,7 +177,7 @@ def __init__(self, n: int): def from_json(cls, data): return cls(data["n"]) - state = StateShim({"n": 3}) + state = StateShim({"n": 3}, JsonDataConverter()) result = state.get_state(Wrapped) assert isinstance(result, Wrapped) assert result.n == 3 @@ -185,6 +187,105 @@ def test_get_state_invalid_coercion_raises(self): # restoring the pre-existing strict contract for entity state access. import pytest - state = StateShim("not-an-int") + state = StateShim("not-an-int", JsonDataConverter()) with pytest.raises(TypeError): state.get_state(int) + + +class TestStateShimDeferredDeserialization: + """Tests for StateShim deferring deserialization of the raw wire payload.""" + + def test_constructor_does_not_deserialize_serialized_state(self): + # A serialized payload is held verbatim until read, not eagerly parsed. + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + assert state._current_state == '{"value": 7}' + + def test_get_state_defers_deserialization_with_type(self): + from dataclasses import dataclass + + @dataclass + class Counter: + value: int + + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + result = state.get_state(Counter) + assert isinstance(result, Counter) + assert result.value == 7 + + def test_get_state_no_type_returns_parsed_value(self): + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + assert state.get_state() == {"value": 7} + + def test_deferred_deserialization_passes_raw_string_to_converter(self): + # A custom converter receives the original serialized string together + # with the requested type, rather than a pre-parsed value. + from typing import Any + + from durabletask.serialization import DataConverter + + seen: dict[str, Any] = {} + + class RecordingConverter(DataConverter): + def serialize(self, value: Any) -> str | None: + return None if value is None else json.dumps(value) + + def deserialize(self, data, target_type=None): + seen["data"] = data + seen["target_type"] = target_type + return {"parsed": True} + + def coerce(self, value, target_type=None): + seen["coerced"] = True + return value + + state = StateShim('{"x": 1}', RecordingConverter(), is_serialized=True) + state.get_state(dict) + assert seen["data"] == '{"x": 1}' + assert seen["target_type"] is dict + assert "coerced" not in seen + + def test_encode_state_passes_through_unmodified_payload(self): + # An unread/unmodified serialized payload is returned verbatim, never + # re-serialized (which would double-encode the JSON string). + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + assert state.encode_state() == '{"value": 7}' + + def test_reading_does_not_trigger_double_encoding(self): + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + # Reading (even with a type) must not turn the payload into a live value + # that would be re-serialized into a JSON-encoded string. + state.get_state() + state.get_state(dict) + encoded = state.encode_state() + assert encoded == '{"value": 7}' + assert json.loads(encoded) == {"value": 7} + + def test_encode_state_serializes_live_value_after_set_state(self): + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + state.set_state({"value": 8}) + encoded = state.encode_state() + assert json.loads(encoded) == {"value": 8} + + def test_encode_state_none_when_state_is_none(self): + state = StateShim(None, JsonDataConverter(), is_serialized=True) + assert state.encode_state() is None + + def test_commit_preserves_unmodified_payload(self): + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + state.commit() + # After commit, the (unmodified) state still round-trips without + # double-encoding. + assert state.encode_state() == '{"value": 7}' + + def test_rollback_restores_unmodified_payload(self): + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) + state.commit() + state.set_state({"value": 99}) + state.rollback() + assert state.encode_state() == '{"value": 7}' + + def test_falsy_serialized_state_is_not_dropped(self): + # A serialized falsy value (e.g. 0) is preserved, not treated as cleared. + state = StateShim("0", JsonDataConverter(), is_serialized=True) + assert state.get_state(int) == 0 + assert state.encode_state() == "0" diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 790f8e23..e95c4613 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -10,6 +10,7 @@ import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb from durabletask import task, worker, entities +from durabletask.serialization import JsonDataConverter logging.basicConfig( format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', @@ -36,7 +37,7 @@ def orchestrator(ctx: task.OrchestrationContext, my_input: int): helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=json.dumps(test_input)), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -58,7 +59,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): name = registry.add_orchestrator(empty_orchestrator) new_events = [helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -73,7 +74,7 @@ def test_orchestrator_not_registered(): registry = worker._Registry() name = "Bogus" new_events = [helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -100,7 +101,7 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -133,7 +134,7 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_timer_fired_event(1, expected_fire_at)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -160,7 +161,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -180,7 +181,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start_time = datetime(2020, 1, 1, 12, 0, 0) t1 = start_time + timedelta(days=3) @@ -277,7 +278,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start_time = datetime(2020, 1, 1, 12, 0, 0) first_chunk_fire_at = start_time + timedelta(days=3) @@ -325,7 +326,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start_time = datetime(2020, 1, 1, 12, 0, 0) timeout_fire_at = start_time + timedelta(hours=1) @@ -363,7 +364,7 @@ def test_only_cancellable_tasks_expose_cancel(): def dummy_activity(ctx, _): pass - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, worker._Registry()) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, worker._Registry(), JsonDataConverter()) timer_task = ctx.create_timer(timedelta(minutes=5)) external_event_task = ctx.wait_for_external_event("approval") @@ -408,7 +409,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -441,7 +442,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): encoded_output = json.dumps("done!") new_events = [helpers.new_task_completed_event(1, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -479,7 +480,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [helpers.new_task_completed_event(1, json.dumps({"message": "hi"}))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) @@ -517,7 +518,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [helpers.new_task_completed_event(1, json.dumps({"message": "hi"}))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) assert captured["type"] == "Result" @@ -556,7 +557,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [helpers.new_task_completed_event(1, json.dumps({"value": "x"}))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) assert captured["type"] == "Override" @@ -587,7 +588,7 @@ def orchestrator(ctx: task.OrchestrationContext, args: StartArgs): helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=encoded_input)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, [], new_events) assert captured["type"] == "StartArgs" @@ -615,7 +616,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): ex = Exception("Kah-BOOOOM!!!") new_events = [helpers.new_task_failed_event(1, ex)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -661,7 +662,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -675,7 +676,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(2, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -688,7 +689,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -702,7 +703,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -715,7 +716,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 4 @@ -729,7 +730,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(4, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 4 @@ -742,7 +743,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 5 @@ -756,7 +757,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(5, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 5 @@ -770,7 +771,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 6 @@ -784,7 +785,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(6, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 6 @@ -796,7 +797,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 7 @@ -840,7 +841,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -854,7 +855,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(2, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -866,7 +867,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -880,7 +881,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -890,7 +891,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 4 @@ -929,7 +930,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Fail!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -942,7 +943,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(2, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) expected_fire_at = current_timestamp + timedelta(seconds=5) @@ -950,7 +951,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Fail!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -978,7 +979,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start = datetime.utcnow() t1 = start + timedelta(days=3) @@ -1077,7 +1078,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_timer_created_event(1, fire_at)] new_events = [helpers.new_timer_fired_event(timer_id=1, fire_at=fire_at)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1105,7 +1106,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [helpers.new_task_completed_event(1)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1135,7 +1136,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [helpers.new_task_completed_event(1)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1166,7 +1167,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [helpers.new_task_completed_event(1)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1200,7 +1201,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1230,7 +1231,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): ex = Exception("Kah-BOOOOM!!!") new_events = [helpers.new_sub_orchestration_failed_event(1, ex)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1261,7 +1262,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1291,7 +1292,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1318,7 +1319,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Execute the orchestration until it is waiting for an external event. The result # should be an empty list of actions because the orchestration didn't schedule any work. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 0 @@ -1327,7 +1328,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # the orchestration should complete. old_events = new_events new_events = [helpers.new_event_raised_event("my_event", encoded_input="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1352,7 +1353,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_event_raised_event("my_event", encoded_input="42")] # Execute the orchestration. It should be in a running state waiting for the timer to fire - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -1363,7 +1364,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): timer_due_time = datetime.utcnow() + timedelta(days=1) old_events = new_events + [helpers.new_timer_created_event(1, timer_due_time)] new_events = [helpers.new_timer_fired_event(1, timer_due_time)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1390,7 +1391,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Execute the orchestration. It should remain in a running state because it was suspended prior # to processing the event raised event. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 0 @@ -1398,7 +1399,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Resume the orchestration. It should complete successfully. old_events = old_events + new_events new_events = [helpers.new_resume_event()] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1424,7 +1425,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_event_raised_event("my_event", encoded_input="42")] # Execute the orchestration. It should be in a running state waiting for an external event - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1453,7 +1454,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))] # Execute the orchestration. It should be in a running state waiting for the timer to fire - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1489,7 +1490,7 @@ def orchestrator(ctx: task.OrchestrationContext, count: int): helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input="10")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1531,13 +1532,13 @@ def orchestrator(ctx: task.OrchestrationContext, _): # First, test with only the first 5 events. We expect the orchestration to be running # but return zero actions since its still waiting for the other 5 tasks to complete. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events[:5]) actions = result.actions assert len(actions) == 0 # Now test with the full set of new events. We expect the orchestration to complete. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1579,7 +1580,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events.append(helpers.new_task_failed_event(6, ex)) # Now test with the full set of new events. We expect the orchestration to complete. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1611,7 +1612,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # to return two actions: one to schedule the "Tokyo" task and one to schedule the "Seattle" task. old_events = [] new_events = [helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -1628,7 +1629,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Test 2: Complete the "Tokyo" task. We expect the orchestration to complete with output "Hello, Tokyo!" encoded_output = json.dumps(hello(None, "Tokyo")) new_events = [helpers.new_task_completed_event(1, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1638,7 +1639,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Test 3: Complete the "Seattle" task. We expect the orchestration to complete with output "Hello, Seattle!" encoded_output = json.dumps(hello(None, "Seattle")) new_events = [helpers.new_task_completed_event(2, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1684,7 +1685,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_orchestrator_started_event(datetime.now()), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) assert result.actions # should have scheduled the activity @@ -1700,7 +1701,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): ] encoded_output = json.dumps(say_hello(None, "World")) new_events = [helpers.new_task_completed_event(1, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED @@ -1743,7 +1744,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_orchestrator_started_event(datetime.now()), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED @@ -1877,7 +1878,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -1891,7 +1892,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -1904,7 +1905,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -1915,7 +1916,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Complete the "Seattle" task. We expect the orchestration to complete with output "Hello, Seattle!" encoded_output = json.dumps(dummy_activity(None, "Seattle")) new_events = [helpers.new_task_completed_event(2, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(3, actions) @@ -1959,7 +1960,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -1973,7 +1974,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -1986,7 +1987,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -2000,7 +2001,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = old_events + new_events new_events = [helpers.new_task_completed_event(2, encoded_output), helpers.new_timer_fired_event(4, expected_fire_at)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -2014,7 +2015,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(4, actions) @@ -2038,7 +2039,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input), helpers.new_orchestrator_completed_event()] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -2066,7 +2067,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_execution_started_event(name, TEST_INSTANCE_ID, None), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result1 = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result1.actions assert len(actions) == 1 diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index 33003c21..ae797c2f 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -1,16 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Unit tests for the JSON serialization codec in durabletask.internal.json_codec.""" +"""Unit tests for the private JSON codec in durabletask.serialization. + +These exercise the low-level encode/decode mechanism directly. The supported, +public surface is the ``DataConverter`` abstraction (see +``test_data_converter.py``). +""" import json from collections import namedtuple from dataclasses import dataclass from types import SimpleNamespace +from typing import List, Optional, Union, get_args import pytest -from durabletask.internal import json_codec +from durabletask.serialization import _AUTO_SERIALIZED as AUTO_SERIALIZED +from durabletask.serialization import _coerce_to_type as coerce_to_type +from durabletask.serialization import _from_json as from_json +from durabletask.serialization import _resolve_forward_refs as resolve_forward_refs +from durabletask.serialization import _to_json as to_json # ----- Test fixtures ----- @@ -78,58 +88,58 @@ def __eq__(self, other: object) -> bool: def test_to_json_static_hook_receives_instance(): # type(obj).to_json(obj) must invoke the @staticmethod with the instance. - assert json_codec.to_json(StaticWidget("gizmo")) == '"gizmo"' + assert to_json(StaticWidget("gizmo")) == '"gizmo"' def test_static_hook_round_trips_with_expected_type(): - encoded = json_codec.to_json(StaticWidget("gizmo")) - result = json_codec.from_json(encoded, StaticWidget) + encoded = to_json(StaticWidget("gizmo")) + result = from_json(encoded, StaticWidget) assert isinstance(result, StaticWidget) assert result == StaticWidget("gizmo") def test_instance_to_json_hook_receives_instance(): # The same type(obj).to_json(obj) path works for plain instance methods. - assert json.loads(json_codec.to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} + assert json.loads(to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} # ----- to_json ----- def test_to_json_builtins_are_plain_json(): - assert json_codec.to_json({"a": 1, "b": [1, 2, 3]}) == json.dumps({"a": 1, "b": [1, 2, 3]}) - assert json_codec.to_json("hello") == '"hello"' - assert json_codec.to_json(42) == "42" + assert to_json({"a": 1, "b": [1, 2, 3]}) == json.dumps({"a": 1, "b": [1, 2, 3]}) + assert to_json("hello") == '"hello"' + assert to_json(42) == "42" def test_to_json_dataclass_emits_plain_dict_without_marker(): - encoded = json_codec.to_json(Address("1 Main St", "Redmond")) + encoded = to_json(Address("1 Main St", "Redmond")) parsed = json.loads(encoded) assert parsed == {"street": "1 Main St", "city": "Redmond"} - assert json_codec.AUTO_SERIALIZED not in encoded + assert AUTO_SERIALIZED not in encoded def test_to_json_nested_dataclass_has_no_marker(): - encoded = json_codec.to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) - assert json_codec.AUTO_SERIALIZED not in encoded + encoded = to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) + assert AUTO_SERIALIZED not in encoded parsed = json.loads(encoded) assert parsed["address"] == {"street": "1 Main St", "city": "Redmond"} def test_to_json_simplenamespace_emits_plain_dict(): - encoded = json_codec.to_json(SimpleNamespace(a=1, b="two")) - assert json_codec.AUTO_SERIALIZED not in encoded + encoded = to_json(SimpleNamespace(a=1, b="two")) + assert AUTO_SERIALIZED not in encoded assert json.loads(encoded) == {"a": 1, "b": "two"} def test_to_json_custom_object_uses_to_json_hook(): - encoded = json_codec.to_json(Widget("gear", 5)) + encoded = to_json(Widget("gear", 5)) assert json.loads(encoded) == {"label": "gear", "size": 5} def test_to_json_namedtuple_serializes_as_array(): # Without an expected_type the field names are not preserved on the wire. - assert json_codec.to_json(Point(1, 2)) == "[1, 2]" + assert to_json(Point(1, 2)) == "[1, 2]" def test_to_json_unserializable_raises_typeerror_with_cause(): @@ -137,33 +147,123 @@ class NotSerializable: pass with pytest.raises(TypeError) as exc_info: - json_codec.to_json(NotSerializable()) + to_json(NotSerializable()) assert "NotSerializable" in str(exc_info.value) assert exc_info.value.__cause__ is not None +# ----- to_json: hook precedence and nested-hook recursion (PR #154 follow-up) ----- + + +@dataclass +class Temperature: + """Dataclass whose JSON shape differs from its field layout.""" + + celsius: float + + def to_json(self) -> dict: + return {"fahrenheit": self.celsius * 9 / 5 + 32} + + @classmethod + def from_json(cls, data: dict) -> "Temperature": + return cls((data["fahrenheit"] - 32) * 5 / 9) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Temperature) and other.celsius == self.celsius + + +class Money: + """A field type that is not JSON-serializable on its own.""" + + def __init__(self, cents: int): + self.cents = cents + + def __eq__(self, other: object) -> bool: + return isinstance(other, Money) and other.cents == self.cents + + +@dataclass +class Invoice: + """Dataclass with a non-serializable field, rescued by a to_json hook.""" + + amount: Money + + def to_json(self) -> dict: + return {"amount_cents": self.amount.cents} + + @classmethod + def from_json(cls, data: dict) -> "Invoice": + return cls(Money(data["amount_cents"])) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Invoice) and other.amount == self.amount + + +def test_dataclass_to_json_hook_takes_precedence_over_fields(): + # The dataclass defines to_json, so its output -- not the raw fields -- wins. + assert json.loads(to_json(Temperature(100.0))) == {"fahrenheit": 212.0} + + +def test_dataclass_with_non_serializable_field_uses_to_json_hook(): + # Previously the dataclass branch ran first and asdict() hit the + # non-serializable Money field, raising even though to_json was defined. + assert json.loads(to_json(Invoice(Money(500)))) == {"amount_cents": 500} + + +def test_dataclass_with_non_serializable_field_round_trips(): + encoded = to_json(Invoice(Money(500))) + assert from_json(encoded, Invoice) == Invoice(Money(500)) + + +def test_nested_dataclass_hook_is_honored(): + # The nested Temperature must serialize through its own to_json, not be + # flattened to its raw fields by dataclasses.asdict. + @dataclass + class Reading: + temp: Temperature + + encoded = to_json(Reading(Temperature(100.0))) + assert json.loads(encoded) == {"temp": {"fahrenheit": 212.0}} + + +def test_nested_dataclass_hook_round_trips(): + @dataclass + class Reading: + temp: Temperature + + encoded = to_json(Reading(Temperature(37.0))) + result = from_json(encoded, Reading) + assert isinstance(result, Reading) + assert result.temp == Temperature(37.0) + + +def test_simplenamespace_without_hook_still_emits_fields(): + # SimpleNamespace has no to_json, so it falls through to vars(). + assert json.loads(to_json(SimpleNamespace(x=1))) == {"x": 1} + + # ----- from_json without expected_type (loose mode) ----- def test_from_json_returns_raw_without_expected_type(): - assert json_codec.from_json('{"a": 1}') == {"a": 1} - assert json_codec.from_json("[1, 2, 3]") == [1, 2, 3] - assert json_codec.from_json("42") == 42 + assert from_json('{"a": 1}') == {"a": 1} + assert from_json("[1, 2, 3]") == [1, 2, 3] + assert from_json("42") == 42 def test_from_json_legacy_marker_decodes_to_simplenamespace(): - legacy = json.dumps({"a": 1, "b": 2, json_codec.AUTO_SERIALIZED: True}) - result = json_codec.from_json(legacy) + legacy = json.dumps({"a": 1, "b": 2, AUTO_SERIALIZED: True}) + result = from_json(legacy) assert isinstance(result, SimpleNamespace) assert result.a == 1 assert result.b == 2 def test_legacy_simplenamespace_reserializes_cleanly(): - legacy = json.dumps({"a": 1, json_codec.AUTO_SERIALIZED: True}) - ns = json_codec.from_json(legacy) - reencoded = json_codec.to_json(ns) - assert json_codec.AUTO_SERIALIZED not in reencoded + legacy = json.dumps({"a": 1, AUTO_SERIALIZED: True}) + ns = from_json(legacy) + reencoded = to_json(ns) + assert AUTO_SERIALIZED not in reencoded assert json.loads(reencoded) == {"a": 1} @@ -171,58 +271,112 @@ def test_legacy_simplenamespace_reserializes_cleanly(): def test_from_json_coerces_to_dataclass(): - encoded = json_codec.to_json(Address("1 Main St", "Redmond")) - result = json_codec.from_json(encoded, Address) + encoded = to_json(Address("1 Main St", "Redmond")) + result = from_json(encoded, Address) assert isinstance(result, Address) assert result == Address("1 Main St", "Redmond") def test_from_json_coerces_nested_dataclass(): - encoded = json_codec.to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) - result = json_codec.from_json(encoded, Person) + encoded = to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) + result = from_json(encoded, Person) assert isinstance(result, Person) assert isinstance(result.address, Address) assert result.address.city == "Redmond" def test_from_json_coerces_optional_dataclass_when_present(): - result = json_codec.from_json('{"name": "Ada", "age": 30, "address": null}', Person) + result = from_json('{"name": "Ada", "age": 30, "address": null}', Person) assert isinstance(result, Person) assert result.address is None def test_from_json_coerces_list_of_dataclasses(): - encoded = json_codec.to_json([Address("a", "b"), Address("c", "d")]) - result = json_codec.from_json(encoded, list[Address]) + encoded = to_json([Address("a", "b"), Address("c", "d")]) + result = from_json(encoded, list[Address]) assert all(isinstance(item, Address) for item in result) assert result[1] == Address("c", "d") def test_from_json_uses_from_json_hook(): - encoded = json_codec.to_json(Widget("gear", 5)) - result = json_codec.from_json(encoded, Widget) + encoded = to_json(Widget("gear", 5)) + result = from_json(encoded, Widget) assert isinstance(result, Widget) assert result == Widget("gear", 5) def test_from_json_primitive_passthrough_with_expected_type(): - assert json_codec.from_json("42", int) == 42 - assert json_codec.from_json('"hi"', str) == "hi" + assert from_json("42", int) == 42 + assert from_json('"hi"', str) == "hi" def test_from_json_legacy_marker_with_expected_type_strips_and_builds(): # A payload persisted by an older SDK version (with the marker) must still # decode when the caller now passes an expected_type. legacy = json.dumps( - {"street": "1 Main St", "city": "Redmond", json_codec.AUTO_SERIALIZED: True} + {"street": "1 Main St", "city": "Redmond", AUTO_SERIALIZED: True} ) - result = json_codec.from_json(legacy, Address) + result = from_json(legacy, Address) assert isinstance(result, Address) assert result == Address("1 Main St", "Redmond") def test_from_json_none_payload_with_expected_type(): - assert json_codec.from_json("null", Address) is None + assert from_json("null", Address) is None + + +# ----- from_json container recursion (PR #154 follow-up) ----- + + +def test_from_json_coerces_dict_of_dataclasses(): + encoded = to_json({"home": Address("1 Main St", "Redmond")}) + result = from_json(encoded, dict[str, Address]) + assert isinstance(result["home"], Address) + assert result["home"] == Address("1 Main St", "Redmond") + + +def test_from_json_coerces_dict_of_from_json_types(): + encoded = to_json({"a": Widget("gear", 5), "b": Widget("cog", 7)}) + result = from_json(encoded, dict[str, Widget]) + assert result["a"] == Widget("gear", 5) + assert result["b"] == Widget("cog", 7) + + +def test_from_json_dict_without_value_type_is_passthrough(): + encoded = to_json({"a": Address("1 Main St", "Redmond")}) + # Bare ``dict`` (no args) leaves values as parsed JSON. + result = from_json(encoded, dict) + assert result == {"a": {"street": "1 Main St", "city": "Redmond"}} + + +def test_from_json_coerces_fixed_length_tuple(): + encoded = to_json([Address("a", "b"), 5]) + result = from_json(encoded, tuple[Address, int]) + assert isinstance(result, tuple) + assert result[0] == Address("a", "b") + assert result[1] == 5 + + +def test_from_json_coerces_homogeneous_tuple(): + encoded = to_json([Address("a", "b"), Address("c", "d")]) + result = from_json(encoded, tuple[Address, ...]) + assert isinstance(result, tuple) + assert all(isinstance(item, Address) for item in result) + assert result[1] == Address("c", "d") + + +def test_coerce_to_type_fixed_length_tuple_too_long_raises(): + # A JSON array longer than the fixed-length tuple type must fail fast rather + # than silently dropping the trailing element(s). + with pytest.raises(TypeError): + coerce_to_type([1, 2, 3], tuple[int, int]) + + +def test_coerce_to_type_fixed_length_tuple_too_short_raises(): + # A JSON array shorter than the fixed-length tuple type must fail fast rather + # than silently producing a short tuple. + with pytest.raises(TypeError): + coerce_to_type([1], tuple[int, int]) # ----- coerce_to_type ----- @@ -230,28 +384,25 @@ def test_from_json_none_payload_with_expected_type(): def test_coerce_to_type_none_type_returns_value(): value = {"a": 1} - assert json_codec.coerce_to_type(value, None) is value + assert coerce_to_type(value, None) is value def test_coerce_to_type_already_correct_type(): addr = Address("a", "b") - assert json_codec.coerce_to_type(addr, Address) is addr + assert coerce_to_type(addr, Address) is addr def test_coerce_to_type_invalid_conversion_raises(): with pytest.raises(TypeError): - json_codec.coerce_to_type("not-a-number", int) + coerce_to_type("not-a-number", int) def test_coerce_optional_dataclass_coerces_member(): - from typing import Optional - result = json_codec.coerce_to_type({"street": "a", "city": "b"}, Optional[Address]) + result = coerce_to_type({"street": "a", "city": "b"}, Optional[Address]) assert isinstance(result, Address) def test_coerce_genuine_union_leaves_unmatched_value_untouched(): - from typing import Union - @dataclass class A: x: int @@ -263,4 +414,133 @@ class B: # A dict matching neither A nor B by isinstance must be returned unchanged, # not force-coerced into the first union member. value = {"z": 1} - assert json_codec.coerce_to_type(value, Union[A, B]) == {"z": 1} + assert coerce_to_type(value, Union[A, B]) == {"z": 1} + + +def test_coerce_dict_values_recursively(): + result = coerce_to_type({"home": {"street": "a", "city": "b"}}, dict[str, Address]) + assert isinstance(result["home"], Address) + + +# ----- from_json converter hook (PR #154 follow-up) ----- + + +class ConverterAwareWidget: + """A type whose from_json hook reconstructs a nested typed value using the + converter passed to it by the DataConverter.""" + + def __init__(self, address: "Address"): + self.address = address + + def to_json(self) -> dict: + return {"address": self.address} + + @classmethod + def from_json(cls, data: dict, converter) -> "ConverterAwareWidget": + nested = converter.coerce(data["address"], Address) + return cls(nested) + + +def test_from_json_hook_receives_converter_for_nested_reconstruction(): + from durabletask.serialization import JsonDataConverter + + conv = JsonDataConverter() + encoded = conv.serialize(ConverterAwareWidget(Address("1 Main St", "Redmond"))) + result = conv.deserialize(encoded, ConverterAwareWidget) + assert isinstance(result, ConverterAwareWidget) + # The nested value was rebuilt into an Address by the hook via the converter. + assert isinstance(result.address, Address) + assert result.address.city == "Redmond" + + +def test_single_arg_from_json_hook_is_not_passed_converter(): + # A classic single-argument from_json must keep working unchanged. + from durabletask.serialization import JsonDataConverter + + conv = JsonDataConverter() + encoded = conv.serialize(Widget("gear", 5)) + result = conv.deserialize(encoded, Widget) + assert result == Widget("gear", 5) + + +def test_converter_threaded_through_nested_containers(): + # A from_json hook nested inside a dataclass field's list still receives the + # converter so it can recurse. + from durabletask.serialization import JsonDataConverter + + @dataclass + class Catalog: + widgets: list[ConverterAwareWidget] + + conv = JsonDataConverter() + encoded = conv.serialize( + Catalog([ConverterAwareWidget(Address("a", "b"))]) + ) + result = conv.deserialize(encoded, Catalog) + assert isinstance(result.widgets[0], ConverterAwareWidget) + assert isinstance(result.widgets[0].address, Address) + + +def test_coerce_to_type_without_converter_calls_single_arg_hook(): + # Calling the private helper without a converter must not pass one, even if + # the hook could accept it (defensive: no converter available). + result = coerce_to_type({"label": "gear", "size": 5}, Widget) + assert result == Widget("gear", 5) + + +def test_from_json_hook_unrelated_second_param_is_not_treated_as_converter(): + # A ``from_json`` whose second parameter is *not* named ``converter`` (here a + # ``strict`` flag with a default) must not be mistaken for a converter-aware + # hook -- otherwise the DataConverter would be bound to ``strict``. + from durabletask.serialization import JsonDataConverter + + @dataclass + class Flagged: + label: str + + @classmethod + def from_json(cls, value, strict=False): + # If the converter were passed here, ``strict`` would be truthy. + assert strict is False + return cls(value["label"]) + + conv = JsonDataConverter() + result = conv.deserialize('{"label": "ok"}', Flagged) + assert result == Flagged("ok") + + +# ----- forward-reference resolution (Python 3.10 get_type_hints parity) ----- +# +# On Python 3.10, ``typing.get_type_hints`` does not deep-resolve forward +# references nested inside container args (e.g. the element type of +# ``list["TreeNode"]`` on a self-referential dataclass), leaving a bare string +# or ``ForwardRef``. ``_resolve_forward_refs`` restores the 3.11+ behavior so +# nested coercion still fires. These tests run on every supported version. + + +def test_resolve_forward_refs_resolves_string_element_type(): + # ``list["Address"]`` evaluates to a generic whose arg is the raw string + # "Address" -- exactly what 3.10 leaves behind. + resolved = resolve_forward_refs(list["Address"], {"Address": Address}) + assert get_args(resolved) == (Address,) + + +def test_resolve_forward_refs_resolves_forwardref_element_type(): + resolved = resolve_forward_refs(List["Address"], {"Address": Address}) + assert get_args(resolved) == (Address,) + + +def test_resolve_forward_refs_then_coerce_reconstructs_nested_dataclass(): + # The full 3.10 path: resolve the unresolved element type, then coerce the + # contained dicts into the target dataclass. + field_type = resolve_forward_refs(list["Address"], {"Address": Address}) + result = coerce_to_type([{"street": "a", "city": "b"}], field_type) + assert isinstance(result[0], Address) + assert result[0].city == "b" + + +def test_resolve_forward_refs_leaves_unresolvable_name_untouched(): + # An unknown name is left as a string so coercion harmlessly falls back to + # the parsed JSON rather than raising. + resolved = resolve_forward_refs(list["DoesNotExist"], {}) + assert get_args(resolved) == ("DoesNotExist",) diff --git a/tests/durabletask/test_shared.py b/tests/durabletask/test_shared.py new file mode 100644 index 00000000..7cf12413 --- /dev/null +++ b/tests/durabletask/test_shared.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for the deprecated serialization shims in durabletask.internal.shared. + +The JSON codec now lives in ``durabletask.serialization`` and its functions are +private; the supported surface is the pluggable ``DataConverter``. The thin +``shared.to_json`` / ``shared.from_json`` wrappers remain for backwards +compatibility but emit a ``DeprecationWarning``. +""" + +import json +import warnings + +import pytest + +from durabletask.internal import shared + + +def test_shared_to_json_warns_and_serializes(): + with pytest.warns(DeprecationWarning, match="to_json"): + result = shared.to_json({"a": 1}) + assert json.loads(result) == {"a": 1} + + +def test_shared_from_json_warns_and_deserializes(): + with pytest.warns(DeprecationWarning, match="from_json"): + result = shared.from_json('{"a": 1}') + assert result == {"a": 1} + + +def test_shared_from_json_still_accepts_expected_type(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + assert shared.from_json("42", int) == 42 + + +def test_shared_auto_serialized_marker_still_exported(): + # The legacy marker constant remains importable for back-compat. + assert shared.AUTO_SERIALIZED == "__durabletask_autoobject__" diff --git a/tests/durabletask/test_tracing.py b/tests/durabletask/test_tracing.py index 9969afab..81159191 100644 --- a/tests/durabletask/test_tracing.py +++ b/tests/durabletask/test_tracing.py @@ -17,6 +17,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode +from durabletask.serialization import JsonDataConverter import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.tracing as tracing @@ -263,10 +264,10 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) assert ctx._parent_trace_context is None - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) # Create an executionStarted event with parentTraceContext event = pb.HistoryEvent( @@ -290,8 +291,8 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) event = pb.HistoryEvent( eventId=-1, @@ -320,10 +321,10 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) assert ctx._orchestration_trace_context is None - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) event = pb.HistoryEvent( eventId=-1, executionStarted=pb.ExecutionStartedEvent( @@ -359,9 +360,10 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) executor = worker._OrchestrationExecutor( registry, TEST_LOGGER, + data_converter=JsonDataConverter(), persisted_orch_span_id=persisted_span_id, ) @@ -867,7 +869,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_task_completed_event(2, json.dumps(20)), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -904,7 +906,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_task_completed_event(2, json.dumps("ok")), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -942,7 +944,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_timer_fired_event(2, fire_at_2), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -976,7 +978,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("r2")), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -1013,7 +1015,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("ok")), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -1732,7 +1734,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): self._make_traced_task_completed_event(1, json.dumps("result"), timestamp_seconds=200), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ @@ -1783,7 +1785,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): self._make_traced_task_failed_event(1, "boom", timestamp_seconds=250), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ @@ -1848,7 +1850,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): ), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ @@ -1895,7 +1897,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_task_completed_event(2, json.dumps(20)), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ diff --git a/tests/durabletask/test_type_discovery.py b/tests/durabletask/test_type_discovery.py index 0f7a6ec6..065e964d 100644 --- a/tests/durabletask/test_type_discovery.py +++ b/tests/durabletask/test_type_discovery.py @@ -11,6 +11,7 @@ from durabletask import entities, task, worker from durabletask.internal import type_discovery from durabletask.internal.entity_state_shim import StateShim +from durabletask.serialization import JsonDataConverter TEST_LOGGER = logging.getLogger("tests") @@ -36,29 +37,78 @@ def from_json(cls, data: dict[str, Any]) -> "Money": return cls(data["amount"]) -# ----- type_discovery helper ----- +# ----- DataConverter.can_reconstruct ----- class TestIsReconstructable: + """Reconstructability is now a DataConverter responsibility; the default + JsonDataConverter recognizes dataclasses and from_json-capable types.""" + + @property + def conv(self) -> JsonDataConverter: + return JsonDataConverter() + def test_dataclass_is_reconstructable(self): - assert type_discovery.is_reconstructable(Order) is True + assert self.conv.can_reconstruct(Order) is True def test_from_json_type_is_reconstructable(self): - assert type_discovery.is_reconstructable(Money) is True + assert self.conv.can_reconstruct(Money) is True def test_builtins_are_not_reconstructable(self): - assert type_discovery.is_reconstructable(int) is False - assert type_discovery.is_reconstructable(str) is False - assert type_discovery.is_reconstructable(dict) is False + assert self.conv.can_reconstruct(int) is False + assert self.conv.can_reconstruct(str) is False + assert self.conv.can_reconstruct(dict) is False def test_optional_dataclass_is_reconstructable(self): - assert type_discovery.is_reconstructable(Optional[Order]) is True + assert self.conv.can_reconstruct(Optional[Order]) is True def test_list_of_dataclass_is_reconstructable(self): - assert type_discovery.is_reconstructable(list[Order]) is True + assert self.conv.can_reconstruct(list[Order]) is True def test_list_of_builtin_is_not_reconstructable(self): - assert type_discovery.is_reconstructable(list[int]) is False + assert self.conv.can_reconstruct(list[int]) is False + + +class TestCustomConverterReconstructable: + """A custom converter can extend reconstructability to its own types, and the + default's Optional/list recursion consults the override for element types.""" + + def test_override_recognizes_custom_type_and_recurses(self): + class Widget: + pass + + class WidgetConverter(JsonDataConverter): + def can_reconstruct(self, target_type: Any) -> bool: + if isinstance(target_type, type) and issubclass(target_type, Widget): + return True + return super().can_reconstruct(target_type) + + conv = WidgetConverter() + assert conv.can_reconstruct(Widget) is True + # The base Optional/list recursion goes through ``self``, so the override + # is consulted for the element type. + assert conv.can_reconstruct(list[Widget]) is True + assert conv.can_reconstruct(Optional[Widget]) is True + # Builtins remain excluded. + assert conv.can_reconstruct(int) is False + + def test_discovery_uses_supplied_converter(self): + class Widget: + pass + + class WidgetConverter(JsonDataConverter): + def can_reconstruct(self, target_type: Any) -> bool: + if isinstance(target_type, type) and issubclass(target_type, Widget): + return True + return super().can_reconstruct(target_type) + + def act(ctx, w: Widget): + ... + + # The default converter does not recognize Widget... + assert type_discovery.activity_input_type(act) is None + # ...but the custom converter does, so discovery surfaces the type. + assert type_discovery.activity_input_type(act, WidgetConverter()) is Widget class TestInputTypeDiscovery: @@ -155,7 +205,7 @@ def test_string_name_returns_none(self): def _activity_executor(fn) -> tuple[worker._ActivityExecutor, str]: registry = worker._Registry() name = registry.add_activity(fn) - return worker._ActivityExecutor(registry, TEST_LOGGER), name + return worker._ActivityExecutor(registry, TEST_LOGGER, JsonDataConverter()), name def test_activity_input_coerced_to_dataclass(): @@ -227,9 +277,9 @@ def store(ctx: entities.EntityContext, order: Order): registry = worker._Registry() registry.add_entity(store) - executor = worker._EntityExecutor(registry, TEST_LOGGER) + executor = worker._EntityExecutor(registry, TEST_LOGGER, JsonDataConverter()) entity_id = entities.EntityInstanceId("store", "k1") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) assert seen["type"] == "Order" assert seen["item"] == "book" @@ -245,9 +295,9 @@ def save(self, order: Order): registry = worker._Registry() registry.add_entity(Store) - executor = worker._EntityExecutor(registry, TEST_LOGGER) + executor = worker._EntityExecutor(registry, TEST_LOGGER, JsonDataConverter()) entity_id = entities.EntityInstanceId("store", "k1") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) assert seen["type"] == "Order" assert seen["item"] == "book"