|
1 | 1 | """Initialize Temporal OpenAI Agents overrides.""" |
2 | 2 |
|
3 | 3 | import dataclasses |
| 4 | +import json |
4 | 5 | import typing |
5 | 6 | from collections.abc import AsyncIterator, Callable, Iterator, Sequence |
6 | 7 | from contextlib import asynccontextmanager, contextmanager |
7 | 8 | from datetime import timedelta |
8 | 9 |
|
| 10 | +import pydantic |
9 | 11 | from agents import ModelProvider, Trace, set_trace_provider |
10 | 12 | from agents.run import get_default_agent_runner, set_default_agent_runner |
11 | 13 | from agents.tracing import get_trace_provider |
12 | 14 | from agents.tracing.provider import DefaultTraceProvider |
13 | 15 |
|
| 16 | +# construct_type is OpenAI's lenient (non-validating) model builder, the same |
| 17 | +# one the SDK uses to parse live API responses. It is in a private module but |
| 18 | +# has no public alias. |
| 19 | +from openai._models import construct_type |
| 20 | + |
| 21 | +import temporalio.api.common.v1 |
14 | 22 | from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity |
15 | 23 | from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters |
16 | 24 | from temporalio.contrib.openai_agents._openai_runner import ( |
|
25 | 33 | from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError |
26 | 34 | from temporalio.contrib.opentelemetry._tracer_provider import ReplaySafeTracerProvider |
27 | 35 | from temporalio.contrib.pydantic import ( |
28 | | - PydanticPayloadConverter, |
| 36 | + PydanticJSONPlainPayloadConverter, |
29 | 37 | ToJsonOptions, |
30 | 38 | ) |
31 | 39 | from temporalio.converter import ( |
| 40 | + CompositePayloadConverter, |
32 | 41 | DataConverter, |
33 | 42 | DefaultPayloadConverter, |
| 43 | + JSONPlainPayloadConverter, |
34 | 44 | ) |
35 | 45 | from temporalio.plugin import SimplePlugin |
36 | 46 | from temporalio.worker import WorkflowRunner |
@@ -64,12 +74,72 @@ def _set_open_ai_agent_temporal_overrides( |
64 | 74 | set_trace_provider(previous_trace_provider or DefaultTraceProvider()) |
65 | 75 |
|
66 | 76 |
|
67 | | -class OpenAIPayloadConverter(PydanticPayloadConverter): |
| 77 | +def _lenient_construct(type_: typing.Any, value: typing.Any) -> typing.Any: |
| 78 | + """Build ``value`` into ``type_`` without enforcing required fields. |
| 79 | +
|
| 80 | + OpenAI's ``construct_type`` handles its own response models (and the |
| 81 | + unions/lists thereof), but not the ``agents`` dataclasses that wrap them |
| 82 | + (e.g. ``ModelResponse``), so the dataclass layer is reconstructed here and |
| 83 | + each field delegated to ``construct_type``. ``include_extras`` preserves the |
| 84 | + ``Annotated`` discriminators the unions rely on. |
| 85 | + """ |
| 86 | + if ( |
| 87 | + isinstance(type_, type) |
| 88 | + and dataclasses.is_dataclass(type_) |
| 89 | + and isinstance(value, dict) |
| 90 | + ): |
| 91 | + hints = typing.get_type_hints(type_, include_extras=True) |
| 92 | + return type_( |
| 93 | + **{ |
| 94 | + field.name: _lenient_construct( |
| 95 | + hints.get(field.name, object), value[field.name] |
| 96 | + ) |
| 97 | + for field in dataclasses.fields(type_) |
| 98 | + if field.name in value |
| 99 | + } |
| 100 | + ) |
| 101 | + return construct_type(type_=type_, value=value) |
| 102 | + |
| 103 | + |
| 104 | +class _OpenAIJSONPlainPayloadConverter(PydanticJSONPlainPayloadConverter): |
| 105 | + """Strict pydantic deserialization with a lenient fallback. |
| 106 | +
|
| 107 | + OpenAI's response models can drift from live API payloads (e.g. a |
| 108 | + deprecated-but-required field the API has stopped sending). The SDK tolerates |
| 109 | + this when parsing responses, but strict ``validate_json`` on the workflow |
| 110 | + side does not, so fall back to lenient construction when validation fails. |
| 111 | + """ |
| 112 | + |
| 113 | + def from_payload( |
| 114 | + self, |
| 115 | + payload: temporalio.api.common.v1.Payload, |
| 116 | + type_hint: type | None = None, |
| 117 | + ) -> typing.Any: |
| 118 | + """See base class.""" |
| 119 | + try: |
| 120 | + return super().from_payload(payload, type_hint) |
| 121 | + except pydantic.ValidationError: |
| 122 | + if type_hint is None: |
| 123 | + raise |
| 124 | + return _lenient_construct(type_hint, json.loads(payload.data)) |
| 125 | + |
| 126 | + |
| 127 | +class OpenAIPayloadConverter(CompositePayloadConverter): |
68 | 128 | """PayloadConverter for OpenAI agents.""" |
69 | 129 |
|
70 | 130 | def __init__(self) -> None: |
71 | 131 | """Initialize a payload converter.""" |
72 | | - super().__init__(ToJsonOptions(exclude_unset=True)) |
| 132 | + json_payload_converter = _OpenAIJSONPlainPayloadConverter( |
| 133 | + ToJsonOptions(exclude_unset=True) |
| 134 | + ) |
| 135 | + super().__init__( |
| 136 | + *( |
| 137 | + c |
| 138 | + if not isinstance(c, JSONPlainPayloadConverter) |
| 139 | + else json_payload_converter |
| 140 | + for c in DefaultPayloadConverter.default_encoding_payload_converters |
| 141 | + ) |
| 142 | + ) |
73 | 143 |
|
74 | 144 |
|
75 | 145 | def _data_converter(converter: DataConverter | None) -> DataConverter: |
|
0 commit comments