From 2df2b209728503f53acf073d7163102e0efe0f23 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 14:15:41 -0600 Subject: [PATCH 1/9] Fix custom serialization gaps from #154 Close several round-tripping gaps left by the type-aware custom serialization work in #154, without introducing new breaking changes versus 1.5.0 or any serialization-related security concerns. Serialize side: - Prefer a to_json() hook over the built-in dataclass / SimpleNamespace handling so a dataclass (or namespace) with a non-serializable field can opt in, mirroring the decode side which already prefers from_json(). - Encode dataclasses via a shallow field mapping instead of dataclasses.asdict(), so nested to_json() hooks are honored and leaf values are not deep-copied. - Serialize enum.Enum values to their underlying .value so non-int enums round-trip (IntEnum already serialized as integers). Deserialize side: - Recurse type-directed reconstruction into dict/Mapping values and tuple elements, in addition to the existing list / Optional / Union / dataclass recursion. - Optionally pass the active DataConverter to a from_json(cls, value, converter) hook so it can rebuild nested typed values the built-in recursion does not cover. Entity state: - Defer deserialization of an entity's wire state until get_state() is called, so the caller's requested type reaches the converter together with the raw payload. Track whether the held value is still the raw serialized string and pass it back through unchanged on persist to avoid double-encoding. - Replace a redundant serialize/deserialize round-trip in the legacy entity event path with converter.coerce(). Module structure / deprecation: - Merge the internal json_codec module into durabletask.serialization and make the codec functions private; the supported surface is the pluggable DataConverter. - Deprecate durabletask.internal.shared.to_json / from_json with a DeprecationWarning; they continue to work for backwards compatibility. Adds a comprehensive JsonDataConverter round-trip test suite plus targeted tests for each fix, and documents intentional limitations (multi-member Union, types needing a custom converter such as datetime/Decimal/set). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 41 +- durabletask/internal/entity_state_shim.py | 68 +- durabletask/internal/json_codec.py | 183 ------ durabletask/internal/shared.py | 54 +- durabletask/serialization.py | 329 +++++++++- durabletask/worker.py | 17 +- .../test_data_converter_roundtrip.py | 584 ++++++++++++++++++ tests/durabletask/test_entity_executor.py | 100 +++ tests/durabletask/test_serialization.py | 305 +++++++-- tests/durabletask/test_shared.py | 40 ++ 10 files changed, 1446 insertions(+), 275 deletions(-) delete mode 100644 durabletask/internal/json_codec.py create mode 100644 tests/durabletask/test_data_converter_roundtrip.py create mode 100644 tests/durabletask/test_shared.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a962b3a..a5aa112 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,10 +20,10 @@ ADDED `call_entity` accept an optional `return_type`, and `wait_for_external_event` accepts an optional `data_type`. When provided, the result/event payload is reconstructed as that type (dataclasses — including nested dataclass, - `Optional`, and `list` fields — and `from_json()`-capable types) and the - returned task is typed accordingly (e.g. `call_activity(..., return_type=Foo)` - yields `CompletableTask[Foo]`). When omitted, the raw deserialized JSON is - returned as before. + `Optional`, `list`, `dict`/`Mapping`, and `tuple` fields — and + `from_json()`-capable types) and the returned task is typed accordingly (e.g. + `call_activity(..., return_type=Foo)` yields `CompletableTask[Foo]`). When + omitted, the raw deserialized JSON is returned as before. - Inbound payloads are reconstructed from function type annotations. When an orchestrator, activity, or entity operation annotates its input parameter (or an activity its return value) with a dataclass or `from_json()`-capable type, @@ -37,6 +37,15 @@ ADDED retained. - Objects exposing a `to_json()` method are now JSON-serializable when passed as activity/orchestrator inputs or outputs. +- `enum.Enum` values now serialize (to their underlying `.value`) and, when a + target type is supplied, deserialize back to the enum member. This covers + string-valued and other non-`int` enums as activity/orchestrator/entity inputs + and outputs, including as dataclass fields and inside `list` / `dict` / + `tuple` containers. (`IntEnum` / `IntFlag` already serialized as integers.) +- A `from_json()` classmethod may now optionally accept the active + `DataConverter` as a second parameter (`from_json(cls, value, converter)`), + letting it reconstruct nested typed values via `converter.coerce(...)` / + `converter.deserialize(...)`. The single-argument form remains supported. - Added `EntityMetadata.get_typed_state(intended_type=...)`, which deserializes the entity's persisted state and reconstructs dataclasses and `from_json()`-capable types. The existing `get_state()` is unchanged: with no @@ -63,11 +72,35 @@ CHANGED FIXED +- A dataclass or `SimpleNamespace` that defines a `to_json()` hook now uses it + when serialized. Previously the built-in dataclass / `SimpleNamespace` + handling ran first, so the hook was ignored — and a dataclass with a field + that was not JSON-serializable on its own would fail to serialize even when it + provided a `to_json()` hook to handle that field. The serialize side now + prefers `to_json()`, mirroring the deserialize side, which already prefers + `from_json()`. +- Nested `to_json()` hooks are now honored when an object is serialized inside a + dataclass. Custom objects (including nested dataclasses with their own + `to_json()`) are now encoded recursively instead of being flattened to their + raw fields, so values that reshape themselves via `to_json()` round-trip + correctly. +- Type-directed deserialization now recurses into `dict`/`Mapping` values and + `tuple` elements, in addition to the existing `list`, `Optional`/`Union`, and + dataclass-field recursion. A `dict[str, Foo]` or `tuple[Foo, ...]` hint now + reconstructs the contained `Foo` values. - Falsy entity states (`0`, `""`, `[]`, `{}`) are no longer dropped when an entity batch is persisted. Previously a falsy current state was treated as "no state" and written as `None`, effectively deleting it; only an actual `None` state now clears the persisted entity state. +DEPRECATED + +- `durabletask.internal.shared.to_json` and `durabletask.internal.shared.from_json` + are deprecated and now emit a `DeprecationWarning`. Use a + `durabletask.serialization.DataConverter` (for example the default + `JsonDataConverter`) instead. The functions continue to work for backwards + compatibility. + BREAKING CHANGES (type-level only — no runtime impact for typical users) These changes do not alter runtime behavior, but because the package ships diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 99f2801..eeab858 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -12,9 +12,31 @@ class StateShim: - def __init__(self, start_state: Any, data_converter: "DataConverter | None" = None): + """In-memory view of an entity's state during a batch. + + The state arriving from the wire is held as its raw serialized JSON string + and is **not** deserialized in the constructor: deserialization is deferred + until :meth:`get_state` is called, so the caller's requested type reaches the + data converter together with the original payload (a custom converter can + then deserialize the string directly into the target type). Once the state + has been read into a Python value or replaced via :meth:`set_state`, it is + held as that live object instead. + + Tracking whether the current value is still the raw serialized string also + lets :meth:`encode_state` pass an unmodified payload straight back to the + wire instead of re-serializing it, which would double-encode the JSON. + """ + + def __init__(self, start_state: Any, data_converter: "DataConverter | None" = None, + *, is_serialized: bool = False): + # ``is_serialized`` marks ``start_state`` as a raw serialized payload + # (the value off the wire) whose deserialization should be deferred. A + # ``None`` state is never treated as serialized. + serialized = is_serialized and start_state is not None self._current_state: Any = start_state + self._current_is_serialized: bool = serialized self._checkpoint_state: Any = start_state + self._checkpoint_is_serialized: bool = serialized self._operation_actions: list[pb.OperationAction] = [] self._actions_checkpoint_state: int = 0 if data_converter is None: @@ -35,13 +57,19 @@ def get_state(self, intended_type: None = None, default: Any = None) -> Any: ... def get_state(self, intended_type: type[TState] | None = None, default: TState | None = None) -> TState | Any | None: - if self._current_state is None and default is not None: + if self._current_state is None: return default - if intended_type is None: - return self._current_state - - coerced = self._data_converter.coerce(self._current_state, intended_type) + if self._current_is_serialized: + # Deferred deserialization: the converter receives the raw payload + # together with the requested type. + if intended_type is None: + return self._data_converter.deserialize(self._current_state) + result = self._data_converter.deserialize(self._current_state, intended_type) + else: + if intended_type is None: + return self._current_state + result = self._data_converter.coerce(self._current_state, intended_type) # An explicit ``intended_type`` is a request to receive that type. The # default converter is best-effort and would silently return the raw @@ -49,17 +77,33 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | # raising when a non-None state could not be coerced to a concrete type. # ``intended_type`` may be a typing generic (e.g. ``list[int]``) at # runtime, which is not a ``type`` instance, so the guard is required. - if (self._current_state is not None - and isinstance(intended_type, type) # pyright: ignore[reportUnnecessaryIsInstance] - and not isinstance(coerced, intended_type)): + if (isinstance(intended_type, type) # pyright: ignore[reportUnnecessaryIsInstance] + and not isinstance(result, intended_type)): raise TypeError( f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" ) - return coerced + return result def set_state(self, state: Any) -> None: + # A value set in-process is a live Python object, not a serialized payload. self._current_state = state + self._current_is_serialized = False + + def encode_state(self) -> str | None: + """Serialize the current state for persistence back to the wire. + + Returns ``None`` only when the state is actually ``None`` (which clears + the persisted entity state). When the current value is still the raw + serialized payload (the state was never modified), it is returned + unchanged to avoid double-encoding; otherwise the live value is + serialized. + """ + if self._current_state is None: + return None + if self._current_is_serialized: + return self._current_state + return self._data_converter.serialize(self._current_state) def add_operation_action(self, action: pb.OperationAction) -> None: self._operation_actions.append(action) @@ -69,14 +113,18 @@ def get_operation_actions(self) -> list[pb.OperationAction]: def commit(self) -> None: self._checkpoint_state = self._current_state + self._checkpoint_is_serialized = self._current_is_serialized self._actions_checkpoint_state = len(self._operation_actions) def rollback(self) -> None: self._current_state = self._checkpoint_state + self._current_is_serialized = self._checkpoint_is_serialized self._operation_actions = self._operation_actions[:self._actions_checkpoint_state] def reset(self) -> None: self._current_state = None + self._current_is_serialized = False self._checkpoint_state = None + self._checkpoint_is_serialized = False self._operation_actions = [] self._actions_checkpoint_state = 0 diff --git a/durabletask/internal/json_codec.py b/durabletask/internal/json_codec.py deleted file mode 100644 index 8fda0ea..0000000 --- a/durabletask/internal/json_codec.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Internal JSON codec for Durable Task payloads. - -This module holds the low-level serialization *mechanism* -- the JSON string -encode/decode primitives and the value-level type coercion used to reconstruct -custom objects. Serialization *policy* (the public, pluggable strategy) lives in -:mod:`durabletask.serialization`; the default ``JsonDataConverter`` is the only -production consumer of ``to_json`` / ``from_json``, while ``coerce_to_type`` is -also used directly by entity state accessors that already hold a parsed value. -""" - -from __future__ import annotations - -import dataclasses -import json -import types -import typing -from collections.abc import Sequence -from types import SimpleNamespace -from typing import Any, cast - -# Marker formerly added to JSON payloads to flag objects for automatic -# deserialization into a SimpleNamespace. New code no longer emits this marker -# (objects are serialized as plain JSON), but the decoder still recognizes it so -# that orchestration histories produced by older SDK versions continue to replay. -AUTO_SERIALIZED = "__durabletask_autoobject__" - - -def to_json(obj: Any) -> str: - """Serialize a value to a JSON string. - - Builtins serialize to plain JSON. Dataclasses, ``SimpleNamespace`` - instances, and objects exposing a ``to_json()`` method are serialized to - plain JSON as well (without any type marker); custom objects can be - reconstructed on the receiving side by passing ``expected_type`` to - :func:`from_json`. - """ - try: - return json.dumps(obj, default=_encode_custom_object) - except TypeError as e: - # Preserve the original error as the cause so serialization failures are - # easier to diagnose, while naming the offending top-level type. - raise TypeError( - f"Failed to serialize object of type '{type(obj).__name__}' to JSON: {e}" - ) from e - - -def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: - """Deserialize a JSON string, optionally coercing the result to a type. - - When ``expected_type`` is ``None`` (the default) the raw parsed JSON is - returned. For backwards compatibility, payloads carrying the legacy - :data:`AUTO_SERIALIZED` marker are reconstructed as ``SimpleNamespace`` - instances so that in-flight orchestrations produced by older SDK versions - continue to replay. - - When ``expected_type`` is provided, the legacy marker (if present) is - stripped and the parsed value is coerced to ``expected_type`` -- dataclasses - are constructed from their dict payloads, types exposing a ``from_json()`` - classmethod are reconstructed via that hook, and ``Optional``/``Union`` and - ``list`` type hints are honored recursively. The destination type is always - supplied by the caller; it is never read from the payload. - """ - if expected_type is None: - return json.loads(json_str, object_hook=_legacy_object_hook) - raw = json.loads(json_str, object_hook=_strip_legacy_marker) - return coerce_to_type(raw, expected_type) - - -def _encode_custom_object(o: Any) -> Any: - """``default`` hook for :func:`json.dumps` that emits plain JSON. - - Called only for values the JSON encoder cannot natively serialize. Note that - namedtuples are handled natively by the encoder (serialized as JSON arrays) - and never reach this hook. - """ - if dataclasses.is_dataclass(o) and not isinstance(o, type): - return dataclasses.asdict(o) - if isinstance(o, SimpleNamespace): - return vars(o) - # Custom objects may opt in via a ``to_json`` hook. It is resolved off the - # type and called with the instance (``type(o).to_json(o)``) so that both - # instance methods and ``@staticmethod`` hooks work -- matching the calling - # convention used by ``azure-functions-durable``. The hook returns a - # JSON-serializable value (a structure or a string), not a JSON document. - to_json_hook = getattr(cast(Any, type(o)), "to_json", None) - if callable(to_json_hook): - return to_json_hook(o) - # This will raise a TypeError describing the unsupported type. - raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") - - -def _legacy_object_hook(d: dict[str, Any]) -> Any: - # If the object carries the legacy marker, deserialize it as a SimpleNamespace. - if d.pop(AUTO_SERIALIZED, False): - return SimpleNamespace(**d) - return d - - -def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: - # Discard the legacy marker so typed coercion sees a plain dict. - d.pop(AUTO_SERIALIZED, None) - return d - - -def coerce_to_type(value: Any, expected_type: Any) -> Any: - """Coerce an already-parsed JSON value to ``expected_type``. - - Handles ``None``/``Optional``/``Union`` and ``list`` type hints recursively, - types exposing a ``from_json()`` classmethod, and dataclasses (including - nested dataclass fields). The destination type is always caller-supplied and - never derived from the payload, keeping deserialization secure. - """ - if expected_type is None or value is None: - return value - - origin = typing.get_origin(expected_type) - if origin is not None: - return _coerce_generic(value, expected_type, origin) - - if not isinstance(expected_type, type): - # Not a concrete, instantiable type (e.g. a typing special form we don't - # special-case) -- return the value unchanged. - return value - - if isinstance(value, expected_type): - return value - - from_json_hook = getattr(expected_type, "from_json", None) - if callable(from_json_hook): - return from_json_hook(value) - - if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): - return _build_dataclass(expected_type, cast(dict[str, Any], value)) - - type_ctor = cast(Any, expected_type) - try: - return type_ctor(value) - except Exception as e: - type_name = getattr(type_ctor, "__name__", None) or str(type_ctor) - raise TypeError( - f"Could not coerce value of type '{type(value).__name__}' to " - f"'{type_name}'" - ) from e - - -def _coerce_generic(value: Any, expected_type: Any, origin: Any) -> Any: - args = typing.get_args(expected_type) - if origin is typing.Union or origin is types.UnionType: - # If the value already matches a member type, keep it as-is. - non_none = [a for a in args if a is not type(None)] - for arg in non_none: - if isinstance(arg, type) and isinstance(value, arg): - return value - # ``Optional[T]`` (exactly one non-None member): coerce to that member. - # For a genuine multi-member ``Union`` where the value matched none of - # the members, leave it untouched rather than guessing the first arg -- - # forcing a coercion there can silently mis-construct the wrong type. - if len(non_none) == 1: - return coerce_to_type(value, non_none[0]) - return value - if origin in (list, Sequence) and isinstance(value, list): - elem_type = args[0] if args else None - return [coerce_to_type(item, elem_type) for item in cast(list[Any], value)] - # Other generics (dict, tuple, ...) are returned as parsed JSON. - return value - - -def _build_dataclass(cls: Any, data: dict[str, Any]) -> Any: - """Construct a dataclass from its dict payload, recursing into typed fields.""" - try: - hints = typing.get_type_hints(cls) - except Exception: - hints = {} - kwargs: dict[str, Any] = {} - for field in dataclasses.fields(cls): - if field.name not in data: - continue - field_type = hints.get(field.name) - kwargs[field.name] = coerce_to_type(data[field.name], field_type) - return cls(**kwargs) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 9ef136d..3708e49 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -2,22 +2,58 @@ # Licensed under the MIT License. import logging +import warnings from collections.abc import Sequence -from typing import TypeAlias +from typing import Any, TypeAlias import grpc import grpc.aio -# Backwards-compatibility re-exports. The JSON codec moved to -# ``durabletask.internal.json_codec``; these aliases keep older imports from -# ``durabletask.internal.shared`` working. -from durabletask.internal.json_codec import ( # noqa: F401 - AUTO_SERIALIZED as AUTO_SERIALIZED, - from_json as from_json, - to_json as to_json, -) +# Backwards-compatibility shims. The JSON codec moved into +# ``durabletask.serialization`` and its functions are now private; the supported +# surface is the pluggable ``DataConverter`` (and the default +# ``JsonDataConverter``). These thin wrappers keep older imports from +# ``durabletask.internal.shared`` working while steering callers to the new API. +# They deliberately reach into the now-private serialization mechanism. +from durabletask import serialization as _serialization from durabletask.grpc_options import GrpcChannelOptions +# Legacy marker constant, re-exported for backwards compatibility. +AUTO_SERIALIZED = _serialization._AUTO_SERIALIZED # pyright: ignore[reportPrivateUsage] + +_SERIALIZATION_DEPRECATION = ( + "durabletask.internal.shared.{name} is deprecated and will be removed in a " + "future release. Use a durabletask.serialization.DataConverter (e.g. the " + "default JsonDataConverter) instead." +) + + +def to_json(obj: Any) -> str: + """Deprecated. Use a ``durabletask.serialization.DataConverter`` instead.""" + warnings.warn( + _SERIALIZATION_DEPRECATION.format(name="to_json"), + DeprecationWarning, + stacklevel=2, + ) + return _serialization._to_json(obj) # pyright: ignore[reportPrivateUsage] + + +def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: + """Deprecated. Use a ``durabletask.serialization.DataConverter`` instead. + + This legacy shim does not thread a ``DataConverter`` into reconstruction, so + a converter-aware ``from_json(cls, value, converter)`` hook is invoked + without the converter (its single-argument form). Call + ``JsonDataConverter().deserialize(...)`` to get the converter-aware path. + """ + warnings.warn( + _SERIALIZATION_DEPRECATION.format(name="from_json"), + DeprecationWarning, + stacklevel=2, + ) + return _serialization._from_json(json_str, expected_type) # pyright: ignore[reportPrivateUsage] + + ClientInterceptor: TypeAlias = ( grpc.UnaryUnaryClientInterceptor | grpc.UnaryStreamClientInterceptor diff --git a/durabletask/serialization.py b/durabletask/serialization.py index b1b469b..e9a72b1 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -10,10 +10,10 @@ values become JSON on the wire and how they are reconstructed on the way back. The default :class:`JsonDataConverter` preserves the SDK's built-in behavior: -builtins serialize as plain JSON, dataclasses / ``SimpleNamespace`` instances -and objects exposing a ``to_json()`` hook serialize to plain JSON structures, -and a caller-supplied ``target_type`` drives reconstruction on the read side -(the destination type is never read from the payload). +builtins serialize as plain JSON; objects exposing a ``to_json()`` hook, then +dataclasses, then ``SimpleNamespace`` instances serialize to plain JSON +structures; and a caller-supplied ``target_type`` drives reconstruction on the +read side (the destination type is never read from the payload). To customize serialization -- for example to validate with pydantic, encode custom ``datetime`` / ``Decimal`` formats, or integrate another model framework @@ -22,18 +22,39 @@ converter = MyDataConverter() worker = TaskHubGrpcWorker(data_converter=converter) client = TaskHubGrpcClient(data_converter=converter) + +This module is the single home for both serialization *policy* (the public, +pluggable :class:`DataConverter` strategy) and the low-level JSON codec +*mechanism* (the private ``_to_json`` / ``_from_json`` / ``_coerce_to_type`` +helpers). The mechanism is intentionally private: the supported, stable surface +is the :class:`DataConverter` abstraction. """ from __future__ import annotations +import dataclasses +import enum +import functools +import inspect +import json import logging +import types +import typing from abc import ABC, abstractmethod -from typing import Any - -from durabletask.internal import json_codec +from collections.abc import Mapping, Sequence +from types import SimpleNamespace +from typing import Any, cast logger = logging.getLogger("durabletask") +# Marker formerly added to JSON payloads to flag objects for automatic +# deserialization into a SimpleNamespace. New code no longer emits this marker +# (objects are serialized as plain JSON), but the decoder still recognizes it so +# that orchestration histories produced by older SDK versions continue to +# replay. Internal detail; re-exported from ``durabletask.internal.shared`` only +# for backwards compatibility. +_AUTO_SERIALIZED = "__durabletask_autoobject__" + class DataConverter(ABC): """Strategy for serializing and deserializing Durable Task payloads. @@ -89,36 +110,54 @@ class JsonDataConverter(DataConverter): classmethod used during type-directed reconstruction. This matches the ``to_json`` / ``from_json`` convention used by ``azure-functions-durable``. + The ``to_json`` hook takes precedence over the built-in dataclass / + ``SimpleNamespace`` handling, and nested values are encoded recursively, so + a dataclass (or any object) with a ``to_json`` hook -- including one nested + inside another dataclass, ``list``, or ``dict`` -- round-trips through its + own hooks. + Deserialization (and value-level :meth:`coerce`) is **best-effort**: when a ``target_type`` is supplied and the value cannot be coerced to it, the raw value is returned (and a debug message is logged) rather than raising. This keeps the core SDK permissive; a stricter, validating converter can be supplied for callers who want coercion failures to surface as errors. + + > [!NOTE] + > Type-directed reconstruction recurses through dataclass fields, + > ``list``/``Sequence``, ``dict``/``Mapping`` values, ``tuple`` elements, + > and ``Optional`` / ``Union`` hints. A type that exposes a custom + > ``from_json()`` classmethod is responsible for reconstructing its own + > nested values. To help, the converter passes itself to ``from_json`` when + > the hook declares a second parameter -- ``from_json(cls, value, converter)`` + > -- so the hook can call ``converter.coerce(nested_value, NestedType)`` (or + > ``converter.deserialize(...)``) to reconstruct nested values the built-in + > recursion does not cover. The SDK never infers nested types from the + > payload; the destination type is always supplied by the caller. """ def serialize(self, value: Any) -> str | None: if value is None: return None - return json_codec.to_json(value) + return _to_json(value) def deserialize(self, data: str | None, target_type: type | None = None) -> Any: if data is None or data == "": return None if target_type is None: - return json_codec.from_json(data) + return _from_json(data) try: - return json_codec.from_json(data, target_type) + return _from_json(data, target_type, converter=self) except Exception as e: # Best-effort: fall back to the raw deserialized value rather than # failing the operation. Logged so the mismatch remains discoverable. self._log_coercion_fallback(target_type, e) - return json_codec.from_json(data) + return _from_json(data) def coerce(self, value: Any, target_type: type | None = None) -> Any: if target_type is None or value is None: return value try: - return json_codec.coerce_to_type(value, target_type) + return _coerce_to_type(value, target_type, converter=self) except Exception as e: self._log_coercion_fallback(target_type, e) return value @@ -134,3 +173,269 @@ def _log_coercion_fallback(target_type: type, error: Exception) -> None: # Shared default instance used when no converter is supplied. DEFAULT_DATA_CONVERTER: DataConverter = JsonDataConverter() + + +# --------------------------------------------------------------------------- +# Private JSON codec mechanism. +# +# These are the low-level encode/decode primitives and the value-level type +# coercion used to reconstruct custom objects. They are deliberately private: +# the supported, pluggable surface is :class:`DataConverter`. ``_coerce_to_type`` +# is also used internally by entity state accessors that already hold a parsed +# value. +# --------------------------------------------------------------------------- + + +def _to_json(obj: Any) -> str: + """Serialize a value to a JSON string. + + Builtins serialize to plain JSON. Objects exposing a ``to_json()`` method, + dataclasses, and ``SimpleNamespace`` instances are serialized to plain JSON + as well (without any type marker), recursively. Custom objects can be + reconstructed on the receiving side by passing ``expected_type`` to + :func:`_from_json`. + """ + try: + return json.dumps(obj, default=_encode_custom_object) + except TypeError as e: + # Preserve the original error as the cause so serialization failures are + # easier to diagnose, while naming the offending top-level type. + raise TypeError( + f"Failed to serialize object of type '{type(obj).__name__}' to JSON: {e}" + ) from e + + +def _from_json( + json_str: str | bytes | bytearray, + expected_type: type | None = None, + converter: DataConverter | None = None, +) -> Any: + """Deserialize a JSON string, optionally coercing the result to a type. + + When ``expected_type`` is ``None`` (the default) the raw parsed JSON is + returned. For backwards compatibility, payloads carrying the legacy + :data:`_AUTO_SERIALIZED` marker are reconstructed as ``SimpleNamespace`` + instances so that in-flight orchestrations produced by older SDK versions + continue to replay. + + When ``expected_type`` is provided, the legacy marker (if present) is + stripped and the parsed value is coerced to ``expected_type`` -- dataclasses + are constructed from their dict payloads, types exposing a ``from_json()`` + classmethod are reconstructed via that hook, and ``Optional``/``Union``, + ``list``, ``dict``, and ``tuple`` type hints are honored recursively. The + destination type is always supplied by the caller; it is never read from the + payload. + + ``converter`` is threaded through to ``from_json`` hooks that declare a + second parameter, letting them recursively reconstruct nested values. + """ + if expected_type is None: + return json.loads(json_str, object_hook=_legacy_object_hook) + raw = json.loads(json_str, object_hook=_strip_legacy_marker) + return _coerce_to_type(raw, expected_type, converter) + + +def _encode_custom_object(o: Any) -> Any: + """``default`` hook for :func:`json.dumps` that emits plain JSON. + + Called only for values the JSON encoder cannot natively serialize. Note that + namedtuples are handled natively by the encoder (serialized as JSON arrays) + and never reach this hook. + """ + # Custom objects may opt in via a ``to_json`` hook. It is resolved off the + # type and called with the instance (``type(o).to_json(o)``) so that both + # instance methods and ``@staticmethod`` hooks work -- matching the calling + # convention used by ``azure-functions-durable``. The hook returns a + # JSON-serializable value (a structure or a string), not a JSON document. + # + # The hook is checked *before* the dataclass / ``SimpleNamespace`` branches + # so a dataclass (or namespace) that needs custom handling -- for example + # one with a field that is not JSON-serializable on its own -- can opt in. + # Resolving off the type (rather than the instance) avoids mistaking a data + # attribute named ``to_json`` for a hook, and mirrors the decode side, which + # prefers ``from_json`` over the dataclass branch. + to_json_hook = getattr(cast(Any, type(o)), "to_json", None) + if callable(to_json_hook): + return to_json_hook(o) + if isinstance(o, enum.Enum): + # Emit the member's underlying value (a primitive). Reconstruction is + # type-directed: passing the enum type to ``deserialize`` rebuilds the + # member via ``EnumType(value)``. ``IntEnum`` / ``IntFlag`` members are + # ints and serialize natively without reaching this hook; this branch + # covers string- and other-valued enums so they round-trip too. Emitting + # only the value (never the member name or type) keeps the wire format a + # plain primitive and avoids leaking the Python type into the payload. + return o.value + if dataclasses.is_dataclass(o) and not isinstance(o, type): + # Return a *shallow* mapping of the dataclass's fields rather than + # ``dataclasses.asdict``. asdict recursively converts nested dataclasses + # to plain dicts (bypassing their ``to_json`` hooks) and deep-copies + # every leaf value. Emitting the live field values lets ``json.dumps`` + # recurse so each nested value is encoded through this same hook, + # honoring nested ``to_json`` hooks and avoiding the deep copy. + return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} + if isinstance(o, SimpleNamespace): + return vars(o) + # This will raise a TypeError describing the unsupported type. + raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") + + +def _legacy_object_hook(d: dict[str, Any]) -> Any: + # If the object carries the legacy marker, deserialize it as a SimpleNamespace. + if d.pop(_AUTO_SERIALIZED, False): + return SimpleNamespace(**d) + return d + + +def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: + # Discard the legacy marker so typed coercion sees a plain dict. + d.pop(_AUTO_SERIALIZED, None) + return d + + +def _coerce_to_type(value: Any, expected_type: Any, converter: DataConverter | None = None) -> Any: + """Coerce an already-parsed JSON value to ``expected_type``. + + Handles ``None``/``Optional``/``Union``, ``list``, ``dict``, and ``tuple`` + type hints recursively, types exposing a ``from_json()`` classmethod, and + dataclasses (including nested dataclass fields). The destination type is + always caller-supplied and never derived from the payload, keeping + deserialization secure. + + ``converter`` is passed to ``from_json`` hooks that declare a second + parameter so they can recursively reconstruct nested values. + """ + if expected_type is None or value is None: + return value + + origin = typing.get_origin(expected_type) + if origin is not None: + return _coerce_generic(value, expected_type, origin, converter) + + if not isinstance(expected_type, type): + # Not a concrete, instantiable type (e.g. a typing special form we don't + # special-case) -- return the value unchanged. + return value + + if isinstance(value, expected_type): + return value + + from_json_hook = getattr(expected_type, "from_json", None) + if callable(from_json_hook): + return _invoke_from_json(from_json_hook, value, converter) + + if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): + return _build_dataclass(expected_type, cast(dict[str, Any], value), converter) + + type_ctor = cast(Any, expected_type) + try: + return type_ctor(value) + except Exception as e: + type_name = getattr(type_ctor, "__name__", None) or str(type_ctor) + raise TypeError( + f"Could not coerce value of type '{type(value).__name__}' to " + f"'{type_name}'" + ) from e + + +def _invoke_from_json(hook: Any, value: Any, converter: DataConverter | None) -> Any: + """Call a ``from_json`` hook, passing the converter when the hook accepts it. + + A hook declared as ``from_json(cls, value)`` (the common, single-argument + convention) is called with just the value. A hook declared as + ``from_json(cls, value, converter)`` additionally receives the active + :class:`DataConverter`, letting it reconstruct nested values recursively via + ``converter.coerce(...)`` / ``converter.deserialize(...)``. + + > [!NOTE] + > ``from_json`` must be a ``@classmethod`` or ``@staticmethod`` -- the hook + > is resolved off the *type* (no instance exists yet during reconstruction). + > A plain instance method would have ``self`` consume the value and is + > unsupported regardless of the converter-arity detection below. + """ + if converter is not None and _hook_accepts_converter(hook): + return hook(value, converter) + return hook(value) + + +@functools.lru_cache(maxsize=2048) +def _hook_accepts_converter(hook: Any) -> bool: + """Return True if a bound ``from_json`` hook can accept a second argument. + + The hook is inspected as accessed off the type (``cls``/``self`` already + bound), so a classmethod ``from_json(cls, value, converter)`` presents as + ``(value, converter)``. Results are cached because reconstruction runs on + hot paths; bound classmethods hash equal across attribute accesses, so the + cache stays effective and bounded. + """ + try: + sig = inspect.signature(hook) + except (TypeError, ValueError): + return False + positional = 0 + for param in sig.parameters.values(): + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD): + positional += 1 + elif param.kind is inspect.Parameter.VAR_POSITIONAL: + return True + return positional >= 2 + + +def _coerce_generic(value: Any, expected_type: Any, origin: Any, + converter: DataConverter | None = None) -> Any: + args = typing.get_args(expected_type) + if origin is typing.Union or origin is types.UnionType: + # If the value already matches a member type, keep it as-is. + non_none = [a for a in args if a is not type(None)] + for arg in non_none: + if isinstance(arg, type) and isinstance(value, arg): + return value + # ``Optional[T]`` (exactly one non-None member): coerce to that member. + # For a genuine multi-member ``Union`` where the value matched none of + # the members, leave it untouched rather than guessing the first arg -- + # forcing a coercion there can silently mis-construct the wrong type. + if len(non_none) == 1: + return _coerce_to_type(value, non_none[0], converter) + return value + if origin in (list, Sequence) and isinstance(value, list): + elem_type = args[0] if args else None + return [_coerce_to_type(item, elem_type, converter) for item in cast(list[Any], value)] + if origin in (dict, Mapping) and isinstance(value, dict): + # JSON object keys are always strings, so only the values can carry a + # reconstructable type. Keys are passed through unchanged. + mapping = cast("dict[Any, Any]", value) + val_type = args[1] if len(args) == 2 else None + if val_type is None: + return mapping + return {k: _coerce_to_type(v, val_type, converter) for k, v in mapping.items()} + if origin is tuple and isinstance(value, list): + # Tuples serialize to JSON arrays, so the parsed value is a list. + if not args: + return tuple(cast(list[Any], value)) + # Homogeneous ``tuple[T, ...]``. + if len(args) == 2 and args[1] is Ellipsis: + return tuple(_coerce_to_type(item, args[0], converter) for item in cast(list[Any], value)) + # Fixed-length ``tuple[T1, T2, ...]`` -- coerce element-wise by position. + return tuple( + _coerce_to_type(item, t, converter) + for item, t in zip(cast(list[Any], value), args) + ) + # Other generics are returned as parsed JSON. + return value + + +def _build_dataclass(cls: Any, data: dict[str, Any], + converter: DataConverter | None = None) -> Any: + """Construct a dataclass from its dict payload, recursing into typed fields.""" + try: + hints = typing.get_type_hints(cls) + except Exception: + hints = {} + kwargs: dict[str, Any] = {} + for field in dataclasses.fields(cls): + if field.name not in data: + continue + field_type = hints.get(field.name) + kwargs[field.name] = _coerce_to_type(data[field.name], field_type, converter) + return cls(**kwargs) diff --git a/durabletask/worker.py b/durabletask/worker.py index d3a3d86..6fda92d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1317,8 +1317,9 @@ def _execute_entity_batch( payload_helpers.deexternalize_payloads(req, self._payload_store) entity_state = StateShim( - self._data_converter.deserialize(req.entityState.value) if req.entityState.value else None, - self._data_converter) + req.entityState.value if req.entityState.value else None, + self._data_converter, + is_serialized=True) instance_id = req.instanceId try: @@ -1376,7 +1377,7 @@ def _execute_entity_batch( batch_result = pb.EntityBatchResult( results=results, actions=entity_state.get_operation_actions(), - entityState=helpers.get_string_value(self._data_converter.serialize(entity_state._current_state)) if entity_state._current_state is not None else None, # pyright: ignore[reportPrivateUsage] + entityState=helpers.get_string_value(entity_state.encode_state()), failureDetails=None, completionToken=completionToken, operationInfos=operation_infos, @@ -2772,12 +2773,12 @@ def _handle_entity_event_raised(self, if not ph.is_empty(event.eventRaised.input): # TODO: Investigate why the event result is wrapped in a dict with "result" key # The expected type applies to the unwrapped result value, not the - # transport wrapper. Unwrap first, then route the inner value back - # through the converter so custom converters and the expected type - # both apply. + # transport wrapper. Unwrap first, then coerce the already-parsed + # inner value to the expected type via the converter (no redundant + # re-serialization round-trip). unwrapped = self._data_converter.deserialize(event.eventRaised.input.value)["result"] - result = self._data_converter.deserialize( - self._data_converter.serialize(unwrapped), + result = self._data_converter.coerce( + unwrapped, entity_task._expected_type, # pyright: ignore[reportPrivateUsage] ) if is_lock_event: diff --git a/tests/durabletask/test_data_converter_roundtrip.py b/tests/durabletask/test_data_converter_roundtrip.py new file mode 100644 index 0000000..465aeb3 --- /dev/null +++ b/tests/durabletask/test_data_converter_roundtrip.py @@ -0,0 +1,584 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Comprehensive round-trip tests for the default ``JsonDataConverter``. + +These exercise ``serialize`` -> ``deserialize(..., target_type)`` through the +*public* converter API (not the private codec) across a broad matrix of object +shapes a user might reasonably hand to an orchestrator/activity/entity boundary: +plain dataclasses, deeply nested dataclasses, dataclasses with ``to_json`` / +``from_json`` hooks (including nested), nested non-dataclass custom classes, +enums, containers (``list`` / ``dict`` / ``tuple``), ``Optional`` / ``Union``, +recursive structures, and a set of types the SDK intentionally does **not** +auto-serialize. + +The intent is to lock in the "secure, minimal-effort, intuitive object +round-tripping" contract and to document -- via xfail/raises -- exactly where a +user must supply a ``to_json`` hook or a custom ``DataConverter``. +""" + +from __future__ import annotations + +import enum +import typing +from dataclasses import dataclass, field +from datetime import datetime +from decimal import Decimal + +import pytest + +from durabletask.serialization import JsonDataConverter + + +@pytest.fixture +def conv() -> JsonDataConverter: + return JsonDataConverter() + + +def _round_trip(conv: JsonDataConverter, value, target_type): + """Serialize then deserialize ``value`` back into ``target_type``.""" + return conv.deserialize(conv.serialize(value), target_type) + + +# --------------------------------------------------------------------------- +# Plain and nested dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class Address: + street: str + city: str + + +@dataclass +class Person: + name: str + age: int + address: Address | None = None + + +def test_plain_dataclass(conv): + value = Address("1 Main St", "Redmond") + assert _round_trip(conv, value, Address) == value + + +def test_nested_dataclass(conv): + value = Person("Ada", 30, Address("1 Main St", "Redmond")) + result = _round_trip(conv, value, Person) + assert result == value + assert isinstance(result.address, Address) + + +def test_optional_dataclass_field_present(conv): + value = Person("Ada", 30, Address("1 Main St", "Redmond")) + assert _round_trip(conv, value, Person).address == Address("1 Main St", "Redmond") + + +def test_optional_dataclass_field_absent(conv): + value = Person("Ada", 30, None) + assert _round_trip(conv, value, Person).address is None + + +@dataclass +class _L3: + v: int + + +@dataclass +class _L2: + l3: _L3 + + +@dataclass +class _L1: + l2: _L2 + + +def test_three_level_nested_dataclass(conv): + value = _L1(_L2(_L3(7))) + result = _round_trip(conv, value, _L1) + assert result == value + assert isinstance(result.l2.l3, _L3) + + +def test_frozen_dataclass(conv): + @dataclass(frozen=True) + class Frozen: + x: int + y: int + + value = Frozen(1, 2) + assert _round_trip(conv, value, Frozen) == value + + +def test_empty_dataclass(conv): + @dataclass + class Empty: + pass + + assert isinstance(_round_trip(conv, Empty(), Empty), Empty) + + +def test_dataclass_with_builtin_shadowing_field_names(conv): + @dataclass + class Shadow: + type: str + id: int + + value = Shadow("widget", 1) + assert _round_trip(conv, value, Shadow) == value + + +def test_dataclass_default_factory_list_of_dataclasses(conv): + @dataclass + class Bag: + items: list[Address] = field(default_factory=list) + + value = Bag([Address("a", "b")]) + result = _round_trip(conv, value, Bag) + assert result == value + assert isinstance(result.items[0], Address) + + +# --------------------------------------------------------------------------- +# Dataclasses with to_json / from_json hooks (incl. nested) +# --------------------------------------------------------------------------- + + +@dataclass +class Temperature: + """Dataclass whose JSON shape differs from its field layout.""" + + celsius: float + + def to_json(self) -> dict: + return {"fahrenheit": self.celsius * 9 / 5 + 32} + + @classmethod + def from_json(cls, data: dict) -> "Temperature": + return cls((data["fahrenheit"] - 32) * 5 / 9) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Temperature) and other.celsius == self.celsius + + +def test_dataclass_with_hooks(conv): + value = Temperature(100.0) + encoded = conv.serialize(value) + assert encoded == '{"fahrenheit": 212.0}' # hook shape, not raw field + assert conv.deserialize(encoded, Temperature) == value + + +def test_nested_dataclass_hook_inside_dataclass(conv): + @dataclass + class Reading: + temp: Temperature + note: str + + value = Reading(Temperature(37.0), "ok") + encoded = conv.serialize(value) + assert '"fahrenheit"' in encoded + result = conv.deserialize(encoded, Reading) + assert result.temp == Temperature(37.0) + assert result.note == "ok" + + +def test_list_of_hooked_dataclasses(conv): + value = [Temperature(0.0), Temperature(100.0)] + result = _round_trip(conv, value, list[Temperature]) + assert result == value + + +def test_dict_of_hooked_dataclasses(conv): + value = {"a": Temperature(0.0)} + result = _round_trip(conv, value, dict[str, Temperature]) + assert result["a"] == Temperature(0.0) + + +# --------------------------------------------------------------------------- +# Nested non-dataclass custom classes (with hooks) +# --------------------------------------------------------------------------- + + +class Money: + """A field type that is not JSON-serializable on its own.""" + + def __init__(self, cents: int): + self.cents = cents + + def to_json(self) -> dict: + return {"cents": self.cents} + + @classmethod + def from_json(cls, data: dict) -> "Money": + return cls(data["cents"]) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Money) and other.cents == self.cents + + +def test_dataclass_with_non_serializable_field_uses_hook(conv): + @dataclass + class Invoice: + amount: Money + + def __eq__(self, other: object) -> bool: + return isinstance(other, Invoice) and other.amount == self.amount + + value = Invoice(Money(500)) + result = _round_trip(conv, value, Invoice) + assert result == value + assert isinstance(result.amount, Money) + + +def test_standalone_custom_class_with_hooks(conv): + value = Money(250) + assert _round_trip(conv, value, Money) == value + + +def test_hook_returning_scalar_string(conv): + class Tag: + def __init__(self, name: str): + self.name = name + + def to_json(self) -> str: + return self.name + + @classmethod + def from_json(cls, data: str) -> "Tag": + return cls(data) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Tag) and other.name == self.name + + value = Tag("vip") + assert conv.serialize(value) == '"vip"' + assert _round_trip(conv, value, Tag) == value + + +def test_dict_of_custom_hooked_class(conv): + value = {"usd": Money(100), "eur": Money(200)} + result = _round_trip(conv, value, dict[str, Money]) + assert result == value + + +# --------------------------------------------------------------------------- +# Converter-aware from_json (recursive reconstruction of nested typed values) +# --------------------------------------------------------------------------- + + +class Ledger: + """A type whose from_json uses the converter to rebuild a nested value.""" + + def __init__(self, owner: str, balance: Money): + self.owner = owner + self.balance = balance + + def to_json(self) -> dict: + return {"owner": self.owner, "balance": self.balance} + + @classmethod + def from_json(cls, data: dict, converter) -> "Ledger": + return cls(data["owner"], converter.coerce(data["balance"], Money)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Ledger) + and other.owner == self.owner + and other.balance == self.balance + ) + + +def test_converter_aware_from_json(conv): + value = Ledger("ada", Money(999)) + result = _round_trip(conv, value, Ledger) + assert result == value + assert isinstance(result.balance, Money) + + +def test_converter_aware_from_json_nested_in_dataclass(conv): + @dataclass + class Account: + ledgers: list[Ledger] + + value = Account([Ledger("ada", Money(1))]) + result = _round_trip(conv, value, Account) + assert isinstance(result.ledgers[0], Ledger) + assert isinstance(result.ledgers[0].balance, Money) + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class Color(enum.IntEnum): + RED = 1 + GREEN = 2 + + +class Status(enum.Enum): + OPEN = "open" + CLOSED = "closed" + + +def test_int_enum_round_trip(conv): + assert conv.serialize(Color.RED) == "1" + assert conv.deserialize("1", Color) is Color.RED + + +def test_str_enum_round_trip(conv): + assert conv.serialize(Status.OPEN) == '"open"' + assert conv.deserialize('"open"', Status) is Status.OPEN + + +def test_enum_as_dataclass_field(conv): + @dataclass + class Ticket: + status: Status + color: Color + + value = Ticket(Status.OPEN, Color.RED) + result = _round_trip(conv, value, Ticket) + assert result.status is Status.OPEN + assert result.color is Color.RED + + +def test_list_of_enums(conv): + value = [Status.OPEN, Status.CLOSED] + result = _round_trip(conv, value, list[Status]) + assert result == value + + +def test_optional_enum_field(conv): + @dataclass + class MaybeStatus: + status: Status | None = None + + assert _round_trip(conv, MaybeStatus(Status.CLOSED), MaybeStatus).status is Status.CLOSED + assert _round_trip(conv, MaybeStatus(None), MaybeStatus).status is None + + +# --------------------------------------------------------------------------- +# Containers +# --------------------------------------------------------------------------- + + +def test_list_of_dataclasses(conv): + value = [Address("a", "b"), Address("c", "d")] + result = _round_trip(conv, value, list[Address]) + assert result == value + + +def test_dict_of_dataclasses(conv): + value = {"home": Address("a", "b")} + result = _round_trip(conv, value, dict[str, Address]) + assert isinstance(result["home"], Address) + + +def test_dict_of_optional_dataclasses(conv): + value = {"a": Address("x", "y"), "b": None} + result = _round_trip(conv, value, dict[str, typing.Optional[Address]]) + assert isinstance(result["a"], Address) + assert result["b"] is None + + +def test_fixed_length_tuple(conv): + value = (Address("a", "b"), 5) + result = _round_trip(conv, value, tuple[Address, int]) + assert isinstance(result, tuple) + assert result[0] == Address("a", "b") + assert result[1] == 5 + + +def test_homogeneous_tuple(conv): + value = (Address("a", "b"), Address("c", "d")) + result = _round_trip(conv, value, tuple[Address, ...]) + assert isinstance(result, tuple) + assert all(isinstance(item, Address) for item in result) + + +def test_tuple_field_in_dataclass(conv): + @dataclass + class Pair: + pair: tuple[Address, Address] + + value = Pair((Address("a", "b"), Address("c", "d"))) + result = _round_trip(conv, value, Pair) + assert isinstance(result.pair, tuple) + assert result.pair[0] == Address("a", "b") + + +def test_dataclass_with_list_of_nested_dataclasses(conv): + @dataclass + class Team: + members: list[Person] + + value = Team([Person("A", 1, Address("x", "y"))]) + result = _round_trip(conv, value, Team) + assert isinstance(result.members[0], Person) + assert isinstance(result.members[0].address, Address) + + +def test_dataclass_with_dict_of_custom_class(conv): + @dataclass + class Wallet: + funds: dict[str, Money] + + value = Wallet({"usd": Money(100)}) + result = _round_trip(conv, value, Wallet) + assert isinstance(result.funds["usd"], Money) + assert result.funds["usd"] == Money(100) + + +# --------------------------------------------------------------------------- +# Recursive structures +# --------------------------------------------------------------------------- + + +@dataclass +class TreeNode: + value: int + children: list["TreeNode"] = field(default_factory=list) + + +def test_recursive_tree_dataclass(conv): + value = TreeNode(1, [TreeNode(2), TreeNode(3, [TreeNode(4)])]) + result = _round_trip(conv, value, TreeNode) + assert result == value + assert isinstance(result.children[1].children[0], TreeNode) + + +# --------------------------------------------------------------------------- +# Documented limitation: nested typed reconstruction needs resolvable hints. +# +# Type-directed reconstruction relies on ``typing.get_type_hints`` to read a +# dataclass's nested field types. A dataclass defined inside a function whose +# fields reference *other function-local* types cannot be resolved (the names +# are not in the module globals), so nested coercion is skipped and the field +# is returned as parsed JSON. Module-level dataclasses (the normal case) work. +# --------------------------------------------------------------------------- + + +def test_function_local_nested_dataclass_hint_is_not_resolvable(conv): + @dataclass + class Inner: + v: int + + @dataclass + class Outer: + inner: Inner + + result = _round_trip(conv, Outer(Inner(7)), Outer) + # The outer dataclass is rebuilt, but the inner field stays a plain dict + # because its type hint can't be resolved from a function scope. + assert isinstance(result, Outer) + assert result.inner == {"v": 7} + + +# --------------------------------------------------------------------------- +# Builtins / primitives +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("value,target", [ + (42, int), + ("hello", str), + (3.14, float), + (True, bool), + (None, type(None)), +]) +def test_primitive_round_trip(conv, value, target): + assert conv.deserialize(conv.serialize(value), target) == value + + +def test_list_of_primitives(conv): + assert _round_trip(conv, [1, 2, 3], list[int]) == [1, 2, 3] + + +def test_dict_of_primitives(conv): + assert _round_trip(conv, {"a": 1, "b": 2}, dict[str, int]) == {"a": 1, "b": 2} + + +def test_no_target_type_returns_plain_json(conv): + # Without a target type, custom objects come back as plain JSON. + encoded = conv.serialize(Address("a", "b")) + assert conv.deserialize(encoded) == {"street": "a", "city": "b"} + + +# --------------------------------------------------------------------------- +# Multi-member Union: documented limitation (not force-coerced) +# --------------------------------------------------------------------------- + + +def test_multi_member_union_field_is_not_force_coerced(conv): + @dataclass + class A: + a: int + + @dataclass + class B: + b: int + + @dataclass + class HasUnion: + val: typing.Union[A, B] + + # A multi-member Union cannot be disambiguated from the payload alone, so the + # value is intentionally left as parsed JSON rather than guessing a member. + # This documents the current, secure behavior. + value = HasUnion(A(1)) + result = _round_trip(conv, value, HasUnion) + assert result.val == {"a": 1} + + +# --------------------------------------------------------------------------- +# Types the SDK intentionally does NOT auto-serialize. +# +# These require a user-supplied ``to_json`` hook or a custom ``DataConverter``. +# The tests pin the current behavior (a clear TypeError) so a future change that +# adds support is a deliberate, reviewed decision. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("value", [ + datetime(2020, 1, 1, 12, 0, 0), + Decimal("1.5"), + {1, 2, 3}, + frozenset({1, 2}), + b"bytes", +]) +def test_unsupported_types_raise_clear_typeerror(conv, value): + with pytest.raises(TypeError) as exc_info: + conv.serialize(value) + assert type(value).__name__ in str(exc_info.value) + + +def test_plain_object_without_hook_raises(conv): + class Plain: + def __init__(self): + self.a = 1 + + with pytest.raises(TypeError): + conv.serialize(Plain()) + + +def test_custom_datetime_via_to_json_hook(conv): + # The supported way to round-trip a datetime is a to_json/from_json hook. + class Timestamp: + def __init__(self, dt: datetime): + self.dt = dt + + def to_json(self) -> str: + return self.dt.isoformat() + + @classmethod + def from_json(cls, data: str) -> "Timestamp": + return cls(datetime.fromisoformat(data)) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Timestamp) and other.dt == self.dt + + value = Timestamp(datetime(2020, 1, 1, 12, 0, 0)) + assert _round_trip(conv, value, Timestamp) == value diff --git a/tests/durabletask/test_entity_executor.py b/tests/durabletask/test_entity_executor.py index e853f13..116ead4 100644 --- a/tests/durabletask/test_entity_executor.py +++ b/tests/durabletask/test_entity_executor.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. """Unit tests for the _EntityExecutor class in durabletask.worker.""" +import json import logging from durabletask import entities @@ -188,3 +189,102 @@ def test_get_state_invalid_coercion_raises(self): state = StateShim("not-an-int") with pytest.raises(TypeError): state.get_state(int) + + +class TestStateShimDeferredDeserialization: + """Tests for StateShim deferring deserialization of the raw wire payload.""" + + def test_constructor_does_not_deserialize_serialized_state(self): + # A serialized payload is held verbatim until read, not eagerly parsed. + state = StateShim('{"value": 7}', is_serialized=True) + assert state._current_state == '{"value": 7}' + + def test_get_state_defers_deserialization_with_type(self): + from dataclasses import dataclass + + @dataclass + class Counter: + value: int + + state = StateShim('{"value": 7}', is_serialized=True) + result = state.get_state(Counter) + assert isinstance(result, Counter) + assert result.value == 7 + + def test_get_state_no_type_returns_parsed_value(self): + state = StateShim('{"value": 7}', is_serialized=True) + assert state.get_state() == {"value": 7} + + def test_deferred_deserialization_passes_raw_string_to_converter(self): + # A custom converter receives the original serialized string together + # with the requested type, rather than a pre-parsed value. + from typing import Any + + from durabletask.serialization import DataConverter + + seen: dict[str, Any] = {} + + class RecordingConverter(DataConverter): + def serialize(self, value: Any) -> str | None: + return None if value is None else json.dumps(value) + + def deserialize(self, data, target_type=None): + seen["data"] = data + seen["target_type"] = target_type + return {"parsed": True} + + def coerce(self, value, target_type=None): + seen["coerced"] = True + return value + + state = StateShim('{"x": 1}', RecordingConverter(), is_serialized=True) + state.get_state(dict) + assert seen["data"] == '{"x": 1}' + assert seen["target_type"] is dict + assert "coerced" not in seen + + def test_encode_state_passes_through_unmodified_payload(self): + # An unread/unmodified serialized payload is returned verbatim, never + # re-serialized (which would double-encode the JSON string). + state = StateShim('{"value": 7}', is_serialized=True) + assert state.encode_state() == '{"value": 7}' + + def test_reading_does_not_trigger_double_encoding(self): + state = StateShim('{"value": 7}', is_serialized=True) + # Reading (even with a type) must not turn the payload into a live value + # that would be re-serialized into a JSON-encoded string. + state.get_state() + state.get_state(dict) + encoded = state.encode_state() + assert encoded == '{"value": 7}' + assert json.loads(encoded) == {"value": 7} + + def test_encode_state_serializes_live_value_after_set_state(self): + state = StateShim('{"value": 7}', is_serialized=True) + state.set_state({"value": 8}) + encoded = state.encode_state() + assert json.loads(encoded) == {"value": 8} + + def test_encode_state_none_when_state_is_none(self): + state = StateShim(None, is_serialized=True) + assert state.encode_state() is None + + def test_commit_preserves_serialized_flag(self): + state = StateShim('{"value": 7}', is_serialized=True) + state.commit() + # After commit, the (unmodified) state still round-trips without + # double-encoding. + assert state.encode_state() == '{"value": 7}' + + def test_rollback_restores_serialized_flag(self): + state = StateShim('{"value": 7}', is_serialized=True) + state.commit() + state.set_state({"value": 99}) + state.rollback() + assert state.encode_state() == '{"value": 7}' + + def test_falsy_serialized_state_is_not_dropped(self): + # A serialized falsy value (e.g. 0) is preserved, not treated as cleared. + state = StateShim("0", is_serialized=True) + assert state.get_state(int) == 0 + assert state.encode_state() == "0" diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index 33003c2..9a9659a 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -1,16 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Unit tests for the JSON serialization codec in durabletask.internal.json_codec.""" +"""Unit tests for the private JSON codec in durabletask.serialization. + +These exercise the low-level encode/decode mechanism directly. The supported, +public surface is the ``DataConverter`` abstraction (see +``test_data_converter.py``). +""" import json from collections import namedtuple from dataclasses import dataclass from types import SimpleNamespace +from typing import Optional, Union import pytest -from durabletask.internal import json_codec +from durabletask.serialization import _AUTO_SERIALIZED as AUTO_SERIALIZED +from durabletask.serialization import _coerce_to_type as coerce_to_type +from durabletask.serialization import _from_json as from_json +from durabletask.serialization import _to_json as to_json # ----- Test fixtures ----- @@ -78,58 +87,58 @@ def __eq__(self, other: object) -> bool: def test_to_json_static_hook_receives_instance(): # type(obj).to_json(obj) must invoke the @staticmethod with the instance. - assert json_codec.to_json(StaticWidget("gizmo")) == '"gizmo"' + assert to_json(StaticWidget("gizmo")) == '"gizmo"' def test_static_hook_round_trips_with_expected_type(): - encoded = json_codec.to_json(StaticWidget("gizmo")) - result = json_codec.from_json(encoded, StaticWidget) + encoded = to_json(StaticWidget("gizmo")) + result = from_json(encoded, StaticWidget) assert isinstance(result, StaticWidget) assert result == StaticWidget("gizmo") def test_instance_to_json_hook_receives_instance(): # The same type(obj).to_json(obj) path works for plain instance methods. - assert json.loads(json_codec.to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} + assert json.loads(to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} # ----- to_json ----- def test_to_json_builtins_are_plain_json(): - assert json_codec.to_json({"a": 1, "b": [1, 2, 3]}) == json.dumps({"a": 1, "b": [1, 2, 3]}) - assert json_codec.to_json("hello") == '"hello"' - assert json_codec.to_json(42) == "42" + assert to_json({"a": 1, "b": [1, 2, 3]}) == json.dumps({"a": 1, "b": [1, 2, 3]}) + assert to_json("hello") == '"hello"' + assert to_json(42) == "42" def test_to_json_dataclass_emits_plain_dict_without_marker(): - encoded = json_codec.to_json(Address("1 Main St", "Redmond")) + encoded = to_json(Address("1 Main St", "Redmond")) parsed = json.loads(encoded) assert parsed == {"street": "1 Main St", "city": "Redmond"} - assert json_codec.AUTO_SERIALIZED not in encoded + assert AUTO_SERIALIZED not in encoded def test_to_json_nested_dataclass_has_no_marker(): - encoded = json_codec.to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) - assert json_codec.AUTO_SERIALIZED not in encoded + encoded = to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) + assert AUTO_SERIALIZED not in encoded parsed = json.loads(encoded) assert parsed["address"] == {"street": "1 Main St", "city": "Redmond"} def test_to_json_simplenamespace_emits_plain_dict(): - encoded = json_codec.to_json(SimpleNamespace(a=1, b="two")) - assert json_codec.AUTO_SERIALIZED not in encoded + encoded = to_json(SimpleNamespace(a=1, b="two")) + assert AUTO_SERIALIZED not in encoded assert json.loads(encoded) == {"a": 1, "b": "two"} def test_to_json_custom_object_uses_to_json_hook(): - encoded = json_codec.to_json(Widget("gear", 5)) + encoded = to_json(Widget("gear", 5)) assert json.loads(encoded) == {"label": "gear", "size": 5} def test_to_json_namedtuple_serializes_as_array(): # Without an expected_type the field names are not preserved on the wire. - assert json_codec.to_json(Point(1, 2)) == "[1, 2]" + assert to_json(Point(1, 2)) == "[1, 2]" def test_to_json_unserializable_raises_typeerror_with_cause(): @@ -137,33 +146,123 @@ class NotSerializable: pass with pytest.raises(TypeError) as exc_info: - json_codec.to_json(NotSerializable()) + to_json(NotSerializable()) assert "NotSerializable" in str(exc_info.value) assert exc_info.value.__cause__ is not None +# ----- to_json: hook precedence and nested-hook recursion (PR #154 follow-up) ----- + + +@dataclass +class Temperature: + """Dataclass whose JSON shape differs from its field layout.""" + + celsius: float + + def to_json(self) -> dict: + return {"fahrenheit": self.celsius * 9 / 5 + 32} + + @classmethod + def from_json(cls, data: dict) -> "Temperature": + return cls((data["fahrenheit"] - 32) * 5 / 9) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Temperature) and other.celsius == self.celsius + + +class Money: + """A field type that is not JSON-serializable on its own.""" + + def __init__(self, cents: int): + self.cents = cents + + def __eq__(self, other: object) -> bool: + return isinstance(other, Money) and other.cents == self.cents + + +@dataclass +class Invoice: + """Dataclass with a non-serializable field, rescued by a to_json hook.""" + + amount: Money + + def to_json(self) -> dict: + return {"amount_cents": self.amount.cents} + + @classmethod + def from_json(cls, data: dict) -> "Invoice": + return cls(Money(data["amount_cents"])) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Invoice) and other.amount == self.amount + + +def test_dataclass_to_json_hook_takes_precedence_over_fields(): + # The dataclass defines to_json, so its output -- not the raw fields -- wins. + assert json.loads(to_json(Temperature(100.0))) == {"fahrenheit": 212.0} + + +def test_dataclass_with_non_serializable_field_uses_to_json_hook(): + # Previously the dataclass branch ran first and asdict() hit the + # non-serializable Money field, raising even though to_json was defined. + assert json.loads(to_json(Invoice(Money(500)))) == {"amount_cents": 500} + + +def test_dataclass_with_non_serializable_field_round_trips(): + encoded = to_json(Invoice(Money(500))) + assert from_json(encoded, Invoice) == Invoice(Money(500)) + + +def test_nested_dataclass_hook_is_honored(): + # The nested Temperature must serialize through its own to_json, not be + # flattened to its raw fields by dataclasses.asdict. + @dataclass + class Reading: + temp: Temperature + + encoded = to_json(Reading(Temperature(100.0))) + assert json.loads(encoded) == {"temp": {"fahrenheit": 212.0}} + + +def test_nested_dataclass_hook_round_trips(): + @dataclass + class Reading: + temp: Temperature + + encoded = to_json(Reading(Temperature(37.0))) + result = from_json(encoded, Reading) + assert isinstance(result, Reading) + assert result.temp == Temperature(37.0) + + +def test_simplenamespace_without_hook_still_emits_fields(): + # SimpleNamespace has no to_json, so it falls through to vars(). + assert json.loads(to_json(SimpleNamespace(x=1))) == {"x": 1} + + # ----- from_json without expected_type (loose mode) ----- def test_from_json_returns_raw_without_expected_type(): - assert json_codec.from_json('{"a": 1}') == {"a": 1} - assert json_codec.from_json("[1, 2, 3]") == [1, 2, 3] - assert json_codec.from_json("42") == 42 + assert from_json('{"a": 1}') == {"a": 1} + assert from_json("[1, 2, 3]") == [1, 2, 3] + assert from_json("42") == 42 def test_from_json_legacy_marker_decodes_to_simplenamespace(): - legacy = json.dumps({"a": 1, "b": 2, json_codec.AUTO_SERIALIZED: True}) - result = json_codec.from_json(legacy) + legacy = json.dumps({"a": 1, "b": 2, AUTO_SERIALIZED: True}) + result = from_json(legacy) assert isinstance(result, SimpleNamespace) assert result.a == 1 assert result.b == 2 def test_legacy_simplenamespace_reserializes_cleanly(): - legacy = json.dumps({"a": 1, json_codec.AUTO_SERIALIZED: True}) - ns = json_codec.from_json(legacy) - reencoded = json_codec.to_json(ns) - assert json_codec.AUTO_SERIALIZED not in reencoded + legacy = json.dumps({"a": 1, AUTO_SERIALIZED: True}) + ns = from_json(legacy) + reencoded = to_json(ns) + assert AUTO_SERIALIZED not in reencoded assert json.loads(reencoded) == {"a": 1} @@ -171,58 +270,98 @@ def test_legacy_simplenamespace_reserializes_cleanly(): def test_from_json_coerces_to_dataclass(): - encoded = json_codec.to_json(Address("1 Main St", "Redmond")) - result = json_codec.from_json(encoded, Address) + encoded = to_json(Address("1 Main St", "Redmond")) + result = from_json(encoded, Address) assert isinstance(result, Address) assert result == Address("1 Main St", "Redmond") def test_from_json_coerces_nested_dataclass(): - encoded = json_codec.to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) - result = json_codec.from_json(encoded, Person) + encoded = to_json(Person("Ada", 30, Address("1 Main St", "Redmond"))) + result = from_json(encoded, Person) assert isinstance(result, Person) assert isinstance(result.address, Address) assert result.address.city == "Redmond" def test_from_json_coerces_optional_dataclass_when_present(): - result = json_codec.from_json('{"name": "Ada", "age": 30, "address": null}', Person) + result = from_json('{"name": "Ada", "age": 30, "address": null}', Person) assert isinstance(result, Person) assert result.address is None def test_from_json_coerces_list_of_dataclasses(): - encoded = json_codec.to_json([Address("a", "b"), Address("c", "d")]) - result = json_codec.from_json(encoded, list[Address]) + encoded = to_json([Address("a", "b"), Address("c", "d")]) + result = from_json(encoded, list[Address]) assert all(isinstance(item, Address) for item in result) assert result[1] == Address("c", "d") def test_from_json_uses_from_json_hook(): - encoded = json_codec.to_json(Widget("gear", 5)) - result = json_codec.from_json(encoded, Widget) + encoded = to_json(Widget("gear", 5)) + result = from_json(encoded, Widget) assert isinstance(result, Widget) assert result == Widget("gear", 5) def test_from_json_primitive_passthrough_with_expected_type(): - assert json_codec.from_json("42", int) == 42 - assert json_codec.from_json('"hi"', str) == "hi" + assert from_json("42", int) == 42 + assert from_json('"hi"', str) == "hi" def test_from_json_legacy_marker_with_expected_type_strips_and_builds(): # A payload persisted by an older SDK version (with the marker) must still # decode when the caller now passes an expected_type. legacy = json.dumps( - {"street": "1 Main St", "city": "Redmond", json_codec.AUTO_SERIALIZED: True} + {"street": "1 Main St", "city": "Redmond", AUTO_SERIALIZED: True} ) - result = json_codec.from_json(legacy, Address) + result = from_json(legacy, Address) assert isinstance(result, Address) assert result == Address("1 Main St", "Redmond") def test_from_json_none_payload_with_expected_type(): - assert json_codec.from_json("null", Address) is None + assert from_json("null", Address) is None + + +# ----- from_json container recursion (PR #154 follow-up) ----- + + +def test_from_json_coerces_dict_of_dataclasses(): + encoded = to_json({"home": Address("1 Main St", "Redmond")}) + result = from_json(encoded, dict[str, Address]) + assert isinstance(result["home"], Address) + assert result["home"] == Address("1 Main St", "Redmond") + + +def test_from_json_coerces_dict_of_from_json_types(): + encoded = to_json({"a": Widget("gear", 5), "b": Widget("cog", 7)}) + result = from_json(encoded, dict[str, Widget]) + assert result["a"] == Widget("gear", 5) + assert result["b"] == Widget("cog", 7) + + +def test_from_json_dict_without_value_type_is_passthrough(): + encoded = to_json({"a": Address("1 Main St", "Redmond")}) + # Bare ``dict`` (no args) leaves values as parsed JSON. + result = from_json(encoded, dict) + assert result == {"a": {"street": "1 Main St", "city": "Redmond"}} + + +def test_from_json_coerces_fixed_length_tuple(): + encoded = to_json([Address("a", "b"), 5]) + result = from_json(encoded, tuple[Address, int]) + assert isinstance(result, tuple) + assert result[0] == Address("a", "b") + assert result[1] == 5 + + +def test_from_json_coerces_homogeneous_tuple(): + encoded = to_json([Address("a", "b"), Address("c", "d")]) + result = from_json(encoded, tuple[Address, ...]) + assert isinstance(result, tuple) + assert all(isinstance(item, Address) for item in result) + assert result[1] == Address("c", "d") # ----- coerce_to_type ----- @@ -230,28 +369,25 @@ def test_from_json_none_payload_with_expected_type(): def test_coerce_to_type_none_type_returns_value(): value = {"a": 1} - assert json_codec.coerce_to_type(value, None) is value + assert coerce_to_type(value, None) is value def test_coerce_to_type_already_correct_type(): addr = Address("a", "b") - assert json_codec.coerce_to_type(addr, Address) is addr + assert coerce_to_type(addr, Address) is addr def test_coerce_to_type_invalid_conversion_raises(): with pytest.raises(TypeError): - json_codec.coerce_to_type("not-a-number", int) + coerce_to_type("not-a-number", int) def test_coerce_optional_dataclass_coerces_member(): - from typing import Optional - result = json_codec.coerce_to_type({"street": "a", "city": "b"}, Optional[Address]) + result = coerce_to_type({"street": "a", "city": "b"}, Optional[Address]) assert isinstance(result, Address) def test_coerce_genuine_union_leaves_unmatched_value_untouched(): - from typing import Union - @dataclass class A: x: int @@ -263,4 +399,75 @@ class B: # A dict matching neither A nor B by isinstance must be returned unchanged, # not force-coerced into the first union member. value = {"z": 1} - assert json_codec.coerce_to_type(value, Union[A, B]) == {"z": 1} + assert coerce_to_type(value, Union[A, B]) == {"z": 1} + + +def test_coerce_dict_values_recursively(): + result = coerce_to_type({"home": {"street": "a", "city": "b"}}, dict[str, Address]) + assert isinstance(result["home"], Address) + + +# ----- from_json converter hook (PR #154 follow-up) ----- + + +class ConverterAwareWidget: + """A type whose from_json hook reconstructs a nested typed value using the + converter passed to it by the DataConverter.""" + + def __init__(self, address: "Address"): + self.address = address + + def to_json(self) -> dict: + return {"address": self.address} + + @classmethod + def from_json(cls, data: dict, converter) -> "ConverterAwareWidget": + nested = converter.coerce(data["address"], Address) + return cls(nested) + + +def test_from_json_hook_receives_converter_for_nested_reconstruction(): + from durabletask.serialization import JsonDataConverter + + conv = JsonDataConverter() + encoded = conv.serialize(ConverterAwareWidget(Address("1 Main St", "Redmond"))) + result = conv.deserialize(encoded, ConverterAwareWidget) + assert isinstance(result, ConverterAwareWidget) + # The nested value was rebuilt into an Address by the hook via the converter. + assert isinstance(result.address, Address) + assert result.address.city == "Redmond" + + +def test_single_arg_from_json_hook_is_not_passed_converter(): + # A classic single-argument from_json must keep working unchanged. + from durabletask.serialization import JsonDataConverter + + conv = JsonDataConverter() + encoded = conv.serialize(Widget("gear", 5)) + result = conv.deserialize(encoded, Widget) + assert result == Widget("gear", 5) + + +def test_converter_threaded_through_nested_containers(): + # A from_json hook nested inside a dataclass field's list still receives the + # converter so it can recurse. + from durabletask.serialization import JsonDataConverter + + @dataclass + class Catalog: + widgets: list[ConverterAwareWidget] + + conv = JsonDataConverter() + encoded = conv.serialize( + Catalog([ConverterAwareWidget(Address("a", "b"))]) + ) + result = conv.deserialize(encoded, Catalog) + assert isinstance(result.widgets[0], ConverterAwareWidget) + assert isinstance(result.widgets[0].address, Address) + + +def test_coerce_to_type_without_converter_calls_single_arg_hook(): + # Calling the private helper without a converter must not pass one, even if + # the hook could accept it (defensive: no converter available). + result = coerce_to_type({"label": "gear", "size": 5}, Widget) + assert result == Widget("gear", 5) diff --git a/tests/durabletask/test_shared.py b/tests/durabletask/test_shared.py new file mode 100644 index 0000000..7cf1241 --- /dev/null +++ b/tests/durabletask/test_shared.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for the deprecated serialization shims in durabletask.internal.shared. + +The JSON codec now lives in ``durabletask.serialization`` and its functions are +private; the supported surface is the pluggable ``DataConverter``. The thin +``shared.to_json`` / ``shared.from_json`` wrappers remain for backwards +compatibility but emit a ``DeprecationWarning``. +""" + +import json +import warnings + +import pytest + +from durabletask.internal import shared + + +def test_shared_to_json_warns_and_serializes(): + with pytest.warns(DeprecationWarning, match="to_json"): + result = shared.to_json({"a": 1}) + assert json.loads(result) == {"a": 1} + + +def test_shared_from_json_warns_and_deserializes(): + with pytest.warns(DeprecationWarning, match="from_json"): + result = shared.from_json('{"a": 1}') + assert result == {"a": 1} + + +def test_shared_from_json_still_accepts_expected_type(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + assert shared.from_json("42", int) == 42 + + +def test_shared_auto_serialized_marker_still_exported(): + # The legacy marker constant remains importable for back-compat. + assert shared.AUTO_SERIALIZED == "__durabletask_autoobject__" From 3cd893912ffd093c4e2bd91ffd51268ddc02f280 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 15:12:04 -0600 Subject: [PATCH 2/9] Address PR feedback --- CHANGELOG.md | 8 ++ durabletask/internal/entity_state_shim.py | 101 +++++++++++----------- durabletask/serialization.py | 11 ++- tests/durabletask/test_entity_executor.py | 4 +- tests/durabletask/test_serialization.py | 14 +++ 5 files changed, 82 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5aa112..6f040f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,14 @@ CHANGED across an upgrade. - JSON serialization failures now raise a `TypeError` that chains the original error (`__cause__`) and names the offending type. +- `EntityContext.get_state()` / `DurableEntity.get_state()` now return a freshly + reconstructed value on every call rather than a reference to a single cached + object. As a result, mutating a value returned by `get_state()` in place no + longer affects the persisted entity state — write the change back with + `set_state()` to persist it. The entity's state is also serialized eagerly at + `set_state()` time, so a value that cannot be serialized surfaces the error + inside the failing operation (which rolls back) instead of after the batch has + run. FIXED diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index eeab858..9993876 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -14,35 +14,50 @@ class StateShim: """In-memory view of an entity's state during a batch. - The state arriving from the wire is held as its raw serialized JSON string - and is **not** deserialized in the constructor: deserialization is deferred - until :meth:`get_state` is called, so the caller's requested type reaches the - data converter together with the original payload (a custom converter can - then deserialize the string directly into the target type). Once the state - has been read into a Python value or replaced via :meth:`set_state`, it is - held as that live object instead. - - Tracking whether the current value is still the raw serialized string also - lets :meth:`encode_state` pass an unmodified payload straight back to the - wire instead of re-serializing it, which would double-encode the JSON. + The state is held internally as its serialized JSON string at all times. + The raw payload off the wire is stored verbatim; a live value supplied via + :meth:`set_state` (or as a non-serialized constructor argument) is + serialized immediately. Keeping a single, always-serialized representation + has two consequences worth noting: + + * Deserialization is deferred to :meth:`get_state`, so the caller's + requested type reaches the data converter together with the original + payload (a custom converter can deserialize the string directly into the + target type), and the unmodified wire payload is handed back by + :meth:`encode_state` without being re-encoded. + * Serialization errors surface inside the failing operation (at + :meth:`set_state`) rather than after the batch has run, so a bad write + rolls back just that operation. + + Because the held value is always the serialized form, :meth:`get_state` + returns a freshly reconstructed object on every call; it does **not** return + a reference to a stored live object. Mutating a value read from + :meth:`get_state` therefore has no effect on the persisted state unless it + is written back with :meth:`set_state`. """ def __init__(self, start_state: Any, data_converter: "DataConverter | None" = None, *, is_serialized: bool = False): - # ``is_serialized`` marks ``start_state`` as a raw serialized payload - # (the value off the wire) whose deserialization should be deferred. A - # ``None`` state is never treated as serialized. - serialized = is_serialized and start_state is not None - self._current_state: Any = start_state - self._current_is_serialized: bool = serialized - self._checkpoint_state: Any = start_state - self._checkpoint_is_serialized: bool = serialized - self._operation_actions: list[pb.OperationAction] = [] - self._actions_checkpoint_state: int = 0 if data_converter is None: from durabletask.serialization import JsonDataConverter data_converter = JsonDataConverter() self._data_converter = data_converter + # The state is normalized to its serialized string form. ``is_serialized`` + # marks ``start_state`` as a raw payload already off the wire (stored + # verbatim); otherwise a live value is serialized now. ``None`` stays + # ``None`` (no persisted state). + serialized_start = self._serialize(start_state, is_serialized) + self._current_state: str | None = serialized_start + self._checkpoint_state: str | None = serialized_start + self._operation_actions: list[pb.OperationAction] = [] + self._actions_checkpoint_state: int = 0 + + def _serialize(self, state: Any, is_serialized: bool = False) -> str | None: + if state is None: + return None + if is_serialized: + return state + return self._data_converter.serialize(state) @overload def get_state(self, intended_type: type[TState], default: TState) -> TState: @@ -60,16 +75,11 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | if self._current_state is None: return default - if self._current_is_serialized: - # Deferred deserialization: the converter receives the raw payload - # together with the requested type. - if intended_type is None: - return self._data_converter.deserialize(self._current_state) - result = self._data_converter.deserialize(self._current_state, intended_type) - else: - if intended_type is None: - return self._current_state - result = self._data_converter.coerce(self._current_state, intended_type) + # Deferred deserialization: the converter receives the raw payload + # together with the requested type. + if intended_type is None: + return self._data_converter.deserialize(self._current_state) + result = self._data_converter.deserialize(self._current_state, intended_type) # An explicit ``intended_type`` is a request to receive that type. The # default converter is best-effort and would silently return the raw @@ -80,30 +90,25 @@ def get_state(self, intended_type: type[TState] | None = None, default: TState | if (isinstance(intended_type, type) # pyright: ignore[reportUnnecessaryIsInstance] and not isinstance(result, intended_type)): raise TypeError( - f"Could not convert state of type '{type(self._current_state).__name__}' to '{intended_type.__name__}'" + f"Could not convert state of type '{type(result).__name__}' to '{intended_type.__name__}'" ) return result def set_state(self, state: Any) -> None: - # A value set in-process is a live Python object, not a serialized payload. - self._current_state = state - self._current_is_serialized = False + # Serialize eagerly so the held value is always the wire form and any + # serialization error surfaces here, inside the failing operation. + self._current_state = self._serialize(state) def encode_state(self) -> str | None: - """Serialize the current state for persistence back to the wire. + """Return the serialized current state for persistence back to the wire. - Returns ``None`` only when the state is actually ``None`` (which clears - the persisted entity state). When the current value is still the raw - serialized payload (the state was never modified), it is returned - unchanged to avoid double-encoding; otherwise the live value is - serialized. + The state is already held in serialized form, so this is the stored + value verbatim: ``None`` when there is no state (which clears the + persisted entity state), otherwise the JSON string. No re-encoding + occurs, so a payload that was never modified round-trips unchanged. """ - if self._current_state is None: - return None - if self._current_is_serialized: - return self._current_state - return self._data_converter.serialize(self._current_state) + return self._current_state def add_operation_action(self, action: pb.OperationAction) -> None: self._operation_actions.append(action) @@ -113,18 +118,14 @@ def get_operation_actions(self) -> list[pb.OperationAction]: def commit(self) -> None: self._checkpoint_state = self._current_state - self._checkpoint_is_serialized = self._current_is_serialized self._actions_checkpoint_state = len(self._operation_actions) def rollback(self) -> None: self._current_state = self._checkpoint_state - self._current_is_serialized = self._checkpoint_is_serialized self._operation_actions = self._operation_actions[:self._actions_checkpoint_state] def reset(self) -> None: self._current_state = None - self._current_is_serialized = False self._checkpoint_state = None - self._checkpoint_is_serialized = False self._operation_actions = [] self._actions_checkpoint_state = 0 diff --git a/durabletask/serialization.py b/durabletask/serialization.py index e9a72b1..623d695 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -417,10 +417,13 @@ def _coerce_generic(value: Any, expected_type: Any, origin: Any, if len(args) == 2 and args[1] is Ellipsis: return tuple(_coerce_to_type(item, args[0], converter) for item in cast(list[Any], value)) # Fixed-length ``tuple[T1, T2, ...]`` -- coerce element-wise by position. - return tuple( - _coerce_to_type(item, t, converter) - for item, t in zip(cast(list[Any], value), args) - ) + arr = cast(list[Any], value) + if len(arr) != len(args): + raise TypeError( + f"Could not coerce JSON array of length {len(arr)} to " + f"tuple of length {len(args)}" + ) + return tuple(_coerce_to_type(item, t, converter) for item, t in zip(arr, args)) # Other generics are returned as parsed JSON. return value diff --git a/tests/durabletask/test_entity_executor.py b/tests/durabletask/test_entity_executor.py index 116ead4..08d2827 100644 --- a/tests/durabletask/test_entity_executor.py +++ b/tests/durabletask/test_entity_executor.py @@ -269,14 +269,14 @@ def test_encode_state_none_when_state_is_none(self): state = StateShim(None, is_serialized=True) assert state.encode_state() is None - def test_commit_preserves_serialized_flag(self): + def test_commit_preserves_unmodified_payload(self): state = StateShim('{"value": 7}', is_serialized=True) state.commit() # After commit, the (unmodified) state still round-trips without # double-encoding. assert state.encode_state() == '{"value": 7}' - def test_rollback_restores_serialized_flag(self): + def test_rollback_restores_unmodified_payload(self): state = StateShim('{"value": 7}', is_serialized=True) state.commit() state.set_state({"value": 99}) diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index 9a9659a..d42674e 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -364,6 +364,20 @@ def test_from_json_coerces_homogeneous_tuple(): assert result[1] == Address("c", "d") +def test_coerce_to_type_fixed_length_tuple_too_long_raises(): + # A JSON array longer than the fixed-length tuple type must fail fast rather + # than silently dropping the trailing element(s). + with pytest.raises(TypeError): + coerce_to_type([1, 2, 3], tuple[int, int]) + + +def test_coerce_to_type_fixed_length_tuple_too_short_raises(): + # A JSON array shorter than the fixed-length tuple type must fail fast rather + # than silently producing a short tuple. + with pytest.raises(TypeError): + coerce_to_type([1], tuple[int, int]) + + # ----- coerce_to_type ----- From 384dafd94608aa9ffef40b40d3641dd0eb33cba6 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 15:16:18 -0600 Subject: [PATCH 3/9] Fix annotation discovery in 3.10 --- durabletask/serialization.py | 60 ++++++++++++++++++++++++- tests/durabletask/test_serialization.py | 40 ++++++++++++++++- 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/durabletask/serialization.py b/durabletask/serialization.py index 623d695..5413371 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -38,6 +38,7 @@ import inspect import json import logging +import sys import types import typing from abc import ABC, abstractmethod @@ -435,10 +436,67 @@ def _build_dataclass(cls: Any, data: dict[str, Any], hints = typing.get_type_hints(cls) except Exception: hints = {} + globalns = _type_namespace(cls) kwargs: dict[str, Any] = {} for field in dataclasses.fields(cls): if field.name not in data: continue - field_type = hints.get(field.name) + # ``get_type_hints`` on Python 3.10 does not deep-resolve forward + # references nested inside container args (e.g. the ``"TreeNode"`` in + # ``list["TreeNode"]`` on a self-referential dataclass), leaving a bare + # string or ``ForwardRef`` that the coercion below would skip. Resolve + # them against the class's defining module so reconstruction behaves the + # same as it does on 3.11+. + field_type = _resolve_forward_refs(hints.get(field.name), globalns) kwargs[field.name] = _coerce_to_type(data[field.name], field_type, converter) return cls(**kwargs) + + +def _type_namespace(cls: Any) -> dict[str, Any]: + """Build the namespace used to resolve forward references in ``cls``'s hints. + + Forward references in a class's annotations are resolved against the + module in which the class is defined, plus the class's own name (so a + self-referential type like ``list["TreeNode"]`` resolves). + """ + module = sys.modules.get(getattr(cls, "__module__", None) or "") + ns: dict[str, Any] = dict(getattr(module, "__dict__", {})) + name = getattr(cls, "__name__", None) + if name: + ns.setdefault(name, cls) + return ns + + +def _resolve_forward_refs(tp: Any, globalns: dict[str, Any]) -> Any: + """Resolve string / ``ForwardRef`` leaves in a type hint, recursing into args. + + Returns ``tp`` unchanged when it (or a nested name) cannot be resolved, so an + unresolvable hint simply falls back to "leave the value as parsed JSON" + rather than raising. Only the supported generic shapes (``Union``, ``list``, + ``dict``, ``tuple``, etc.) are rebuilt; the destination type is still + entirely caller-supplied, so this does not weaken the security model. + """ + if isinstance(tp, str): + try: + tp = eval(tp, globalns) # noqa: S307 - resolves caller-authored annotations + except Exception: + return tp + elif isinstance(tp, typing.ForwardRef): + try: + tp = eval(tp.__forward_arg__, globalns) # noqa: S307 + except Exception: + return tp + + origin = typing.get_origin(tp) + if origin is None: + return tp + args = typing.get_args(tp) + if not args: + return tp + resolved = [_resolve_forward_refs(a, globalns) for a in args] + if origin is typing.Union or origin is types.UnionType: + return typing.Union[tuple(resolved)] + try: + return origin[tuple(resolved) if len(resolved) > 1 else resolved[0]] + except TypeError: + return tp diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index d42674e..d5a9ecf 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -12,13 +12,14 @@ from collections import namedtuple from dataclasses import dataclass from types import SimpleNamespace -from typing import Optional, Union +from typing import List, Optional, Union, get_args import pytest from durabletask.serialization import _AUTO_SERIALIZED as AUTO_SERIALIZED from durabletask.serialization import _coerce_to_type as coerce_to_type from durabletask.serialization import _from_json as from_json +from durabletask.serialization import _resolve_forward_refs as resolve_forward_refs from durabletask.serialization import _to_json as to_json @@ -485,3 +486,40 @@ def test_coerce_to_type_without_converter_calls_single_arg_hook(): # the hook could accept it (defensive: no converter available). result = coerce_to_type({"label": "gear", "size": 5}, Widget) assert result == Widget("gear", 5) + + +# ----- forward-reference resolution (Python 3.10 get_type_hints parity) ----- +# +# On Python 3.10, ``typing.get_type_hints`` does not deep-resolve forward +# references nested inside container args (e.g. the element type of +# ``list["TreeNode"]`` on a self-referential dataclass), leaving a bare string +# or ``ForwardRef``. ``_resolve_forward_refs`` restores the 3.11+ behavior so +# nested coercion still fires. These tests run on every supported version. + + +def test_resolve_forward_refs_resolves_string_element_type(): + # ``list["Address"]`` evaluates to a generic whose arg is the raw string + # "Address" -- exactly what 3.10 leaves behind. + resolved = resolve_forward_refs(list["Address"], {"Address": Address}) + assert get_args(resolved) == (Address,) + + +def test_resolve_forward_refs_resolves_forwardref_element_type(): + resolved = resolve_forward_refs(List["Address"], {"Address": Address}) + assert get_args(resolved) == (Address,) + + +def test_resolve_forward_refs_then_coerce_reconstructs_nested_dataclass(): + # The full 3.10 path: resolve the unresolved element type, then coerce the + # contained dicts into the target dataclass. + field_type = resolve_forward_refs(list["Address"], {"Address": Address}) + result = coerce_to_type([{"street": "a", "city": "b"}], field_type) + assert isinstance(result[0], Address) + assert result[0].city == "b" + + +def test_resolve_forward_refs_leaves_unresolvable_name_untouched(): + # An unknown name is left as a string so coercion harmlessly falls back to + # the parsed JSON rather than raising. + resolved = resolve_forward_refs(list["DoesNotExist"], {}) + assert get_args(resolved) == ("DoesNotExist",) From 6be2af1082025b60ff8252512951bd955cd8cab5 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 15:48:32 -0600 Subject: [PATCH 4/9] Add pydantic example, fix reconstructibility concern --- CHANGELOG.md | 8 + durabletask/internal/type_discovery.py | 86 ++++----- durabletask/serialization.py | 50 ++++++ durabletask/worker.py | 8 +- examples/README.md | 6 + examples/custom_data_converter/README.md | 165 ++++++++++++++++++ .../custom_data_converter/requirements.txt | 5 + .../custom_data_converter/src/__init__.py | 2 + examples/custom_data_converter/src/app.py | 90 ++++++++++ .../custom_data_converter/src/converter.py | 106 +++++++++++ .../custom_data_converter/src/workflows.py | 119 +++++++++++++ .../custom_data_converter/test/__init__.py | 2 + .../test/test_custom_converter.py | 148 ++++++++++++++++ tests/durabletask/test_type_discovery.py | 68 +++++++- 14 files changed, 800 insertions(+), 63 deletions(-) create mode 100644 examples/custom_data_converter/README.md create mode 100644 examples/custom_data_converter/requirements.txt create mode 100644 examples/custom_data_converter/src/__init__.py create mode 100644 examples/custom_data_converter/src/app.py create mode 100644 examples/custom_data_converter/src/converter.py create mode 100644 examples/custom_data_converter/src/workflows.py create mode 100644 examples/custom_data_converter/test/__init__.py create mode 100644 examples/custom_data_converter/test/test_custom_converter.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f040f9..be82789 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,14 @@ ADDED `DataConverter` as a second parameter (`from_json(cls, value, converter)`), letting it reconstruct nested typed values via `converter.coerce(...)` / `converter.deserialize(...)`. The single-argument form remains supported. +- `DataConverter` now exposes an overridable `is_reconstructable(target_type)` + method that controls which annotated input/return types the SDK reconstructs + on the inbound path. A custom converter can override it to recognize its own + types (for example `pydantic.BaseModel` subclasses), so that orchestrator / + activity / entity inputs annotated with those types are reconstructed by the + converter instead of arriving as raw JSON. The default behavior is unchanged + (dataclasses and `from_json()`-capable types, plus `Optional` / `list` + wrappers, are reconstructable; builtins are not). - Added `EntityMetadata.get_typed_state(intended_type=...)`, which deserializes the entity's persisted state and reconstructs dataclasses and `from_json()`-capable types. The existing `get_state()` is unchanged: with no diff --git a/durabletask/internal/type_discovery.py b/durabletask/internal/type_discovery.py index 58fd6f8..cc92404 100644 --- a/durabletask/internal/type_discovery.py +++ b/durabletask/internal/type_discovery.py @@ -5,15 +5,18 @@ These helpers resolve the annotation of the *input* parameter of an orchestrator, activity, or entity function so that inbound payloads can be -reconstructed into the annotated custom type (a dataclass or a type exposing a -``from_json()`` classmethod) without the caller having to pass an explicit type. +reconstructed into the annotated custom type without the caller having to pass +an explicit type. Discovery is intentionally conservative: it only returns an annotation when the -target is a *reconstructable* custom type (a dataclass, a ``from_json()``-capable -type, or an ``Optional`` / ``list`` wrapping one). Primitive and unknown -annotations resolve to ``None`` so that existing payloads are passed through -unchanged -- inbound type discovery never invokes an arbitrary constructor on -untrusted data, and never alters the value for builtins. +active :class:`~durabletask.serialization.DataConverter` reports it as +*reconstructable* via :meth:`DataConverter.is_reconstructable`. The default +converter recognizes dataclasses, ``from_json()``-capable types, and ``Optional`` +/ ``list`` hints wrapping them; a custom converter can recognize its own types +(e.g. ``pydantic.BaseModel``). Primitive and unknown annotations resolve to +``None`` so that existing payloads are passed through unchanged -- inbound type +discovery never invokes an arbitrary constructor on untrusted data, and never +alters the value for builtins. All public helpers swallow exceptions and return ``None`` on failure; the caller treats ``None`` as "no type information available" and uses the raw payload. @@ -21,38 +24,17 @@ from __future__ import annotations -import collections.abc -import dataclasses import functools import inspect -import types import typing -from typing import Any, Callable, cast +from typing import Any, Callable +from durabletask.serialization import DEFAULT_DATA_CONVERTER, DataConverter -def is_reconstructable(annotation: Any) -> bool: - """Return True if ``annotation`` names a custom type we can rebuild. - Reconstructable targets are dataclasses, types exposing a callable - ``from_json``, and ``Optional`` / ``list`` hints wrapping such types. - Builtins (``int``, ``str``, ``dict``, ...) and unknown annotations are not - reconstructable and resolve to ``False``. - """ - origin = typing.get_origin(annotation) - if origin is not None: - args = typing.get_args(annotation) - if origin is typing.Union or origin is types.UnionType: - return any( - is_reconstructable(a) for a in args if a is not type(None) - ) - if origin in (list, collections.abc.Sequence): - return any(is_reconstructable(a) for a in args) - return False - if not isinstance(annotation, type): - return False - if dataclasses.is_dataclass(annotation): - return True - return callable(getattr(cast(Any, annotation), "from_json", None)) +def _resolve_converter(converter: DataConverter | None) -> DataConverter: + """Return the supplied converter, or the shared default when ``None``.""" + return converter if converter is not None else DEFAULT_DATA_CONVERTER # Bounded so a worker that registers dynamically-created functions or closures @@ -72,14 +54,15 @@ def _resolved_hints(fn: Callable[..., Any]) -> dict[str, Any] | None: return None -def _input_annotation(fn: Callable[..., Any], position: int) -> Any | None: +def _input_annotation(fn: Callable[..., Any], position: int, + converter: DataConverter | None = None) -> Any | None: """Return the resolved annotation of the positional parameter at ``position``. ``position`` is the zero-based index among positional parameters (so the ``input`` parameter of a ``(ctx, input)`` function is at position 1, and the ``input`` parameter of an unbound ``(self, input)`` entity method is also at position 1). Returns ``None`` when the parameter is absent, unannotated, or - its annotation is not a reconstructable custom type. + its annotation is not reconstructable by ``converter``. """ try: sig = inspect.signature(fn) @@ -105,27 +88,29 @@ def _input_annotation(fn: Callable[..., Any], position: int) -> Any | None: if annotation is inspect.Parameter.empty or annotation is Any: return None - return annotation if is_reconstructable(annotation) else None + return annotation if _resolve_converter(converter).is_reconstructable(annotation) else None -def orchestrator_input_type(fn: Callable[..., Any]) -> Any | None: +def orchestrator_input_type(fn: Callable[..., Any], + converter: DataConverter | None = None) -> Any | None: """Discover the input type of an orchestrator function ``(ctx, input)``.""" - return _input_annotation(fn, 1) + return _input_annotation(fn, 1, converter) -def activity_input_type(fn: Callable[..., Any]) -> Any | None: +def activity_input_type(fn: Callable[..., Any], + converter: DataConverter | None = None) -> Any | None: """Discover the input type of an activity function ``(ctx, input)``.""" - return _input_annotation(fn, 1) + return _input_annotation(fn, 1, converter) -def activity_output_type(fn: Any) -> Any | None: +def activity_output_type(fn: Any, converter: DataConverter | None = None) -> Any | None: """Discover the return type of an activity function. - Returns the resolved return annotation when it names a reconstructable - custom type (a dataclass or a ``from_json()``-capable type, optionally - wrapped in ``Optional`` / ``list``). Returns ``None`` for plain callables - that are not annotated with such a type, for string activity names, or when - the annotation cannot be resolved. + Returns the resolved return annotation when ``converter`` reports it as + reconstructable (the default converter recognizes a dataclass or a + ``from_json()``-capable type, optionally wrapped in ``Optional`` / ``list``). + Returns ``None`` for plain callables that are not annotated with such a type, + for string activity names, or when the annotation cannot be resolved. """ if not callable(fn): return None @@ -144,10 +129,11 @@ def activity_output_type(fn: Any) -> Any | None: if annotation is inspect.Signature.empty or annotation is Any or annotation is None: return None - return annotation if is_reconstructable(annotation) else None + return annotation if _resolve_converter(converter).is_reconstructable(annotation) else None -def entity_input_type(fn: Any, operation: str) -> Any | None: +def entity_input_type(fn: Any, operation: str, + converter: DataConverter | None = None) -> Any | None: """Discover the input type of an entity operation. For class-based entities (a ``DurableEntity`` subclass) the operation is a @@ -160,5 +146,5 @@ def entity_input_type(fn: Any, operation: str) -> Any | None: if method is None or not callable(method): return None # Unbound method includes ``self`` at position 0, so ``input`` is at 1. - return _input_annotation(method, 1) - return _input_annotation(fn, 1) + return _input_annotation(method, 1, converter) + return _input_annotation(fn, 1, converter) diff --git a/durabletask/serialization.py b/durabletask/serialization.py index 5413371..19cfb6d 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -101,6 +101,30 @@ def coerce(self, value: Any, target_type: type | None = None) -> Any: """ ... + def is_reconstructable(self, target_type: Any) -> bool: + """Return True if this converter can rebuild ``target_type`` from a payload. + + Inbound type-discovery calls this to decide whether a function's + annotated *input* type (or an activity's *return* annotation) should be + passed to :meth:`deserialize` / :meth:`coerce`. When it returns ``False`` + the SDK passes the raw deserialized payload through unchanged -- this + gate is what stops the SDK from invoking an arbitrary constructor on a + builtin or otherwise unrecognized annotation. + + The default recognizes the types the built-in codec can rebuild -- + dataclasses and ``from_json()``-capable types, plus ``Optional`` / + ``list`` / ``Sequence`` hints wrapping them -- and excludes builtins + (``int``, ``str``, ``dict``, ...) and unknown annotations. + + Override this to teach the SDK about a custom converter's own types (for + example ``pydantic.BaseModel`` subclasses) so that inputs annotated with + them are reconstructed instead of arriving as raw JSON. The default + implementation recurses through ``self.is_reconstructable``, so an + override is also consulted for the element types of ``Optional`` / + ``list`` hints (e.g. ``list[MyModel]``). + """ + return _is_reconstructable(self, target_type) + class JsonDataConverter(DataConverter): """Default :class:`DataConverter` backed by the SDK's JSON codec. @@ -187,6 +211,32 @@ def _log_coercion_fallback(target_type: type, error: Exception) -> None: # --------------------------------------------------------------------------- +def _is_reconstructable(converter: DataConverter, target_type: Any) -> bool: + """Default :meth:`DataConverter.is_reconstructable` policy. + + Recognizes dataclasses and ``from_json()``-capable types, plus ``Optional`` + / ``list`` / ``Sequence`` hints wrapping them; builtins and unknown + annotations are excluded. Recurses through ``converter.is_reconstructable`` + (not itself) so a subclass override participates in the element-type checks + of ``Optional`` / ``list`` hints. + """ + origin = typing.get_origin(target_type) + if origin is not None: + args = typing.get_args(target_type) + if origin is typing.Union or origin is types.UnionType: + return any( + converter.is_reconstructable(a) for a in args if a is not type(None) + ) + if origin in (list, Sequence): + return any(converter.is_reconstructable(a) for a in args) + return False + if not isinstance(target_type, type): + return False + if dataclasses.is_dataclass(target_type): + return True + return callable(getattr(cast(Any, target_type), "from_json", None)) + + def _to_json(obj: Any) -> str: """Serialize a value to a JSON string. diff --git a/durabletask/worker.py b/durabletask/worker.py index 6fda92d..de19329 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1692,7 +1692,7 @@ def call_activity( # converter decides how a coercion failure is handled (the default is # best-effort). if return_type is None and not isinstance(activity, str): - return_type = type_discovery.activity_output_type(activity) + return_type = type_discovery.activity_output_type(activity, self._data_converter) self.call_activity_function_helper( id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags, @@ -2211,7 +2211,7 @@ def process_event( if ( event.executionStarted.HasField("input") and event.executionStarted.input.value != "" ): - input_type = type_discovery.orchestrator_input_type(fn) + input_type = type_discovery.orchestrator_input_type(fn, self._data_converter) input = self._data_converter.deserialize( event.executionStarted.input.value, input_type) @@ -2856,7 +2856,7 @@ def execute( f"Activity function named '{name}' was not registered!" ) - input_type = type_discovery.activity_input_type(fn) if encoded_input else None + input_type = type_discovery.activity_input_type(fn, self._data_converter) if encoded_input else None activity_input = self._data_converter.deserialize(encoded_input, input_type) ctx = task.ActivityContext(orchestration_id, task_id) @@ -2897,7 +2897,7 @@ def execute( f"Entity function named '{entity_id.entity}' was not registered!" ) - input_type = type_discovery.entity_input_type(fn, operation) if encoded_input else None + input_type = type_discovery.entity_input_type(fn, operation, self._data_converter) if encoded_input else None entity_input = self._data_converter.deserialize(encoded_input, input_type) ctx = EntityContext(orchestration_id, operation, state, entity_id, self._data_converter) diff --git a/examples/README.md b/examples/README.md index 3a4ce41..7476e4d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -169,6 +169,12 @@ python activity_sequence.py > account and an additional install step. See the > [large payload README](./large_payload/README.md) for details. +> [!NOTE] +> The `custom_data_converter/` example is a self-contained subproject that +> shows how to plug a third-party serialization library (pydantic) into the +> SDK via a custom `DataConverter`, with in-process tests that need no backend. +> See the [custom DataConverter README](./custom_data_converter/README.md). + ### Review Orchestration History and Status To access the Durable Task Scheduler Dashboard, follow these steps: diff --git a/examples/custom_data_converter/README.md b/examples/custom_data_converter/README.md new file mode 100644 index 0000000..4888d9f --- /dev/null +++ b/examples/custom_data_converter/README.md @@ -0,0 +1,165 @@ +# Custom DataConverter Example (pydantic) + +This example shows how to plug a **third-party serialization library** into the +Durable Task Python SDK by implementing a custom +[`DataConverter`](../../durabletask/serialization.py). It uses +[pydantic](https://docs.pydantic.dev/) to serialize and *validate* payloads, and +proves the integration works end-to-end with in-process tests that need no +sidecar, emulator, or Azure resources. + +You can copy this entire folder into a new directory and run it as a standalone +project. + +## Why a custom converter? + +By default the SDK serializes payloads with its built-in JSON codec +(`JsonDataConverter`), which understands builtins, dataclasses, and objects that +expose `to_json()` / `from_json()` hooks. It does **not** know about pydantic +models. Supplying a custom `DataConverter` lets you route serialization through +any library you like — pydantic, `attrs`, `marshmallow`, a schema registry, an +encryption layer, etc. — at every payload boundary the SDK touches. + +## How the seam works + +Both the worker and the client accept a `data_converter` argument, and the SDK +routes **every** payload boundary through it — orchestrator / activity / entity +inputs and outputs, external events, and custom status: + +```python +converter = PydanticDataConverter() + +worker = DurableTaskSchedulerWorker(..., data_converter=converter) +client = DurableTaskSchedulerClient(..., data_converter=converter) +``` + +> [!IMPORTANT] +> Pass an equivalent converter to **both** the worker and the client. A payload +> serialized by one side is reconstructed by the other, so they must agree on +> the format. + +A `DataConverter` implements three methods: + +| Method | Direction | Used when | +|---|---|---| +| `serialize(value)` | Python value → JSON string | Any value leaves the process | +| `deserialize(data, target_type)` | JSON string → Python value (optionally typed) | A value arrives and the SDK knows the target type (from a function annotation, `return_type=`, or a typed client accessor) | +| `coerce(value, target_type)` | already-parsed value → typed value | The SDK already holds a parsed value (e.g. entity state) | + +The converter in [src/converter.py](src/converter.py) recognizes +`pydantic.BaseModel` subclasses and uses pydantic for them, **delegating +everything else** to the default `JsonDataConverter`. This "handle my types, +delegate the rest" shape is the recommended pattern for a real converter — it +costs nothing for non-pydantic payloads. + +## Inbound inputs: `is_reconstructable` + +There is one extra detail for reconstructing **inbound** orchestrator/activity +*inputs*. Before the SDK hands an input to your converter, it asks the converter +whether the function's annotated input type is something it can rebuild, via +`DataConverter.is_reconstructable(target_type)`. The default implementation +recognizes dataclasses and `from_json()`-capable types (and `Optional` / `list` +wrappers) — it does **not** know about pydantic models, so without an override an +input annotated `order: Order` would arrive as a plain `dict`. + +The converter overrides `is_reconstructable` to also recognize +`pydantic.BaseModel` subclasses: + +```python +def is_reconstructable(self, target_type): + if _is_model_type(target_type): + return True + return super().is_reconstructable(target_type) # keep the defaults +``` + +Because the base implementation recurses through `self.is_reconstructable`, +`list[OrderItem]` and `Optional[Order]` are recognized too. Outbound values, +`return_type=` arguments, and typed client accessors (`state.get_output(Receipt)`) +don't depend on this hook — they pass the type to the converter directly. + + + +## Folder structure + +```text +custom_data_converter/ +├── README.md +├── requirements.txt +├── src/ +│ ├── __init__.py +│ ├── converter.py # PydanticDataConverter — the integration point +│ ├── workflows.py # pydantic models + orchestrator/activities +│ └── app.py # runs against a real DTS backend / emulator +└── test/ + ├── __init__.py + └── test_custom_converter.py # in-process proof using the in-memory backend +``` + +## What the example proves + +The models in [src/workflows.py](src/workflows.py) are plain +`pydantic.BaseModel` subclasses — *not* dataclasses — so they only round-trip +correctly **because** of the custom converter. The tests in +[test/test_custom_converter.py](test/test_custom_converter.py) verify: + +1. A pydantic `Order` passed as orchestration input arrives at the + orchestrator/activity as a **validated model instance** (attribute access), + not a raw dict. +2. The orchestration's pydantic `Receipt` result is reconstructed, typed, on + the client via `state.get_output(Receipt)`. +3. The wire payload is genuine pydantic JSON (`model_dump_json`), confirming the + custom converter — not the default codec — handled it. +4. An input that violates a pydantic constraint fails the orchestration with a + validation error, instead of passing bad data through. +5. For contrast, the default `JsonDataConverter` **cannot serialize** a pydantic + model at all (it raises `TypeError`) — which is exactly what motivates the + custom converter. + +## Getting started + +1. Copy this folder to a new location and `cd` into it: + + ```bash + cd custom_data_converter + ``` + +1. Create and activate a virtual environment: + + Bash: + + ```bash + python -m venv .venv + source .venv/bin/activate + ``` + + PowerShell: + + ```powershell + python -m venv .venv + .\.venv\Scripts\Activate.ps1 + ``` + +1. Install dependencies: + + ```bash + pip install -r requirements.txt + ``` + +## Running the tests (no backend required) + +From the `custom_data_converter/` directory: + +```bash +pytest test/ +``` + +This is the self-contained proof: it runs the full orchestration in-process +against the in-memory backend. + +## Running the app against the emulator + +Start the [DTS emulator](../README.md#running-with-the-emulator), then from the +`custom_data_converter/` directory: + +```bash +python -m src.app +``` diff --git a/examples/custom_data_converter/requirements.txt b/examples/custom_data_converter/requirements.txt new file mode 100644 index 0000000..a0c9a1b --- /dev/null +++ b/examples/custom_data_converter/requirements.txt @@ -0,0 +1,5 @@ +durabletask +durabletask-azuremanaged +azure-identity +pydantic>=2 +pytest diff --git a/examples/custom_data_converter/src/__init__.py b/examples/custom_data_converter/src/__init__.py new file mode 100644 index 0000000..59e481e --- /dev/null +++ b/examples/custom_data_converter/src/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/examples/custom_data_converter/src/app.py b/examples/custom_data_converter/src/app.py new file mode 100644 index 0000000..681efad --- /dev/null +++ b/examples/custom_data_converter/src/app.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Run the pydantic-backed workflow against a Durable Task Scheduler backend. + +The only thing that differs from an ordinary app is the ``data_converter=`` +argument passed to **both** the worker and the client. Pass the *same* +converter instance (or an equivalent one) to both sides so payloads serialized +by one are reconstructed by the other. + +Usage (emulator -- no env vars needed): + python -m src.app + +Usage (Azure): + export ENDPOINT=https://.durabletask.io + export TASKHUB= + python -m src.app + +For a self-contained proof that needs no backend at all, run the tests instead +(see ``test/test_custom_converter.py``). +""" + +import os + +from azure.identity import DefaultAzureCredential + +from durabletask import client +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +from src.converter import PydanticDataConverter +from src.workflows import Receipt, calculate_total, charge_payment, process_order, sample_order + + +def main() -> None: + taskhub_name = os.getenv("TASKHUB", "default") + endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + print(f"Using taskhub: {taskhub_name}") + print(f"Using endpoint: {endpoint}") + + secure_channel = endpoint.startswith("https://") + credential = DefaultAzureCredential() if secure_channel else None + + # The single line that wires in the third-party converter. The same + # converter is handed to the worker and the client below. + converter = PydanticDataConverter() + + with DurableTaskSchedulerWorker( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + data_converter=converter, + ) as w: + w.add_orchestrator(process_order) + w.add_activity(calculate_total) + w.add_activity(charge_payment) + w.start() + + c = DurableTaskSchedulerClient( + host_address=endpoint, + secure_channel=secure_channel, + taskhub=taskhub_name, + token_credential=credential, + data_converter=converter, + ) + + # ``input`` is a pydantic ``Order``; the converter serializes it. + instance_id = c.schedule_new_orchestration(process_order, input=sample_order()) + print(f"Scheduled orchestration: {instance_id}") + + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + # ``get_output(Receipt)`` reconstructs the typed, validated result. + receipt = state.get_output(Receipt) + print("Orchestration completed. Typed receipt:") + print(f" customer = {receipt.customer}") + print(f" total = {receipt.total}") + print(f" item_count = {receipt.item_count}") + print(f" confirmation_id = {receipt.confirmation_id}") + elif state: + print(f"Orchestration failed: {state.failure_details}") + else: + print("Orchestration timed out.") + + +if __name__ == "__main__": + main() diff --git a/examples/custom_data_converter/src/converter.py b/examples/custom_data_converter/src/converter.py new file mode 100644 index 0000000..ab142ba --- /dev/null +++ b/examples/custom_data_converter/src/converter.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pydantic-backed :class:`DataConverter` for the Durable Task Python SDK. + +This is the heart of the example: a small, self-contained converter that plugs +a third-party serialization/validation library (`pydantic +`_) into the SDK's pluggable serialization seam. + +How the seam works +------------------ +Both the worker and the client accept a ``data_converter`` argument. The SDK +routes **every** payload boundary -- orchestrator/activity/entity inputs and +outputs, external events, and custom status -- through that single converter, +so one object controls how Python values become JSON on the wire and how they +are reconstructed on the way back. + +A converter implements three methods: + +* ``serialize(value)`` -- Python value -> JSON string (or ``None``). +* ``deserialize(data, t)`` -- JSON string -> Python value, optionally + reconstructed as ``t`` (the type the SDK learned from a function's + annotation, a ``return_type=`` argument, or a typed client accessor). +* ``coerce(value, t)`` -- already-parsed value -> reconstructed as ``t`` + (used where the SDK already holds a parsed value, e.g. entity state). + +It may also override one hook: + +* ``is_reconstructable(t)`` -- tells the SDK's inbound type-discovery that an + *input* annotated with type ``t`` should be handed to ``deserialize`` / + ``coerce`` (rather than passed through as raw JSON). The default recognizes + dataclasses and ``from_json()``-capable types; override it to add your own + (here, pydantic models). + +This converter recognizes :class:`pydantic.BaseModel` subclasses and uses +pydantic for them (gaining validation, aliasing, custom field types, etc.). +For everything else -- ``str``, ``int``, ``list``, dataclasses, ... -- it +delegates to the SDK's default :class:`JsonDataConverter`, so plugging it in +costs nothing for non-pydantic payloads. This "handle my types, delegate the +rest" shape is the recommended pattern for a real custom converter. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + +from durabletask.serialization import DataConverter, JsonDataConverter + + +def _is_model_type(target_type: Any) -> bool: + """Return True when ``target_type`` is a pydantic model class.""" + return isinstance(target_type, type) and issubclass(target_type, BaseModel) + + +class PydanticDataConverter(DataConverter): + """A :class:`DataConverter` that serializes pydantic models with pydantic. + + Pydantic models are serialized via ``model_dump_json()`` and reconstructed + (with full validation) via ``model_validate_json()`` / ``model_validate()`` + whenever the SDK supplies the model type. Every other value falls through to + the default :class:`JsonDataConverter`. + """ + + def __init__(self) -> None: + # Delegate non-pydantic payloads to the SDK's built-in JSON codec. + self._fallback = JsonDataConverter() + + def serialize(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, BaseModel): + # ``model_dump_json`` honors pydantic field serializers, aliases, + # and custom types (e.g. ``datetime`` -> ISO 8601) automatically. + return value.model_dump_json() + return self._fallback.serialize(value) + + def deserialize(self, data: str | None, target_type: type | None = None) -> Any: + if data is None or data == "": + return None + if _is_model_type(target_type): + # ``model_validate_json`` parses *and validates* the payload, + # raising ``pydantic.ValidationError`` on malformed data. + return target_type.model_validate_json(data) # type: ignore[union-attr] + return self._fallback.deserialize(data, target_type) + + def coerce(self, value: Any, target_type: type | None = None) -> Any: + if value is None: + return None + if _is_model_type(target_type): + return target_type.model_validate(value) # type: ignore[union-attr] + return self._fallback.coerce(value, target_type) + + def is_reconstructable(self, target_type: Any) -> bool: + # Teach the SDK's inbound type-discovery that pydantic models are + # reconstructable, so an orchestrator/activity input annotated with a + # model type is rebuilt (and validated) by this converter instead of + # arriving as a plain ``dict``. Delegating to ``super()`` keeps the + # default behavior (dataclasses, ``from_json`` types, ``Optional`` / + # ``list`` wrappers, builtins excluded) for everything else; because the + # base recurses through ``self.is_reconstructable``, ``list[OrderItem]`` + # and ``Optional[Order]`` are recognized too. + if _is_model_type(target_type): + return True + return super().is_reconstructable(target_type) diff --git a/examples/custom_data_converter/src/workflows.py b/examples/custom_data_converter/src/workflows.py new file mode 100644 index 0000000..fc4cd7b --- /dev/null +++ b/examples/custom_data_converter/src/workflows.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Pydantic models and the workflow that exercises them. + +These models are plain :class:`pydantic.BaseModel` subclasses -- *not* +dataclasses -- which is what makes them a meaningful proof: without the custom +:class:`~src.converter.PydanticDataConverter`, the SDK's default codec would not +know how to validate or reconstruct them. The orchestrator and activities +annotate their inputs/outputs with these model types, so the SDK hands those +types to the converter at every boundary. +""" + +from __future__ import annotations + +from collections.abc import Generator +from datetime import datetime, timezone +from typing import Any + +from pydantic import BaseModel, Field, field_validator + +from durabletask import task + + +# --------------------------------------------------------------------------- +# Pydantic models (validated at every serialization boundary) +# --------------------------------------------------------------------------- +# These are plain ``pydantic.BaseModel`` subclasses -- no special hooks. The +# custom ``PydanticDataConverter`` both serializes them and (because it +# overrides ``is_reconstructable``) teaches the SDK to reconstruct them for +# inbound orchestrator/activity inputs. + + +class OrderItem(BaseModel): + """A single line item. Pydantic validates the constraints below.""" + + name: str = Field(min_length=1) + quantity: int = Field(gt=0) + unit_price: float = Field(ge=0) + + +class Order(BaseModel): + """An order placed by a customer at a point in time.""" + + customer: str = Field(min_length=1) + placed_at: datetime + items: list[OrderItem] + + @field_validator("items") + @classmethod + def _must_have_items(cls, items: list[OrderItem]) -> list[OrderItem]: + if not items: + raise ValueError("an order must contain at least one item") + return items + + +class Receipt(BaseModel): + """The orchestration's typed result.""" + + customer: str + total: float + item_count: int + confirmation_id: str + + +# --------------------------------------------------------------------------- +# Activities +# --------------------------------------------------------------------------- +# Each activity annotates its input parameter with a pydantic model type. The +# worker passes that type to ``PydanticDataConverter.deserialize``, so the +# activity body receives a fully validated model instance (attribute access), +# never a raw dict. + + +def calculate_total(ctx: task.ActivityContext, order: Order) -> float: + """Return the order's total cost from the validated model.""" + return round(sum(item.quantity * item.unit_price for item in order.items), 2) + + +def charge_payment(ctx: task.ActivityContext, amount: float) -> str: + """Pretend to charge a payment processor and return a confirmation ID.""" + if amount <= 0: + raise ValueError("payment amount must be positive") + return f"PAY-{int(amount * 100)}" + + +# --------------------------------------------------------------------------- +# Orchestrator +# --------------------------------------------------------------------------- + + +def process_order(ctx: task.OrchestrationContext, order: Order) -> Generator[task.Task[Any], Any, Receipt]: + """Validate, total, and charge an order, returning a typed :class:`Receipt`. + + The orchestrator's ``order`` parameter is reconstructed as a validated + ``Order`` by the custom converter. ``call_activity(..., return_type=Order)`` + and the ``Receipt`` return value likewise round-trip through pydantic. + """ + total: float = yield ctx.call_activity(calculate_total, input=order) + confirmation_id: str = yield ctx.call_activity(charge_payment, input=total) + + return Receipt( + customer=order.customer, + total=total, + item_count=sum(item.quantity for item in order.items), + confirmation_id=confirmation_id, + ) + + +def sample_order() -> Order: + """A valid sample order used by both the runnable app and the tests.""" + return Order( + customer="Contoso", + placed_at=datetime.now(timezone.utc), + items=[ + OrderItem(name="Widget", quantity=3, unit_price=25.0), + OrderItem(name="Gadget", quantity=1, unit_price=99.99), + ], + ) diff --git a/examples/custom_data_converter/test/__init__.py b/examples/custom_data_converter/test/__init__.py new file mode 100644 index 0000000..59e481e --- /dev/null +++ b/examples/custom_data_converter/test/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/examples/custom_data_converter/test/test_custom_converter.py b/examples/custom_data_converter/test/test_custom_converter.py new file mode 100644 index 0000000..ed96d3f --- /dev/null +++ b/examples/custom_data_converter/test/test_custom_converter.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end proof that the pydantic ``DataConverter`` plugs into the SDK. + +These tests run entirely in-process against the in-memory backend -- no +sidecar, emulator, or Azure resources. They prove that: + +1. A pydantic model passed as orchestration input round-trips through the wire + and arrives at the orchestrator/activity as a *validated model instance*. +2. The orchestration's pydantic result is reconstructed, typed, on the client + via ``state.get_output(Receipt)``. +3. The wire payload is real pydantic JSON (``model_dump_json``), confirming the + custom converter -- not the default codec -- handled the model. +4. The same converter must be supplied to both worker and client. + +Run from the example root (custom_data_converter/): + pytest test/ +""" + +from __future__ import annotations + +import json + +import pytest + +from durabletask import client, worker +from durabletask.testing import create_test_backend + +from src.converter import PydanticDataConverter +from src.workflows import ( + Order, + OrderItem, + Receipt, + calculate_total, + charge_payment, + process_order, + sample_order, +) + +HOST = "localhost:50071" +PORT = 50071 + + +@pytest.fixture(autouse=True) +def backend(): + """Start and stop the in-memory backend for each test.""" + b = create_test_backend(port=PORT) + yield b + b.stop() + b.reset() + + +def _worker(converter: PydanticDataConverter) -> worker.TaskHubGrpcWorker: + w = worker.TaskHubGrpcWorker(host_address=HOST, data_converter=converter) + w.add_orchestrator(process_order) + w.add_activity(calculate_total) + w.add_activity(charge_payment) + return w + + +def _run(order: Order): + """Run ``process_order`` with the pydantic converter wired into both sides.""" + converter = PydanticDataConverter() + with _worker(converter) as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST, data_converter=converter) as c: + instance_id = c.schedule_new_orchestration(process_order, input=order) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + return state + + +def test_pydantic_model_round_trips_and_output_is_typed(): + """A valid order completes and the client reads back a typed Receipt.""" + order = sample_order() + state = _run(order) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + # The output is reconstructed as a validated pydantic model, not a dict. + receipt = state.get_output(Receipt) + assert isinstance(receipt, Receipt) + assert receipt.customer == "Contoso" + assert receipt.item_count == 4 # 3 widgets + 1 gadget + assert receipt.total == pytest.approx(174.99) + assert receipt.confirmation_id == "PAY-17499" + + +def test_input_is_validated_pydantic_model_on_the_wire(): + """The serialized input is genuine pydantic JSON, proving the converter ran.""" + converter = PydanticDataConverter() + order = sample_order() + + serialized = converter.serialize(order) + assert serialized is not None + # ``model_dump_json`` emits the model fields; ``placed_at`` is an ISO string. + payload = json.loads(serialized) + assert payload["customer"] == "Contoso" + assert isinstance(payload["placed_at"], str) + assert payload["items"][0]["name"] == "Widget" + + # And the SDK reconstructs it back into a validated model given the type. + restored = converter.deserialize(serialized, Order) + assert isinstance(restored, Order) + assert isinstance(restored.items[0], OrderItem) + assert restored.items[0].quantity == 3 + + +def test_invalid_order_input_fails_validation_in_orchestration(): + """An order that violates a pydantic constraint surfaces as a failure. + + The activity/orchestrator receives the input via the converter, so the + pydantic ``ValidationError`` raised while reconstructing the model fails the + orchestration rather than silently passing bad data through. + """ + # ``quantity`` must be > 0; construct the bad payload past pydantic's own + # constructor guard via ``model_construct`` so the failure happens on the + # wire (during reconstruction) inside the orchestration. + bad_order = Order.model_construct( + customer="Contoso", + placed_at=sample_order().placed_at, + items=[OrderItem.model_construct(name="Widget", quantity=-1, unit_price=25.0)], + ) + + state = _run(bad_order) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert "validation" in (state.failure_details.message or "").lower() \ + or "quantity" in (state.failure_details.message or "").lower() + + +def test_default_converter_cannot_serialize_pydantic_model(): + """Contrast: the default codec has no idea how to serialize a pydantic model. + + This is what motivates the custom converter. Pydantic models are not + dataclasses and expose no ``to_json()`` hook, so the default + ``JsonDataConverter`` raises ``TypeError`` when asked to serialize one. The + ``PydanticDataConverter`` is what makes serialization work (via + ``model_dump_json``). + """ + from durabletask.serialization import JsonDataConverter + + default = JsonDataConverter() + with pytest.raises(TypeError): + default.serialize(sample_order()) diff --git a/tests/durabletask/test_type_discovery.py b/tests/durabletask/test_type_discovery.py index 0f7a6ec..77b14e8 100644 --- a/tests/durabletask/test_type_discovery.py +++ b/tests/durabletask/test_type_discovery.py @@ -11,6 +11,7 @@ from durabletask import entities, task, worker from durabletask.internal import type_discovery from durabletask.internal.entity_state_shim import StateShim +from durabletask.serialization import JsonDataConverter TEST_LOGGER = logging.getLogger("tests") @@ -36,29 +37,78 @@ def from_json(cls, data: dict[str, Any]) -> "Money": return cls(data["amount"]) -# ----- type_discovery helper ----- +# ----- DataConverter.is_reconstructable ----- class TestIsReconstructable: + """Reconstructability is now a DataConverter responsibility; the default + JsonDataConverter recognizes dataclasses and from_json-capable types.""" + + @property + def conv(self) -> JsonDataConverter: + return JsonDataConverter() + def test_dataclass_is_reconstructable(self): - assert type_discovery.is_reconstructable(Order) is True + assert self.conv.is_reconstructable(Order) is True def test_from_json_type_is_reconstructable(self): - assert type_discovery.is_reconstructable(Money) is True + assert self.conv.is_reconstructable(Money) is True def test_builtins_are_not_reconstructable(self): - assert type_discovery.is_reconstructable(int) is False - assert type_discovery.is_reconstructable(str) is False - assert type_discovery.is_reconstructable(dict) is False + assert self.conv.is_reconstructable(int) is False + assert self.conv.is_reconstructable(str) is False + assert self.conv.is_reconstructable(dict) is False def test_optional_dataclass_is_reconstructable(self): - assert type_discovery.is_reconstructable(Optional[Order]) is True + assert self.conv.is_reconstructable(Optional[Order]) is True def test_list_of_dataclass_is_reconstructable(self): - assert type_discovery.is_reconstructable(list[Order]) is True + assert self.conv.is_reconstructable(list[Order]) is True def test_list_of_builtin_is_not_reconstructable(self): - assert type_discovery.is_reconstructable(list[int]) is False + assert self.conv.is_reconstructable(list[int]) is False + + +class TestCustomConverterReconstructable: + """A custom converter can extend reconstructability to its own types, and the + default's Optional/list recursion consults the override for element types.""" + + def test_override_recognizes_custom_type_and_recurses(self): + class Widget: + pass + + class WidgetConverter(JsonDataConverter): + def is_reconstructable(self, target_type: Any) -> bool: + if isinstance(target_type, type) and issubclass(target_type, Widget): + return True + return super().is_reconstructable(target_type) + + conv = WidgetConverter() + assert conv.is_reconstructable(Widget) is True + # The base Optional/list recursion goes through ``self``, so the override + # is consulted for the element type. + assert conv.is_reconstructable(list[Widget]) is True + assert conv.is_reconstructable(Optional[Widget]) is True + # Builtins remain excluded. + assert conv.is_reconstructable(int) is False + + def test_discovery_uses_supplied_converter(self): + class Widget: + pass + + class WidgetConverter(JsonDataConverter): + def is_reconstructable(self, target_type: Any) -> bool: + if isinstance(target_type, type) and issubclass(target_type, Widget): + return True + return super().is_reconstructable(target_type) + + def act(ctx, w: Widget): + ... + + # The default converter does not recognize Widget... + assert type_discovery.activity_input_type(act) is None + # ...but the custom converter does, so discovery surfaces the type. + assert type_discovery.activity_input_type(act, WidgetConverter()) is Widget class TestInputTypeDiscovery: From f0ea4039ed3fdef5cf830d880b5fdacb08b0d199 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 16:02:09 -0600 Subject: [PATCH 5/9] More fixes: - Rename is_reconstructible to can_reconstruct - Correct ownership of _can_reconstruct - Required DataConverter for internal classes --- CHANGELOG.md | 2 +- durabletask/internal/type_discovery.py | 4 +- durabletask/serialization.py | 45 +++--- durabletask/worker.py | 16 +- examples/custom_data_converter/README.md | 17 +- .../custom_data_converter/src/converter.py | 16 +- .../custom_data_converter/src/workflows.py | 2 +- tests/durabletask/test_activity_executor.py | 3 +- tests/durabletask/test_entity_executor.py | 3 +- .../test_orchestration_executor.py | 149 +++++++++--------- tests/durabletask/test_type_discovery.py | 40 ++--- 11 files changed, 151 insertions(+), 146 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be82789..acc2959 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,7 +46,7 @@ ADDED `DataConverter` as a second parameter (`from_json(cls, value, converter)`), letting it reconstruct nested typed values via `converter.coerce(...)` / `converter.deserialize(...)`. The single-argument form remains supported. -- `DataConverter` now exposes an overridable `is_reconstructable(target_type)` +- `DataConverter` now exposes an overridable `can_reconstruct(target_type)` method that controls which annotated input/return types the SDK reconstructs on the inbound path. A custom converter can override it to recognize its own types (for example `pydantic.BaseModel` subclasses), so that orchestrator / diff --git a/durabletask/internal/type_discovery.py b/durabletask/internal/type_discovery.py index cc92404..6f87ea2 100644 --- a/durabletask/internal/type_discovery.py +++ b/durabletask/internal/type_discovery.py @@ -88,7 +88,7 @@ def _input_annotation(fn: Callable[..., Any], position: int, if annotation is inspect.Parameter.empty or annotation is Any: return None - return annotation if _resolve_converter(converter).is_reconstructable(annotation) else None + return annotation if _resolve_converter(converter).can_reconstruct(annotation) else None def orchestrator_input_type(fn: Callable[..., Any], @@ -129,7 +129,7 @@ def activity_output_type(fn: Any, converter: DataConverter | None = None) -> Any if annotation is inspect.Signature.empty or annotation is Any or annotation is None: return None - return annotation if _resolve_converter(converter).is_reconstructable(annotation) else None + return annotation if _resolve_converter(converter).can_reconstruct(annotation) else None def entity_input_type(fn: Any, operation: str, diff --git a/durabletask/serialization.py b/durabletask/serialization.py index 19cfb6d..8317abf 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -101,29 +101,26 @@ def coerce(self, value: Any, target_type: type | None = None) -> Any: """ ... - def is_reconstructable(self, target_type: Any) -> bool: + def can_reconstruct(self, target_type: Any) -> bool: """Return True if this converter can rebuild ``target_type`` from a payload. Inbound type-discovery calls this to decide whether a function's annotated *input* type (or an activity's *return* annotation) should be passed to :meth:`deserialize` / :meth:`coerce`. When it returns ``False`` the SDK passes the raw deserialized payload through unchanged -- this - gate is what stops the SDK from invoking an arbitrary constructor on a - builtin or otherwise unrecognized annotation. - - The default recognizes the types the built-in codec can rebuild -- - dataclasses and ``from_json()``-capable types, plus ``Optional`` / - ``list`` / ``Sequence`` hints wrapping them -- and excludes builtins - (``int``, ``str``, ``dict``, ...) and unknown annotations. - - Override this to teach the SDK about a custom converter's own types (for + gate is what stops the SDK from invoking reconstruction on a type the + converter does not actually handle. + + The base implementation is conservative and returns ``False``: a + converter makes no reconstruction claims unless it opts in. + :class:`JsonDataConverter` overrides this to recognize the types its + codec can rebuild (dataclasses and ``from_json()``-capable types, plus + ``Optional`` / ``list`` / ``Sequence`` hints wrapping them). Override + this in a custom converter to teach the SDK about its own types (for example ``pydantic.BaseModel`` subclasses) so that inputs annotated with - them are reconstructed instead of arriving as raw JSON. The default - implementation recurses through ``self.is_reconstructable``, so an - override is also consulted for the element types of ``Optional`` / - ``list`` hints (e.g. ``list[MyModel]``). + them are reconstructed instead of arriving as raw JSON. """ - return _is_reconstructable(self, target_type) + return False class JsonDataConverter(DataConverter): @@ -187,6 +184,9 @@ def coerce(self, value: Any, target_type: type | None = None) -> Any: self._log_coercion_fallback(target_type, e) return value + def can_reconstruct(self, target_type: Any) -> bool: + return _can_reconstruct(self, target_type) + @staticmethod def _log_coercion_fallback(target_type: type, error: Exception) -> None: logger.debug( @@ -211,24 +211,25 @@ def _log_coercion_fallback(target_type: type, error: Exception) -> None: # --------------------------------------------------------------------------- -def _is_reconstructable(converter: DataConverter, target_type: Any) -> bool: - """Default :meth:`DataConverter.is_reconstructable` policy. +def _can_reconstruct(converter: DataConverter, target_type: Any) -> bool: + """:class:`JsonDataConverter`'s reconstruction policy. Recognizes dataclasses and ``from_json()``-capable types, plus ``Optional`` / ``list`` / ``Sequence`` hints wrapping them; builtins and unknown - annotations are excluded. Recurses through ``converter.is_reconstructable`` - (not itself) so a subclass override participates in the element-type checks - of ``Optional`` / ``list`` hints. + annotations are excluded. Recurses through ``converter.can_reconstruct`` + (not itself) so a :class:`JsonDataConverter` subclass that overrides + ``can_reconstruct`` still participates in the element-type checks of + ``Optional`` / ``list`` hints. """ origin = typing.get_origin(target_type) if origin is not None: args = typing.get_args(target_type) if origin is typing.Union or origin is types.UnionType: return any( - converter.is_reconstructable(a) for a in args if a is not type(None) + converter.can_reconstruct(a) for a in args if a is not type(None) ) if origin in (list, Sequence): - return any(converter.is_reconstructable(a) for a in args) + return any(converter.can_reconstruct(a) for a in args) return False if not isinstance(target_type, type): return False diff --git a/durabletask/worker.py b/durabletask/worker.py index de19329..66381c9 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1420,8 +1420,8 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext): def __init__(self, instance_id: str, registry: _Registry, + data_converter: DataConverter, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, - data_converter: DataConverter | None = None, ): self._generator = None self._is_replaying = True @@ -1450,7 +1450,7 @@ def __init__(self, self._parent_trace_context: pb.TraceContext | None = None self._orchestration_trace_context: pb.TraceContext | None = None self._maximum_timer_interval = maximum_timer_interval - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter def run(self, generator: Generator[task.Task[Any], Any, Any]) -> None: self._generator = generator @@ -2050,13 +2050,13 @@ def __init__( self, registry: _Registry, logger: logging.Logger, + data_converter: DataConverter, persisted_orch_span_id: str | None = None, maximum_timer_interval: timedelta | None = DEFAULT_MAXIMUM_TIMER_INTERVAL, - data_converter: DataConverter | None = None, ): self._registry = registry self._logger = logger - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter self._maximum_timer_interval = maximum_timer_interval self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] @@ -2834,10 +2834,10 @@ def compare_versions(self, source_version: str | None, default_version: str | No class _ActivityExecutor: def __init__(self, registry: _Registry, logger: logging.Logger, - data_converter: DataConverter | None = None): + data_converter: DataConverter): self._registry = registry self._logger = logger - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter def execute( self, @@ -2873,10 +2873,10 @@ def execute( class _EntityExecutor: def __init__(self, registry: _Registry, logger: logging.Logger, - data_converter: DataConverter | None = None): + data_converter: DataConverter): self._registry = registry self._logger = logger - self._data_converter = data_converter if data_converter is not None else JsonDataConverter() + self._data_converter = data_converter self._entity_method_cache: dict[tuple[type, str], bool] = {} def execute( diff --git a/examples/custom_data_converter/README.md b/examples/custom_data_converter/README.md index 4888d9f..749b567 100644 --- a/examples/custom_data_converter/README.md +++ b/examples/custom_data_converter/README.md @@ -51,28 +51,29 @@ everything else** to the default `JsonDataConverter`. This "handle my types, delegate the rest" shape is the recommended pattern for a real converter — it costs nothing for non-pydantic payloads. -## Inbound inputs: `is_reconstructable` +## Inbound inputs: `can_reconstruct` There is one extra detail for reconstructing **inbound** orchestrator/activity *inputs*. Before the SDK hands an input to your converter, it asks the converter whether the function's annotated input type is something it can rebuild, via -`DataConverter.is_reconstructable(target_type)`. The default implementation +`DataConverter.can_reconstruct(target_type)`. The default implementation recognizes dataclasses and `from_json()`-capable types (and `Optional` / `list` wrappers) — it does **not** know about pydantic models, so without an override an input annotated `order: Order` would arrive as a plain `dict`. -The converter overrides `is_reconstructable` to also recognize -`pydantic.BaseModel` subclasses: +The converter overrides `can_reconstruct` to also recognize +`pydantic.BaseModel` subclasses, deferring everything else to the same +`JsonDataConverter` fallback it uses for serialization: ```python -def is_reconstructable(self, target_type): +def can_reconstruct(self, target_type): if _is_model_type(target_type): return True - return super().is_reconstructable(target_type) # keep the defaults + return self._fallback.can_reconstruct(target_type) # dataclasses, from_json, ... ``` -Because the base implementation recurses through `self.is_reconstructable`, -`list[OrderItem]` and `Optional[Order]` are recognized too. Outbound values, +The base `DataConverter.can_reconstruct` is conservative — it returns `False`, +so a converter only claims the types it actually rebuilds. Outbound values, `return_type=` arguments, and typed client accessors (`state.get_output(Receipt)`) don't depend on this hook — they pass the type to the converter directly. diff --git a/examples/custom_data_converter/src/converter.py b/examples/custom_data_converter/src/converter.py index ab142ba..29696f4 100644 --- a/examples/custom_data_converter/src/converter.py +++ b/examples/custom_data_converter/src/converter.py @@ -26,7 +26,7 @@ It may also override one hook: -* ``is_reconstructable(t)`` -- tells the SDK's inbound type-discovery that an +* ``can_reconstruct(t)`` -- tells the SDK's inbound type-discovery that an *input* annotated with type ``t`` should be handed to ``deserialize`` / ``coerce`` (rather than passed through as raw JSON). The default recognizes dataclasses and ``from_json()``-capable types; override it to add your own @@ -92,15 +92,15 @@ def coerce(self, value: Any, target_type: type | None = None) -> Any: return target_type.model_validate(value) # type: ignore[union-attr] return self._fallback.coerce(value, target_type) - def is_reconstructable(self, target_type: Any) -> bool: + def can_reconstruct(self, target_type: Any) -> bool: # Teach the SDK's inbound type-discovery that pydantic models are # reconstructable, so an orchestrator/activity input annotated with a # model type is rebuilt (and validated) by this converter instead of - # arriving as a plain ``dict``. Delegating to ``super()`` keeps the - # default behavior (dataclasses, ``from_json`` types, ``Optional`` / - # ``list`` wrappers, builtins excluded) for everything else; because the - # base recurses through ``self.is_reconstructable``, ``list[OrderItem]`` - # and ``Optional[Order]`` are recognized too. + # arriving as a plain ``dict``. For everything else, defer to the same + # ``JsonDataConverter`` fallback this converter uses for serialization, + # so its reconstruction claims match what it actually handles + # (dataclasses, ``from_json`` types, ``Optional`` / ``list`` wrappers; + # builtins excluded). if _is_model_type(target_type): return True - return super().is_reconstructable(target_type) + return self._fallback.can_reconstruct(target_type) diff --git a/examples/custom_data_converter/src/workflows.py b/examples/custom_data_converter/src/workflows.py index fc4cd7b..61e8bb0 100644 --- a/examples/custom_data_converter/src/workflows.py +++ b/examples/custom_data_converter/src/workflows.py @@ -27,7 +27,7 @@ # --------------------------------------------------------------------------- # These are plain ``pydantic.BaseModel`` subclasses -- no special hooks. The # custom ``PydanticDataConverter`` both serializes them and (because it -# overrides ``is_reconstructable``) teaches the SDK to reconstruct them for +# overrides ``can_reconstruct``) teaches the SDK to reconstruct them for # inbound orchestrator/activity inputs. diff --git a/tests/durabletask/test_activity_executor.py b/tests/durabletask/test_activity_executor.py index 408815b..ff775ad 100644 --- a/tests/durabletask/test_activity_executor.py +++ b/tests/durabletask/test_activity_executor.py @@ -6,6 +6,7 @@ from typing import Any from durabletask import task, worker +from durabletask.serialization import JsonDataConverter logging.basicConfig( format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', @@ -53,5 +54,5 @@ def test_activity(ctx: task.ActivityContext, _): def _get_activity_executor(fn: task.Activity) -> tuple[worker._ActivityExecutor, str]: registry = worker._Registry() name = registry.add_activity(fn) - executor = worker._ActivityExecutor(registry, TEST_LOGGER) + executor = worker._ActivityExecutor(registry, TEST_LOGGER, JsonDataConverter()) return executor, name diff --git a/tests/durabletask/test_entity_executor.py b/tests/durabletask/test_entity_executor.py index 08d2827..34656c5 100644 --- a/tests/durabletask/test_entity_executor.py +++ b/tests/durabletask/test_entity_executor.py @@ -7,6 +7,7 @@ from durabletask import entities from durabletask.internal.entity_state_shim import StateShim +from durabletask.serialization import JsonDataConverter from durabletask.worker import _EntityExecutor, _Registry @@ -15,7 +16,7 @@ def _make_executor(*entity_args) -> _EntityExecutor: registry = _Registry() for entity in entity_args: registry.add_entity(entity) - return _EntityExecutor(registry, logging.getLogger("test")) + return _EntityExecutor(registry, logging.getLogger("test"), JsonDataConverter()) def _execute(executor, entity_name, operation, encoded_input=None): diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 790f8e2..e95c461 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -10,6 +10,7 @@ import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb from durabletask import task, worker, entities +from durabletask.serialization import JsonDataConverter logging.basicConfig( format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', @@ -36,7 +37,7 @@ def orchestrator(ctx: task.OrchestrationContext, my_input: int): helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=json.dumps(test_input)), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -58,7 +59,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): name = registry.add_orchestrator(empty_orchestrator) new_events = [helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -73,7 +74,7 @@ def test_orchestrator_not_registered(): registry = worker._Registry() name = "Bogus" new_events = [helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -100,7 +101,7 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -133,7 +134,7 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_timer_fired_event(1, expected_fire_at)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -160,7 +161,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -180,7 +181,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start_time = datetime(2020, 1, 1, 12, 0, 0) t1 = start_time + timedelta(days=3) @@ -277,7 +278,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start_time = datetime(2020, 1, 1, 12, 0, 0) first_chunk_fire_at = start_time + timedelta(days=3) @@ -325,7 +326,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start_time = datetime(2020, 1, 1, 12, 0, 0) timeout_fire_at = start_time + timedelta(hours=1) @@ -363,7 +364,7 @@ def test_only_cancellable_tasks_expose_cancel(): def dummy_activity(ctx, _): pass - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, worker._Registry()) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, worker._Registry(), JsonDataConverter()) timer_task = ctx.create_timer(timedelta(minutes=5)) external_event_task = ctx.wait_for_external_event("approval") @@ -408,7 +409,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -441,7 +442,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): encoded_output = json.dumps("done!") new_events = [helpers.new_task_completed_event(1, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -479,7 +480,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [helpers.new_task_completed_event(1, json.dumps({"message": "hi"}))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) @@ -517,7 +518,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [helpers.new_task_completed_event(1, json.dumps({"message": "hi"}))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) assert captured["type"] == "Result" @@ -556,7 +557,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [helpers.new_task_completed_event(1, json.dumps({"value": "x"}))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) assert captured["type"] == "Override" @@ -587,7 +588,7 @@ def orchestrator(ctx: task.OrchestrationContext, args: StartArgs): helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=encoded_input)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, [], new_events) assert captured["type"] == "StartArgs" @@ -615,7 +616,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): ex = Exception("Kah-BOOOOM!!!") new_events = [helpers.new_task_failed_event(1, ex)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -661,7 +662,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -675,7 +676,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(2, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -688,7 +689,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -702,7 +703,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -715,7 +716,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 4 @@ -729,7 +730,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(4, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 4 @@ -742,7 +743,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 5 @@ -756,7 +757,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(5, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 5 @@ -770,7 +771,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 6 @@ -784,7 +785,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(6, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 6 @@ -796,7 +797,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 7 @@ -840,7 +841,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -854,7 +855,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(2, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -866,7 +867,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -880,7 +881,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -890,7 +891,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 4 @@ -929,7 +930,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Fail!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -942,7 +943,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(2, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) expected_fire_at = current_timestamp + timedelta(seconds=5) @@ -950,7 +951,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Fail!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -978,7 +979,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): registry = worker._Registry() name = registry.add_orchestrator(orchestrator) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) start = datetime.utcnow() t1 = start + timedelta(days=3) @@ -1077,7 +1078,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_timer_created_event(1, fire_at)] new_events = [helpers.new_timer_fired_event(timer_id=1, fire_at=fire_at)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1105,7 +1106,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [helpers.new_task_completed_event(1)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1135,7 +1136,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [helpers.new_task_completed_event(1)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1166,7 +1167,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [helpers.new_task_completed_event(1)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1200,7 +1201,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1230,7 +1231,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): ex = Exception("Kah-BOOOOM!!!") new_events = [helpers.new_sub_orchestration_failed_event(1, ex)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1261,7 +1262,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1291,7 +1292,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1318,7 +1319,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Execute the orchestration until it is waiting for an external event. The result # should be an empty list of actions because the orchestration didn't schedule any work. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 0 @@ -1327,7 +1328,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # the orchestration should complete. old_events = new_events new_events = [helpers.new_event_raised_event("my_event", encoded_input="42")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1352,7 +1353,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_event_raised_event("my_event", encoded_input="42")] # Execute the orchestration. It should be in a running state waiting for the timer to fire - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -1363,7 +1364,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): timer_due_time = datetime.utcnow() + timedelta(days=1) old_events = new_events + [helpers.new_timer_created_event(1, timer_due_time)] new_events = [helpers.new_timer_fired_event(1, timer_due_time)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1390,7 +1391,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Execute the orchestration. It should remain in a running state because it was suspended prior # to processing the event raised event. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 0 @@ -1398,7 +1399,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Resume the orchestration. It should complete successfully. old_events = old_events + new_events new_events = [helpers.new_resume_event()] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1424,7 +1425,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_event_raised_event("my_event", encoded_input="42")] # Execute the orchestration. It should be in a running state waiting for an external event - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1453,7 +1454,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))] # Execute the orchestration. It should be in a running state waiting for the timer to fire - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1489,7 +1490,7 @@ def orchestrator(ctx: task.OrchestrationContext, count: int): helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input="10")] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1531,13 +1532,13 @@ def orchestrator(ctx: task.OrchestrationContext, _): # First, test with only the first 5 events. We expect the orchestration to be running # but return zero actions since its still waiting for the other 5 tasks to complete. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events[:5]) actions = result.actions assert len(actions) == 0 # Now test with the full set of new events. We expect the orchestration to complete. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1579,7 +1580,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events.append(helpers.new_task_failed_event(6, ex)) # Now test with the full set of new events. We expect the orchestration to complete. - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1611,7 +1612,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # to return two actions: one to schedule the "Tokyo" task and one to schedule the "Seattle" task. old_events = [] new_events = [helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -1628,7 +1629,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Test 2: Complete the "Tokyo" task. We expect the orchestration to complete with output "Hello, Tokyo!" encoded_output = json.dumps(hello(None, "Tokyo")) new_events = [helpers.new_task_completed_event(1, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1638,7 +1639,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Test 3: Complete the "Seattle" task. We expect the orchestration to complete with output "Hello, Seattle!" encoded_output = json.dumps(hello(None, "Seattle")) new_events = [helpers.new_task_completed_event(2, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(1, actions) @@ -1684,7 +1685,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_orchestrator_started_event(datetime.now()), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) assert result.actions # should have scheduled the activity @@ -1700,7 +1701,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): ] encoded_output = json.dumps(say_hello(None, "World")) new_events = [helpers.new_task_completed_event(1, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED @@ -1743,7 +1744,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_orchestrator_started_event(datetime.now()), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED @@ -1877,7 +1878,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -1891,7 +1892,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -1904,7 +1905,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -1915,7 +1916,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Complete the "Seattle" task. We expect the orchestration to complete with output "Hello, Seattle!" encoded_output = json.dumps(dummy_activity(None, "Seattle")) new_events = [helpers.new_task_completed_event(2, encoded_output)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(3, actions) @@ -1959,7 +1960,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 @@ -1973,7 +1974,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_timer_fired_event(3, current_timestamp)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 @@ -1986,7 +1987,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -2000,7 +2001,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = old_events + new_events new_events = [helpers.new_task_completed_event(2, encoded_output), helpers.new_timer_fired_event(4, expected_fire_at)] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 3 @@ -2014,7 +2015,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_complete_orchestration_action_list(4, actions) @@ -2038,7 +2039,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input), helpers.new_orchestrator_completed_event()] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -2066,7 +2067,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_execution_started_event(name, TEST_INSTANCE_ID, None), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result1 = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result1.actions assert len(actions) == 1 diff --git a/tests/durabletask/test_type_discovery.py b/tests/durabletask/test_type_discovery.py index 77b14e8..31befb2 100644 --- a/tests/durabletask/test_type_discovery.py +++ b/tests/durabletask/test_type_discovery.py @@ -37,7 +37,7 @@ def from_json(cls, data: dict[str, Any]) -> "Money": return cls(data["amount"]) -# ----- DataConverter.is_reconstructable ----- +# ----- DataConverter.can_reconstruct ----- class TestIsReconstructable: @@ -49,24 +49,24 @@ def conv(self) -> JsonDataConverter: return JsonDataConverter() def test_dataclass_is_reconstructable(self): - assert self.conv.is_reconstructable(Order) is True + assert self.conv.can_reconstruct(Order) is True def test_from_json_type_is_reconstructable(self): - assert self.conv.is_reconstructable(Money) is True + assert self.conv.can_reconstruct(Money) is True def test_builtins_are_not_reconstructable(self): - assert self.conv.is_reconstructable(int) is False - assert self.conv.is_reconstructable(str) is False - assert self.conv.is_reconstructable(dict) is False + assert self.conv.can_reconstruct(int) is False + assert self.conv.can_reconstruct(str) is False + assert self.conv.can_reconstruct(dict) is False def test_optional_dataclass_is_reconstructable(self): - assert self.conv.is_reconstructable(Optional[Order]) is True + assert self.conv.can_reconstruct(Optional[Order]) is True def test_list_of_dataclass_is_reconstructable(self): - assert self.conv.is_reconstructable(list[Order]) is True + assert self.conv.can_reconstruct(list[Order]) is True def test_list_of_builtin_is_not_reconstructable(self): - assert self.conv.is_reconstructable(list[int]) is False + assert self.conv.can_reconstruct(list[int]) is False class TestCustomConverterReconstructable: @@ -78,29 +78,29 @@ class Widget: pass class WidgetConverter(JsonDataConverter): - def is_reconstructable(self, target_type: Any) -> bool: + def can_reconstruct(self, target_type: Any) -> bool: if isinstance(target_type, type) and issubclass(target_type, Widget): return True - return super().is_reconstructable(target_type) + return super().can_reconstruct(target_type) conv = WidgetConverter() - assert conv.is_reconstructable(Widget) is True + assert conv.can_reconstruct(Widget) is True # The base Optional/list recursion goes through ``self``, so the override # is consulted for the element type. - assert conv.is_reconstructable(list[Widget]) is True - assert conv.is_reconstructable(Optional[Widget]) is True + assert conv.can_reconstruct(list[Widget]) is True + assert conv.can_reconstruct(Optional[Widget]) is True # Builtins remain excluded. - assert conv.is_reconstructable(int) is False + assert conv.can_reconstruct(int) is False def test_discovery_uses_supplied_converter(self): class Widget: pass class WidgetConverter(JsonDataConverter): - def is_reconstructable(self, target_type: Any) -> bool: + def can_reconstruct(self, target_type: Any) -> bool: if isinstance(target_type, type) and issubclass(target_type, Widget): return True - return super().is_reconstructable(target_type) + return super().can_reconstruct(target_type) def act(ctx, w: Widget): ... @@ -205,7 +205,7 @@ def test_string_name_returns_none(self): def _activity_executor(fn) -> tuple[worker._ActivityExecutor, str]: registry = worker._Registry() name = registry.add_activity(fn) - return worker._ActivityExecutor(registry, TEST_LOGGER), name + return worker._ActivityExecutor(registry, TEST_LOGGER, JsonDataConverter()), name def test_activity_input_coerced_to_dataclass(): @@ -277,7 +277,7 @@ def store(ctx: entities.EntityContext, order: Order): registry = worker._Registry() registry.add_entity(store) - executor = worker._EntityExecutor(registry, TEST_LOGGER) + executor = worker._EntityExecutor(registry, TEST_LOGGER, JsonDataConverter()) entity_id = entities.EntityInstanceId("store", "k1") state = StateShim(None) executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) @@ -295,7 +295,7 @@ def save(self, order: Order): registry = worker._Registry() registry.add_entity(Store) - executor = worker._EntityExecutor(registry, TEST_LOGGER) + executor = worker._EntityExecutor(registry, TEST_LOGGER, JsonDataConverter()) entity_id = entities.EntityInstanceId("store", "k1") state = StateShim(None) executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) From 8db4510118d18387921d6f6c694425d2dbb0089e Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 16:09:22 -0600 Subject: [PATCH 6/9] CHANGELOG summarization --- CHANGELOG.md | 107 +++++++++++++-------------------------------------- 1 file changed, 26 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index acc2959..d45cadd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,98 +12,46 @@ ADDED - Added a pluggable `DataConverter` (`durabletask.serialization`) accepted by `TaskHubGrpcWorker`, `TaskHubGrpcClient`, and `AsyncTaskHubGrpcClient` via a `data_converter` argument. Every payload boundary (inputs, outputs, events, - custom status, entity state) routes through it. The default + custom status, entity state) routes through it, so one converter controls how + Python values become JSON and how they are reconstructed. The default `JsonDataConverter` preserves existing behavior, so a custom converter (for - example one backed by pydantic) is opt-in. Custom objects can opt in via a - `to_json()` hook and a `from_json(value)` classmethod. -- `OrchestrationContext.call_activity`, `call_sub_orchestrator`, and - `call_entity` accept an optional `return_type`, and `wait_for_external_event` - accepts an optional `data_type`. When provided, the result/event payload is - reconstructed as that type (dataclasses — including nested dataclass, - `Optional`, `list`, `dict`/`Mapping`, and `tuple` fields — and - `from_json()`-capable types) and the returned task is typed accordingly (e.g. - `call_activity(..., return_type=Foo)` yields `CompletableTask[Foo]`). When - omitted, the raw deserialized JSON is returned as before. -- Inbound payloads are reconstructed from function type annotations. When an - orchestrator, activity, or entity operation annotates its input parameter (or - an activity its return value) with a dataclass or `from_json()`-capable type, - the payload is reconstructed as that type. Builtins and unannotated/unknown - types are passed through unchanged. An explicit `return_type` takes precedence - over a discovered annotation. -- Added typed accessors to `client.OrchestrationState`: `get_input()`, - `get_output()`, and `get_custom_status()` each accept an optional - `expected_type` and deserialize the corresponding payload, reconstructing - dataclasses and `from_json()`-capable types. The raw `serialized_*` fields are - retained. -- Objects exposing a `to_json()` method are now JSON-serializable when passed as - activity/orchestrator inputs or outputs. -- `enum.Enum` values now serialize (to their underlying `.value`) and, when a - target type is supplied, deserialize back to the enum member. This covers - string-valued and other non-`int` enums as activity/orchestrator/entity inputs - and outputs, including as dataclass fields and inside `list` / `dict` / - `tuple` containers. (`IntEnum` / `IntFlag` already serialized as integers.) -- A `from_json()` classmethod may now optionally accept the active - `DataConverter` as a second parameter (`from_json(cls, value, converter)`), - letting it reconstruct nested typed values via `converter.coerce(...)` / - `converter.deserialize(...)`. The single-argument form remains supported. -- `DataConverter` now exposes an overridable `can_reconstruct(target_type)` - method that controls which annotated input/return types the SDK reconstructs - on the inbound path. A custom converter can override it to recognize its own - types (for example `pydantic.BaseModel` subclasses), so that orchestrator / - activity / entity inputs annotated with those types are reconstructed by the - converter instead of arriving as raw JSON. The default behavior is unchanged - (dataclasses and `from_json()`-capable types, plus `Optional` / `list` - wrappers, are reconstructable; builtins are not). -- Added `EntityMetadata.get_typed_state(intended_type=...)`, which deserializes - the entity's persisted state and reconstructs dataclasses and - `from_json()`-capable types. The existing `get_state()` is unchanged: with no - argument it returns the raw serialized JSON payload, and `get_state(some_type)` - applies constructor-based coercion (`some_type(raw)`). -- Entity runtime state retrieval (`EntityContext.get_state(intended_type=...)` / - `DurableEntity.get_state(...)`) now also reconstructs dataclasses and - `from_json()`-capable types, in addition to the existing constructor-based - coercion. + example one backed by pydantic) is fully opt-in. +- Custom objects can participate in serialization by exposing a `to_json()` + method and a `from_json(value)` classmethod. Both are honored recursively, so + nested custom objects round-trip through their own hooks. +- Payloads are reconstructed into a caller-supplied type — dataclasses + (including nested fields), `from_json()`-capable types, and `enum.Enum` + members, recursing through `list`, `dict`, `tuple`, and `Optional`/`Union` + hints. The type comes from a function's annotations, from an explicit + `return_type` on `call_activity` / `call_sub_orchestrator` / `call_entity` + (or `data_type` on `wait_for_external_event`), or from the typed accessors + `get_input()` / `get_output()` / `get_custom_status()` on + `client.OrchestrationState` and `EntityMetadata.get_typed_state(...)`. It is + never inferred from the payload. Which annotated types are eligible is decided + by the converter via the overridable `DataConverter.can_reconstruct(...)`; a + custom converter can override it to recognize its own types (for example + `pydantic.BaseModel` subclasses). CHANGED - Custom objects (dataclasses, `SimpleNamespace`, namedtuples) are now serialized as plain JSON. Decoding such a payload *without* a type hint now yields a plain `dict` (previously a `SimpleNamespace`; a namedtuple now - round-trips as a JSON array). To get the original type back, pass the new - `return_type` / `data_type` arguments, annotate the consuming function's - parameter or return type, or use the typed client accessors. Payloads produced - by older SDK versions still deserialize — including into a `SimpleNamespace` - when no type is supplied — so in-flight orchestrations continue to replay - across an upgrade. + round-trips as a JSON array). To get the original type back, supply a type via + one of the mechanisms above. Payloads produced by older SDK versions still + deserialize — including into a `SimpleNamespace` when no type is supplied — so + in-flight orchestrations continue to replay across an upgrade. - JSON serialization failures now raise a `TypeError` that chains the original error (`__cause__`) and names the offending type. - `EntityContext.get_state()` / `DurableEntity.get_state()` now return a freshly reconstructed value on every call rather than a reference to a single cached - object. As a result, mutating a value returned by `get_state()` in place no - longer affects the persisted entity state — write the change back with - `set_state()` to persist it. The entity's state is also serialized eagerly at - `set_state()` time, so a value that cannot be serialized surfaces the error - inside the failing operation (which rolls back) instead of after the batch has - run. + object, so mutating the returned value in place no longer affects persisted + state — write it back with `set_state()`. State is also serialized eagerly at + `set_state()` time, so a non-serializable value fails inside the operation + (which rolls back) instead of after the batch has run. FIXED -- A dataclass or `SimpleNamespace` that defines a `to_json()` hook now uses it - when serialized. Previously the built-in dataclass / `SimpleNamespace` - handling ran first, so the hook was ignored — and a dataclass with a field - that was not JSON-serializable on its own would fail to serialize even when it - provided a `to_json()` hook to handle that field. The serialize side now - prefers `to_json()`, mirroring the deserialize side, which already prefers - `from_json()`. -- Nested `to_json()` hooks are now honored when an object is serialized inside a - dataclass. Custom objects (including nested dataclasses with their own - `to_json()`) are now encoded recursively instead of being flattened to their - raw fields, so values that reshape themselves via `to_json()` round-trip - correctly. -- Type-directed deserialization now recurses into `dict`/`Mapping` values and - `tuple` elements, in addition to the existing `list`, `Optional`/`Union`, and - dataclass-field recursion. A `dict[str, Foo]` or `tuple[Foo, ...]` hint now - reconstructs the contained `Foo` values. - Falsy entity states (`0`, `""`, `[]`, `{}`) are no longer dropped when an entity batch is persisted. Previously a falsy current state was treated as "no state" and written as `None`, effectively deleting it; only an actual @@ -127,9 +75,6 @@ subclassing the public abstract types — may need to update their code: and `wait_for_external_event` gained new keyword-only parameters (`return_type` / `data_type`). Subclasses overriding these methods should add the parameter to match the base signature. -- `client.OrchestrationState` gained a non-public `_data_converter` field - (excluded from equality and `repr`). Code constructing `OrchestrationState` - positionally should pass it via the new field or rely on its default. ## v1.6.0 From a98ccc010e866255e8d45b6b8f1a33bb26cac75c Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 16:25:49 -0600 Subject: [PATCH 7/9] No more silent fallbacks to JsonDataConverter --- CHANGELOG.md | 15 +++++-- .../durabletask/azuremanaged/client.py | 9 ++++- .../durabletask/azuremanaged/worker.py | 5 ++- durabletask/entities/entity_context.py | 5 +-- durabletask/entities/entity_metadata.py | 9 ++--- durabletask/internal/entity_state_shim.py | 5 +-- examples/custom_data_converter/src/app.py | 1 + tests/durabletask/test_entity_executor.py | 40 +++++++++---------- tests/durabletask/test_tracing.py | 34 ++++++++-------- tests/durabletask/test_type_discovery.py | 4 +- 10 files changed, 68 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d45cadd..a3515eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,16 +65,23 @@ DEPRECATED `JsonDataConverter`) instead. The functions continue to work for backwards compatibility. -BREAKING CHANGES (type-level only — no runtime impact for typical users) +BREAKING CHANGES (no runtime impact for typical users) -These changes do not alter runtime behavior, but because the package ships -`py.typed`, consumers running strict type checkers (pyright/mypy) — or -subclassing the public abstract types — may need to update their code: +Most of these are type-level only: because the package ships `py.typed`, +consumers running strict type checkers (pyright/mypy) — or subclassing the +public abstract types — may need to update their code. The constructor change +below also affects callers who *directly* construct the named classes, which is +uncommon since they are normally handed to you by the SDK. - `OrchestrationContext.call_activity`, `call_sub_orchestrator`, `call_entity`, and `wait_for_external_event` gained new keyword-only parameters (`return_type` / `data_type`). Subclasses overriding these methods should add the parameter to match the base signature. +- `EntityContext` and `EntityMetadata` (and its `from_entity_metadata` / + `from_entity_response` factories) now require a `data_converter` argument. + These objects are normally constructed by the SDK — you receive an + `EntityContext` in an entity function and an `EntityMetadata` from the client — + so this only affects code that constructs them directly. ## v1.6.0 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 308c341..fa3875f 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -20,6 +20,7 @@ ) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore +from durabletask.serialization import DataConverter # Client class used for Durable Task Scheduler (DTS) @@ -35,6 +36,7 @@ def __init__(self, *, resiliency_options: GrpcClientResiliencyOptions | None = None, default_version: str | None = None, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None): @@ -59,7 +61,8 @@ def __init__(self, *, channel_options=channel_options, resiliency_options=resiliency_options, default_version=default_version, - payload_store=payload_store) + payload_store=payload_store, + data_converter=data_converter) # Async client class used for Durable Task Scheduler (DTS) @@ -113,6 +116,7 @@ def __init__(self, *, resiliency_options: GrpcClientResiliencyOptions | None = None, default_version: str | None = None, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None): @@ -137,4 +141,5 @@ def __init__(self, *, channel_options=channel_options, resiliency_options=resiliency_options, default_version=default_version, - payload_store=payload_store) + payload_store=payload_store, + data_converter=data_converter) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 8acbad5..c27ebc1 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -18,6 +18,7 @@ ) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore +from durabletask.serialization import DataConverter from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker @@ -81,6 +82,7 @@ def __init__(self, *, resiliency_options: GrpcWorkerResiliencyOptions | None = None, concurrency_options: ConcurrencyOptions | None = None, payload_store: PayloadStore | None = None, + data_converter: DataConverter | None = None, log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None): @@ -110,5 +112,6 @@ def __init__(self, *, concurrency_options=concurrency_options, # DTS natively supports long timers so chunking is unnecessary maximum_timer_interval=None, - payload_store=payload_store + payload_store=payload_store, + data_converter=data_converter ) diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index c5435a2..7a44c22 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -16,14 +16,11 @@ class EntityContext: def __init__(self, orchestration_id: str, operation: str, state: StateShim, - entity_id: EntityInstanceId, data_converter: "DataConverter | None" = None): + entity_id: EntityInstanceId, data_converter: "DataConverter"): self._orchestration_id = orchestration_id self._operation = operation self._state = state self._entity_id = entity_id - if data_converter is None: - from durabletask.serialization import JsonDataConverter - data_converter = JsonDataConverter() self._data_converter = data_converter @property diff --git a/durabletask/entities/entity_metadata.py b/durabletask/entities/entity_metadata.py index 37c437e..91e329e 100644 --- a/durabletask/entities/entity_metadata.py +++ b/durabletask/entities/entity_metadata.py @@ -36,7 +36,7 @@ def __init__(self, locked_by: str, includes_state: bool, state: Any | None, - data_converter: "DataConverter | None" = None): + data_converter: "DataConverter"): """Initializes a new instance of the EntityMetadata class. Args: @@ -48,20 +48,17 @@ def __init__(self, self._locked_by = locked_by self.includes_state = includes_state self._state = state - if data_converter is None: - from durabletask.serialization import JsonDataConverter - data_converter = JsonDataConverter() self._data_converter = data_converter @staticmethod def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool, - data_converter: "DataConverter | None" = None): + data_converter: "DataConverter"): return EntityMetadata.from_entity_metadata( entity_response.entity, includes_state, data_converter) @staticmethod def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool, - data_converter: "DataConverter | None" = None): + data_converter: "DataConverter"): try: entity_id = EntityInstanceId.parse(entity.instanceId) except ValueError: diff --git a/durabletask/internal/entity_state_shim.py b/durabletask/internal/entity_state_shim.py index 9993876..b18d3d6 100644 --- a/durabletask/internal/entity_state_shim.py +++ b/durabletask/internal/entity_state_shim.py @@ -36,11 +36,8 @@ class StateShim: is written back with :meth:`set_state`. """ - def __init__(self, start_state: Any, data_converter: "DataConverter | None" = None, + def __init__(self, start_state: Any, data_converter: "DataConverter", *, is_serialized: bool = False): - if data_converter is None: - from durabletask.serialization import JsonDataConverter - data_converter = JsonDataConverter() self._data_converter = data_converter # The state is normalized to its serialized string form. ``is_serialized`` # marks ``start_state`` as a raw payload already off the wire (stored diff --git a/examples/custom_data_converter/src/app.py b/examples/custom_data_converter/src/app.py index 681efad..caa9eb1 100644 --- a/examples/custom_data_converter/src/app.py +++ b/examples/custom_data_converter/src/app.py @@ -75,6 +75,7 @@ def main() -> None: if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: # ``get_output(Receipt)`` reconstructs the typed, validated result. receipt = state.get_output(Receipt) + assert receipt is not None print("Orchestration completed. Typed receipt:") print(f" customer = {receipt.customer}") print(f" total = {receipt.total}") diff --git a/tests/durabletask/test_entity_executor.py b/tests/durabletask/test_entity_executor.py index 34656c5..4277405 100644 --- a/tests/durabletask/test_entity_executor.py +++ b/tests/durabletask/test_entity_executor.py @@ -22,7 +22,7 @@ def _make_executor(*entity_args) -> _EntityExecutor: def _execute(executor, entity_name, operation, encoded_input=None): """Helper to execute an entity operation.""" entity_id = entities.EntityInstanceId(entity_name, "test-key") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) return executor.execute("test-orchestration", entity_id, operation, state, encoded_input) @@ -75,7 +75,7 @@ def get(self): entity_id = entities.EntityInstanceId("Counter", "test-key") # set requires input - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("test-orch", entity_id, "set", state, "10") state.commit() @@ -127,7 +127,7 @@ def counter(ctx: entities.EntityContext, input): executor = _make_executor(counter) entity_id = entities.EntityInstanceId("counter", "test-key") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("test-orch", entity_id, "set", state, "42") state.commit() @@ -140,19 +140,19 @@ class TestStateShimCoercion: """Tests for StateShim.get_state type coercion via the data converter.""" def test_get_state_none_returns_default(self): - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) assert state.get_state(int, 0) == 0 def test_get_state_none_without_default_returns_none(self): - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) assert state.get_state(int) is None def test_get_state_passes_through_matching_type(self): - state = StateShim(5) + state = StateShim(5, JsonDataConverter()) assert state.get_state(int) == 5 def test_get_state_constructor_coercion(self): - state = StateShim("5") + state = StateShim("5", JsonDataConverter()) assert state.get_state(int) == 5 def test_get_state_coerces_dataclass(self): @@ -163,7 +163,7 @@ class Counter: value: int # State is stored as a plain dict (as it would be after from_json). - state = StateShim({"value": 7}) + state = StateShim({"value": 7}, JsonDataConverter()) result = state.get_state(Counter) assert isinstance(result, Counter) assert result.value == 7 @@ -177,7 +177,7 @@ def __init__(self, n: int): def from_json(cls, data): return cls(data["n"]) - state = StateShim({"n": 3}) + state = StateShim({"n": 3}, JsonDataConverter()) result = state.get_state(Wrapped) assert isinstance(result, Wrapped) assert result.n == 3 @@ -187,7 +187,7 @@ def test_get_state_invalid_coercion_raises(self): # restoring the pre-existing strict contract for entity state access. import pytest - state = StateShim("not-an-int") + state = StateShim("not-an-int", JsonDataConverter()) with pytest.raises(TypeError): state.get_state(int) @@ -197,7 +197,7 @@ class TestStateShimDeferredDeserialization: def test_constructor_does_not_deserialize_serialized_state(self): # A serialized payload is held verbatim until read, not eagerly parsed. - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) assert state._current_state == '{"value": 7}' def test_get_state_defers_deserialization_with_type(self): @@ -207,13 +207,13 @@ def test_get_state_defers_deserialization_with_type(self): class Counter: value: int - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) result = state.get_state(Counter) assert isinstance(result, Counter) assert result.value == 7 def test_get_state_no_type_returns_parsed_value(self): - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) assert state.get_state() == {"value": 7} def test_deferred_deserialization_passes_raw_string_to_converter(self): @@ -247,11 +247,11 @@ def coerce(self, value, target_type=None): def test_encode_state_passes_through_unmodified_payload(self): # An unread/unmodified serialized payload is returned verbatim, never # re-serialized (which would double-encode the JSON string). - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) assert state.encode_state() == '{"value": 7}' def test_reading_does_not_trigger_double_encoding(self): - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) # Reading (even with a type) must not turn the payload into a live value # that would be re-serialized into a JSON-encoded string. state.get_state() @@ -261,24 +261,24 @@ def test_reading_does_not_trigger_double_encoding(self): assert json.loads(encoded) == {"value": 7} def test_encode_state_serializes_live_value_after_set_state(self): - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) state.set_state({"value": 8}) encoded = state.encode_state() assert json.loads(encoded) == {"value": 8} def test_encode_state_none_when_state_is_none(self): - state = StateShim(None, is_serialized=True) + state = StateShim(None, JsonDataConverter(), is_serialized=True) assert state.encode_state() is None def test_commit_preserves_unmodified_payload(self): - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) state.commit() # After commit, the (unmodified) state still round-trips without # double-encoding. assert state.encode_state() == '{"value": 7}' def test_rollback_restores_unmodified_payload(self): - state = StateShim('{"value": 7}', is_serialized=True) + state = StateShim('{"value": 7}', JsonDataConverter(), is_serialized=True) state.commit() state.set_state({"value": 99}) state.rollback() @@ -286,6 +286,6 @@ def test_rollback_restores_unmodified_payload(self): def test_falsy_serialized_state_is_not_dropped(self): # A serialized falsy value (e.g. 0) is preserved, not treated as cleared. - state = StateShim("0", is_serialized=True) + state = StateShim("0", JsonDataConverter(), is_serialized=True) assert state.get_state(int) == 0 assert state.encode_state() == "0" diff --git a/tests/durabletask/test_tracing.py b/tests/durabletask/test_tracing.py index 9969afa..8115919 100644 --- a/tests/durabletask/test_tracing.py +++ b/tests/durabletask/test_tracing.py @@ -17,6 +17,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode +from durabletask.serialization import JsonDataConverter import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.tracing as tracing @@ -263,10 +264,10 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) assert ctx._parent_trace_context is None - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) # Create an executionStarted event with parentTraceContext event = pb.HistoryEvent( @@ -290,8 +291,8 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) event = pb.HistoryEvent( eventId=-1, @@ -320,10 +321,10 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) assert ctx._orchestration_trace_context is None - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) event = pb.HistoryEvent( eventId=-1, executionStarted=pb.ExecutionStartedEvent( @@ -359,9 +360,10 @@ def simple_orchestrator(ctx: task.OrchestrationContext, _): registry = worker._Registry() registry.add_orchestrator(simple_orchestrator) - ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry) + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, registry, JsonDataConverter()) executor = worker._OrchestrationExecutor( registry, TEST_LOGGER, + data_converter=JsonDataConverter(), persisted_orch_span_id=persisted_span_id, ) @@ -867,7 +869,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_task_completed_event(2, json.dumps(20)), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -904,7 +906,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_task_completed_event(2, json.dumps("ok")), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -942,7 +944,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_timer_fired_event(2, fire_at_2), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -976,7 +978,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("r2")), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -1013,7 +1015,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_sub_orchestration_completed_event(2, encoded_output=json.dumps("ok")), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = self._get_client_spans(otel_setup) @@ -1732,7 +1734,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): self._make_traced_task_completed_event(1, json.dumps("result"), timestamp_seconds=200), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ @@ -1783,7 +1785,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): self._make_traced_task_failed_event(1, "boom", timestamp_seconds=250), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ @@ -1848,7 +1850,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): ), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ @@ -1895,7 +1897,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): helpers.new_task_completed_event(2, json.dumps(20)), ] - executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER, JsonDataConverter()) executor.execute(TEST_INSTANCE_ID, old_events, new_events) client_spans = [ diff --git a/tests/durabletask/test_type_discovery.py b/tests/durabletask/test_type_discovery.py index 31befb2..065e964 100644 --- a/tests/durabletask/test_type_discovery.py +++ b/tests/durabletask/test_type_discovery.py @@ -279,7 +279,7 @@ def store(ctx: entities.EntityContext, order: Order): registry.add_entity(store) executor = worker._EntityExecutor(registry, TEST_LOGGER, JsonDataConverter()) entity_id = entities.EntityInstanceId("store", "k1") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) assert seen["type"] == "Order" assert seen["item"] == "book" @@ -297,7 +297,7 @@ def save(self, order: Order): registry.add_entity(Store) executor = worker._EntityExecutor(registry, TEST_LOGGER, JsonDataConverter()) entity_id = entities.EntityInstanceId("store", "k1") - state = StateShim(None) + state = StateShim(None, JsonDataConverter()) executor.execute("orch1", entity_id, "save", state, json.dumps({"item": "book", "quantity": 2})) assert seen["type"] == "Order" assert seen["item"] == "book" From c71e8c9061cb8b507e015c4fdef320cc5446086d Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 16:39:59 -0600 Subject: [PATCH 8/9] PR feedback --- durabletask/internal/type_discovery.py | 2 +- durabletask/serialization.py | 30 +++++++++++-------- examples/custom_data_converter/README.md | 12 +++++++- .../custom_data_converter/src/converter.py | 12 ++++++++ tests/durabletask/test_serialization.py | 21 +++++++++++++ 5 files changed, 63 insertions(+), 14 deletions(-) diff --git a/durabletask/internal/type_discovery.py b/durabletask/internal/type_discovery.py index 6f87ea2..942ac8b 100644 --- a/durabletask/internal/type_discovery.py +++ b/durabletask/internal/type_discovery.py @@ -10,7 +10,7 @@ Discovery is intentionally conservative: it only returns an annotation when the active :class:`~durabletask.serialization.DataConverter` reports it as -*reconstructable* via :meth:`DataConverter.is_reconstructable`. The default +*reconstructable* via :meth:`DataConverter.can_reconstruct`. The default converter recognizes dataclasses, ``from_json()``-capable types, and ``Optional`` / ``list`` hints wrapping them; a custom converter can recognize its own types (e.g. ``pydantic.BaseModel``). Primitive and unknown annotations resolve to diff --git a/durabletask/serialization.py b/durabletask/serialization.py index 8317abf..5d1dad8 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -403,7 +403,12 @@ def _invoke_from_json(hook: Any, value: Any, converter: DataConverter | None) -> > ``from_json`` must be a ``@classmethod`` or ``@staticmethod`` -- the hook > is resolved off the *type* (no instance exists yet during reconstruction). > A plain instance method would have ``self`` consume the value and is - > unsupported regardless of the converter-arity detection below. + > unsupported regardless of the converter detection below. + > + > The converter is passed positionally as the second argument, so a hook + > opts in only by naming that parameter exactly ``converter``. This reserved + > name avoids misreading an unrelated second parameter (e.g. + > ``from_json(cls, value, strict=False)``) as converter-aware. """ if converter is not None and _hook_accepts_converter(hook): return hook(value, converter) @@ -412,26 +417,27 @@ def _invoke_from_json(hook: Any, value: Any, converter: DataConverter | None) -> @functools.lru_cache(maxsize=2048) def _hook_accepts_converter(hook: Any) -> bool: - """Return True if a bound ``from_json`` hook can accept a second argument. + """Return True if a bound ``from_json`` hook opts in to receiving the converter. The hook is inspected as accessed off the type (``cls``/``self`` already bound), so a classmethod ``from_json(cls, value, converter)`` presents as - ``(value, converter)``. Results are cached because reconstruction runs on - hot paths; bound classmethods hash equal across attribute accesses, so the - cache stays effective and bounded. + ``(value, converter)``. A hook is treated as converter-aware only when its + second positional parameter is named exactly ``converter`` -- the reserved + name for this argument -- so an unrelated second parameter such as + ``strict=False`` is not mistaken for it. Results are cached because + reconstruction runs on hot paths; bound classmethods hash equal across + attribute accesses, so the cache stays effective and bounded. """ try: sig = inspect.signature(hook) except (TypeError, ValueError): return False - positional = 0 - for param in sig.parameters.values(): + positional = [ + param for param in sig.parameters.values() if param.kind in (inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD): - positional += 1 - elif param.kind is inspect.Parameter.VAR_POSITIONAL: - return True - return positional >= 2 + inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + return len(positional) >= 2 and positional[1].name == "converter" def _coerce_generic(value: Any, expected_type: Any, origin: Any, diff --git a/examples/custom_data_converter/README.md b/examples/custom_data_converter/README.md index 749b567..e375689 100644 --- a/examples/custom_data_converter/README.md +++ b/examples/custom_data_converter/README.md @@ -48,9 +48,19 @@ A `DataConverter` implements three methods: The converter in [src/converter.py](src/converter.py) recognizes `pydantic.BaseModel` subclasses and uses pydantic for them, **delegating everything else** to the default `JsonDataConverter`. This "handle my types, -delegate the rest" shape is the recommended pattern for a real converter — it +delegate the rest" shape is a good starting point for a real converter — it costs nothing for non-pydantic payloads. +> [!NOTE] +> To stay focused on the seam, this example only intercepts when a model is the +> *top-level* type. A model nested inside another model (like `Order.items`) +> round-trips because pydantic recurses on its own, but a model nested in a +> *top-level generic* the SDK rebuilds directly — e.g. `return_type=list[OrderItem]` +> or an input annotated `dict[str, OrderItem]` — is not intercepted and falls to +> the default codec, which leaves the elements as raw dicts. A production +> converter would recurse into such generics (for example via +> `pydantic.TypeAdapter`). + ## Inbound inputs: `can_reconstruct` There is one extra detail for reconstructing **inbound** orchestrator/activity diff --git a/examples/custom_data_converter/src/converter.py b/examples/custom_data_converter/src/converter.py index 29696f4..360fee6 100644 --- a/examples/custom_data_converter/src/converter.py +++ b/examples/custom_data_converter/src/converter.py @@ -61,6 +61,18 @@ class PydanticDataConverter(DataConverter): (with full validation) via ``model_validate_json()`` / ``model_validate()`` whenever the SDK supplies the model type. Every other value falls through to the default :class:`JsonDataConverter`. + + > [!NOTE] + > For simplicity this converter only intercepts when a pydantic model is the + > *top-level* ``target_type``. A model nested **inside another model** (like + > ``Order.items: list[OrderItem]`` below) round-trips fine because pydantic + > handles that recursion itself. But a model nested in a *top-level generic* + > the SDK is asked to rebuild directly -- e.g. ``return_type=list[OrderItem]`` + > or an input annotated ``dict[str, OrderItem]`` -- is **not** intercepted + > here; it falls to the default codec, which cannot construct a pydantic + > model from a positional ``dict`` and so leaves the elements as raw dicts. A + > production converter would recurse into such generics (for example via + > ``pydantic.TypeAdapter``); this example keeps the surface minimal. """ def __init__(self) -> None: diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index d5a9ecf..ae797c2 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -488,6 +488,27 @@ def test_coerce_to_type_without_converter_calls_single_arg_hook(): assert result == Widget("gear", 5) +def test_from_json_hook_unrelated_second_param_is_not_treated_as_converter(): + # A ``from_json`` whose second parameter is *not* named ``converter`` (here a + # ``strict`` flag with a default) must not be mistaken for a converter-aware + # hook -- otherwise the DataConverter would be bound to ``strict``. + from durabletask.serialization import JsonDataConverter + + @dataclass + class Flagged: + label: str + + @classmethod + def from_json(cls, value, strict=False): + # If the converter were passed here, ``strict`` would be truthy. + assert strict is False + return cls(value["label"]) + + conv = JsonDataConverter() + result = conv.deserialize('{"label": "ok"}', Flagged) + assert result == Flagged("ok") + + # ----- forward-reference resolution (Python 3.10 get_type_hints parity) ----- # # On Python 3.10, ``typing.get_type_hints`` does not deep-resolve forward From 25a404be1e486142a54f47cb78df0823ac7ca943 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 16:51:14 -0600 Subject: [PATCH 9/9] Final CHANGELOG tuneups --- CHANGELOG.md | 9 +++++---- durabletask-azuremanaged/CHANGELOG.md | 7 ++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3515eb..a85930d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,10 +45,11 @@ CHANGED error (`__cause__`) and names the offending type. - `EntityContext.get_state()` / `DurableEntity.get_state()` now return a freshly reconstructed value on every call rather than a reference to a single cached - object, so mutating the returned value in place no longer affects persisted - state — write it back with `set_state()`. State is also serialized eagerly at - `set_state()` time, so a non-serializable value fails inside the operation - (which rolls back) instead of after the batch has run. + object. This changes v1.6.0 behavior: mutating the returned value in place no + longer affects persisted state — write it back with `set_state()`. State is + also serialized eagerly at `set_state()` time, so a non-serializable value + fails inside the operation (which rolls back) instead of after the batch has + run. FIXED diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 005cb66..e9eb2b0 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -7,7 +7,12 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased -N/A +ADDED + +- `DurableTaskSchedulerWorker`, `DurableTaskSchedulerClient`, and the async + client now accept a `data_converter` argument and forward it to the base + worker/client, so a custom `durabletask.serialization.DataConverter` (for + example a pydantic-backed one) can be used with the Durable Task Scheduler. ## v1.6.0