Skip to content

Commit 7e527bc

Browse files
feat(internal/types): support eagerly validating pydantic iterators
1 parent 625827c commit 7e527bc

2 files changed

Lines changed: 137 additions & 3 deletions

File tree

src/openai/_models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
Protocol,
2828
Required,
2929
Sequence,
30+
Annotated,
3031
ParamSpec,
32+
TypeAlias,
3133
TypedDict,
3234
TypeGuard,
3335
final,
@@ -81,7 +83,15 @@
8183
from ._constants import RAW_RESPONSE_HEADER
8284

8385
if TYPE_CHECKING:
86+
from pydantic import GetCoreSchemaHandler, ValidatorFunctionWrapHandler
87+
from pydantic_core import CoreSchema, core_schema
8488
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
89+
else:
90+
try:
91+
from pydantic_core import CoreSchema, core_schema
92+
except ImportError:
93+
CoreSchema = None
94+
core_schema = None
8595

8696
__all__ = ["BaseModel", "GenericModel"]
8797

@@ -422,6 +432,76 @@ def model_dump_json(
422432
)
423433

424434

435+
class _EagerIterable(list[_T], Generic[_T]):
436+
"""
437+
Accepts any Iterable[T] input (including generators), consumes it
438+
eagerly, and validates all items upfront.
439+
440+
Validation preserves the original container type where possible
441+
(e.g. a set[T] stays a set[T]). Serialization (model_dump / JSON)
442+
always emits a list — round-tripping through model_dump() will not
443+
restore the original container type.
444+
"""
445+
446+
@classmethod
447+
def __get_pydantic_core_schema__(
448+
cls,
449+
source_type: Any,
450+
handler: GetCoreSchemaHandler,
451+
) -> CoreSchema:
452+
(item_type,) = get_args(source_type) or (Any,)
453+
item_schema: CoreSchema = handler.generate_schema(item_type)
454+
list_of_items_schema: CoreSchema = core_schema.list_schema(item_schema)
455+
456+
return core_schema.no_info_wrap_validator_function(
457+
cls._validate,
458+
list_of_items_schema,
459+
serialization=core_schema.plain_serializer_function_ser_schema(
460+
cls._serialize,
461+
info_arg=False,
462+
),
463+
)
464+
465+
@staticmethod
466+
def _validate(v: Iterable[_T], handler: "ValidatorFunctionWrapHandler") -> Any:
467+
original_type: type[Any] = type(v)
468+
469+
# Normalize to list so list_schema can validate each item
470+
if isinstance(v, list):
471+
items: list[_T] = v
472+
else:
473+
try:
474+
items = list(v)
475+
except TypeError as e:
476+
raise TypeError("Value is not iterable") from e
477+
478+
# Validate items against the inner schema
479+
validated: list[_T] = handler(items)
480+
481+
# Reconstruct original container type
482+
if original_type is list:
483+
return validated
484+
# str(list) produces the list's repr, not a string built from items,
485+
# so skip reconstruction for str and its subclasses.
486+
if issubclass(original_type, str):
487+
return validated
488+
try:
489+
return original_type(validated)
490+
except (TypeError, ValueError):
491+
# If the type cannot be reconstructed, just return the validated list
492+
return validated
493+
494+
@staticmethod
495+
def _serialize(v: Iterable[_T]) -> list[_T]:
496+
"""Always serialize as a list so Pydantic's JSON encoder is happy."""
497+
if isinstance(v, list):
498+
return v
499+
return list(v)
500+
501+
502+
EagerIterable: TypeAlias = Annotated[Iterable[_T], _EagerIterable]
503+
504+
425505
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
426506
if value is None:
427507
return field_get_default(field)

tests/test_models.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import json
2-
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
2+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Iterable, Optional, cast
33
from datetime import datetime, timezone
4-
from typing_extensions import Literal, Annotated, TypeAliasType
4+
from collections import deque
5+
from typing_extensions import Literal, Annotated, TypedDict, TypeAliasType
56

67
import pytest
78
import pydantic
89
from pydantic import Field
910

1011
from openai._utils import PropertyInfo
1112
from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
12-
from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
13+
from openai._models import DISCRIMINATOR_CACHE, BaseModel, EagerIterable, construct_type
1314

1415

1516
class BasicModel(BaseModel):
@@ -961,3 +962,56 @@ def __getattr__(self, attr: str) -> Item: ...
961962
assert model.a.prop == 1
962963
assert isinstance(model.a, Item)
963964
assert model.other == "foo"
965+
966+
967+
# NOTE: Workaround for Pydantic Iterable behavior.
968+
# Iterable fields are replaced with a ValidatorIterator and may be consumed
969+
# during serialization, which can cause subsequent dumps to return empty data.
970+
# See: https://github.com/pydantic/pydantic/issues/9541
971+
@pytest.mark.parametrize(
972+
"data, expected_validated",
973+
[
974+
([1, 2, 3], [1, 2, 3]),
975+
((1, 2, 3), (1, 2, 3)),
976+
(set([1, 2, 3]), set([1, 2, 3])),
977+
(iter([1, 2, 3]), [1, 2, 3]),
978+
([], []),
979+
((x for x in [1, 2, 3]), [1, 2, 3]),
980+
(map(lambda x: x, [1, 2, 3]), [1, 2, 3]),
981+
(frozenset([1, 2, 3]), frozenset([1, 2, 3])),
982+
(deque([1, 2, 3]), deque([1, 2, 3])),
983+
],
984+
ids=["list", "tuple", "set", "iterator", "empty", "generator", "map", "frozenset", "deque"],
985+
)
986+
@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2")
987+
def test_iterable_construction(data: Iterable[int], expected_validated: Iterable[int]) -> None:
988+
class TypeWithIterable(TypedDict):
989+
items: EagerIterable[int]
990+
991+
class Model(BaseModel):
992+
data: TypeWithIterable
993+
994+
m = Model.model_validate({"data": {"items": data}})
995+
assert m.data["items"] == expected_validated
996+
997+
# Verify repeated dumps don't lose data (the original bug)
998+
assert m.model_dump()["data"]["items"] == list(expected_validated)
999+
assert m.model_dump()["data"]["items"] == list(expected_validated)
1000+
1001+
1002+
@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2")
1003+
def test_iterable_construction_str_falls_back_to_list() -> None:
1004+
# str is iterable (over chars), but str(list_of_chars) produces the list's repr
1005+
# rather than reconstructing a string from items. We special-case str to fall
1006+
# back to list instead of attempting reconstruction.
1007+
class TypeWithIterable(TypedDict):
1008+
items: EagerIterable[str]
1009+
1010+
class Model(BaseModel):
1011+
data: TypeWithIterable
1012+
1013+
m = Model.model_validate({"data": {"items": "hello"}})
1014+
1015+
# falls back to list of chars rather than calling str(["h", "e", "l", "l", "o"])
1016+
assert m.data["items"] == ["h", "e", "l", "l", "o"]
1017+
assert m.model_dump()["data"]["items"] == ["h", "e", "l", "l", "o"]

0 commit comments

Comments
 (0)