Skip to content

Commit 6c51ddf

Browse files
authored
Merge branch 'main' into dependabot/uv/langsmith-0.8.0
2 parents 0569753 + 28243b9 commit 6c51ddf

7 files changed

Lines changed: 101 additions & 32 deletions

File tree

.github/workflows/omes.yml

Lines changed: 0 additions & 22 deletions
This file was deleted.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ class IPv4AddressJSONEncoder(AdvancedJSONEncoder):
432432
class IPv4AddressJSONTypeConverter(JSONTypeConverter):
433433
def to_typed_value(
434434
self, hint: Type, value: Any
435-
) -> Union[Optional[Any], _JSONTypeConverterUnhandled]:
435+
) -> Union[Optional[Any], JSONTypeConverterUnhandled]:
436436
if issubclass(hint, ipaddress.IPv4Address):
437437
return ipaddress.IPv4Address(value)
438438
return JSONTypeConverter.Unhandled

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

33
import dataclasses
4+
import json
45
import typing
56
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
67
from contextlib import asynccontextmanager, contextmanager
78
from datetime import timedelta
89

10+
import pydantic
911
from agents import ModelProvider, Trace, set_trace_provider
1012
from agents.run import get_default_agent_runner, set_default_agent_runner
1113
from agents.tracing import get_trace_provider
1214
from agents.tracing.provider import DefaultTraceProvider
1315

16+
# construct_type is OpenAI's lenient (non-validating) model builder, the same
17+
# one the SDK uses to parse live API responses. It is in a private module but
18+
# has no public alias.
19+
from openai._models import construct_type
20+
21+
import temporalio.api.common.v1
1422
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
1523
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
1624
from temporalio.contrib.openai_agents._openai_runner import (
@@ -25,12 +33,14 @@
2533
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
2634
from temporalio.contrib.opentelemetry._tracer_provider import ReplaySafeTracerProvider
2735
from temporalio.contrib.pydantic import (
28-
PydanticPayloadConverter,
36+
PydanticJSONPlainPayloadConverter,
2937
ToJsonOptions,
3038
)
3139
from temporalio.converter import (
40+
CompositePayloadConverter,
3241
DataConverter,
3342
DefaultPayloadConverter,
43+
JSONPlainPayloadConverter,
3444
)
3545
from temporalio.plugin import SimplePlugin
3646
from temporalio.worker import WorkflowRunner
@@ -64,12 +74,72 @@ def _set_open_ai_agent_temporal_overrides(
6474
set_trace_provider(previous_trace_provider or DefaultTraceProvider())
6575

6676

67-
class OpenAIPayloadConverter(PydanticPayloadConverter):
77+
def _lenient_construct(type_: typing.Any, value: typing.Any) -> typing.Any:
78+
"""Build ``value`` into ``type_`` without enforcing required fields.
79+
80+
OpenAI's ``construct_type`` handles its own response models (and the
81+
unions/lists thereof), but not the ``agents`` dataclasses that wrap them
82+
(e.g. ``ModelResponse``), so the dataclass layer is reconstructed here and
83+
each field delegated to ``construct_type``. ``include_extras`` preserves the
84+
``Annotated`` discriminators the unions rely on.
85+
"""
86+
if (
87+
isinstance(type_, type)
88+
and dataclasses.is_dataclass(type_)
89+
and isinstance(value, dict)
90+
):
91+
hints = typing.get_type_hints(type_, include_extras=True)
92+
return type_(
93+
**{
94+
field.name: _lenient_construct(
95+
hints.get(field.name, object), value[field.name]
96+
)
97+
for field in dataclasses.fields(type_)
98+
if field.name in value
99+
}
100+
)
101+
return construct_type(type_=type_, value=value)
102+
103+
104+
class _OpenAIJSONPlainPayloadConverter(PydanticJSONPlainPayloadConverter):
105+
"""Strict pydantic deserialization with a lenient fallback.
106+
107+
OpenAI's response models can drift from live API payloads (e.g. a
108+
deprecated-but-required field the API has stopped sending). The SDK tolerates
109+
this when parsing responses, but strict ``validate_json`` on the workflow
110+
side does not, so fall back to lenient construction when validation fails.
111+
"""
112+
113+
def from_payload(
114+
self,
115+
payload: temporalio.api.common.v1.Payload,
116+
type_hint: type | None = None,
117+
) -> typing.Any:
118+
"""See base class."""
119+
try:
120+
return super().from_payload(payload, type_hint)
121+
except pydantic.ValidationError:
122+
if type_hint is None:
123+
raise
124+
return _lenient_construct(type_hint, json.loads(payload.data))
125+
126+
127+
class OpenAIPayloadConverter(CompositePayloadConverter):
68128
"""PayloadConverter for OpenAI agents."""
69129

70130
def __init__(self) -> None:
71131
"""Initialize a payload converter."""
72-
super().__init__(ToJsonOptions(exclude_unset=True))
132+
json_payload_converter = _OpenAIJSONPlainPayloadConverter(
133+
ToJsonOptions(exclude_unset=True)
134+
)
135+
super().__init__(
136+
*(
137+
c
138+
if not isinstance(c, JSONPlainPayloadConverter)
139+
else json_payload_converter
140+
for c in DefaultPayloadConverter.default_encoding_payload_converters
141+
)
142+
)
73143

74144

75145
def _data_converter(converter: DataConverter | None) -> DataConverter:

temporalio/converter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
JSONPlainPayloadConverter,
3232
JSONProtoPayloadConverter,
3333
JSONTypeConverter,
34+
JSONTypeConverterUnhandled,
3435
PayloadConverter,
3536
value_to_type,
3637
)
@@ -76,6 +77,7 @@
7677
"JSONPlainPayloadConverter",
7778
"JSONProtoPayloadConverter",
7879
"JSONTypeConverter",
80+
"JSONTypeConverterUnhandled",
7981
"PayloadCodec",
8082
"PayloadConverter",
8183
"PayloadLimitsConfig",

temporalio/converter/_payload_converter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,23 +548,28 @@ def default(self, o: Any) -> Any:
548548
return super().default(o)
549549

550550

551-
_JSONTypeConverterUnhandled = NewType("_JSONTypeConverterUnhandled", object)
551+
JSONTypeConverterUnhandled = NewType("JSONTypeConverterUnhandled", object)
552+
"""Type of :py:attr:`JSONTypeConverter.Unhandled`."""
553+
554+
_JSONTypeConverterUnhandled = JSONTypeConverterUnhandled
552555

553556

554557
class JSONTypeConverter(ABC):
555558
"""Converter for converting an object from Python :py:func:`json.loads`
556559
result (e.g. scalar, list, or dict) to a known type.
557560
"""
558561

559-
Unhandled = _JSONTypeConverterUnhandled(object())
562+
Unhandled: ClassVar[JSONTypeConverterUnhandled] = JSONTypeConverterUnhandled(
563+
object()
564+
)
560565
"""Sentinel value that must be used as the result of
561566
:py:meth:`to_typed_value` to say the given type is not handled by this
562567
converter."""
563568

564569
@abstractmethod
565570
def to_typed_value(
566571
self, hint: type, value: Any
567-
) -> Any | None | _JSONTypeConverterUnhandled:
572+
) -> Any | None | JSONTypeConverterUnhandled:
568573
"""Convert the given value to a type based on the given hint.
569574
570575
Args:

tests/contrib/openai_agents/test_openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,9 @@ def research_mock_model():
547547
id="",
548548
status="completed",
549549
type="web_search_call",
550-
action=ActionSearch(query="", type="search"),
550+
action=ActionSearch.model_construct(
551+
type="search", queries=[""]
552+
),
551553
),
552554
ResponseBuilders.response_output_message("Granada"),
553555
],

tests/test_converter.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
Dict, # type:ignore[reportDeprecated]
1717
Literal,
1818
NewType,
19+
get_args,
20+
get_type_hints,
1921
)
2022
from uuid import UUID, uuid4
2123

@@ -40,12 +42,12 @@
4042
DefaultPayloadConverter,
4143
JSONPlainPayloadConverter,
4244
JSONTypeConverter,
45+
JSONTypeConverterUnhandled,
4346
PayloadCodec,
4447
decode_search_attributes,
4548
encode_search_attribute_values,
4649
value_to_type,
4750
)
48-
from temporalio.converter._payload_converter import _JSONTypeConverterUnhandled
4951
from temporalio.exceptions import (
5052
ApplicationError,
5153
FailureError,
@@ -869,12 +871,22 @@ def default(self, o: Any) -> Any:
869871
class IPv4AddressJSONTypeConverter(JSONTypeConverter):
870872
def to_typed_value(
871873
self, hint: type, value: Any
872-
) -> Any | None | _JSONTypeConverterUnhandled:
874+
) -> Any | None | JSONTypeConverterUnhandled:
873875
if inspect.isclass(hint) and issubclass(hint, ipaddress.IPv4Address):
874876
return ipaddress.IPv4Address(value)
875877
return JSONTypeConverter.Unhandled
876878

877879

880+
def test_json_type_converter_unhandled_type_public():
881+
return_type = get_type_hints(JSONTypeConverter.to_typed_value)["return"]
882+
883+
assert JSONTypeConverterUnhandled.__name__ == "JSONTypeConverterUnhandled"
884+
assert JSONTypeConverterUnhandled in get_args(return_type)
885+
assert JSONTypeConverterUnhandled(JSONTypeConverter.Unhandled) is (
886+
JSONTypeConverter.Unhandled
887+
)
888+
889+
878890
async def test_json_type_converter():
879891
addr = ipaddress.IPv4Address("1.2.3.4")
880892
custom_conv = dataclasses.replace(

0 commit comments

Comments
 (0)