Skip to content

Commit b5ffb62

Browse files
fix: make Image pydantic schema Union-friendly
When UserMessage.content contains mixed Image and str items, Pydantic tries Image validation before the string branch. The previous any_schema validator raised TypeError for strings, preventing Union[str, Image] fallback.\n\nUse a narrow Image schema that accepts Image instances or serialized dicts, and add a regression to ensure mixed text/image UserMessage content round-trips through JSON in both orders.\n\nFixes #7170.
1 parent 13e144e commit b5ffb62

2 files changed

Lines changed: 36 additions & 15 deletions

File tree

python/packages/autogen-core/src/autogen_core/_image.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Dict, cast
88

99
from PIL import Image as PILImage
10-
from pydantic import GetCoreSchemaHandler, ValidationInfo
10+
from pydantic import GetCoreSchemaHandler
1111
from pydantic_core import core_schema
1212
from typing_extensions import Literal
1313

@@ -84,25 +84,27 @@ def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> D
8484

8585
@classmethod
8686
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
87-
# Custom validation
88-
def validate(value: Any, validation_info: ValidationInfo) -> Image:
89-
if isinstance(value, dict):
90-
base_64 = cast(str | None, value.get("data")) # type: ignore
91-
if base_64 is None:
92-
raise ValueError("Expected 'data' key in the dictionary")
93-
return cls.from_base64(base_64)
94-
elif isinstance(value, cls):
95-
return value
96-
else:
97-
raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}")
87+
# Custom validation for dict input (from JSON deserialization)
88+
def validate_from_dict(value: dict[str, Any]) -> Image:
89+
base_64 = cast(str | None, value.get("data"))
90+
if base_64 is None:
91+
raise ValueError("Expected 'data' key in the dictionary")
92+
return cls.from_base64(base_64)
9893

9994
# Custom serialization
10095
def serialize(value: Image) -> dict[str, Any]:
10196
return {"data": value.to_base64()}
10297

103-
return core_schema.with_info_after_validator_function(
104-
validate,
105-
core_schema.any_schema(), # Accept any type; adjust if needed
98+
# Keep Image validation narrow so Union[str, Image] can fall through to
99+
# the string branch instead of failing inside the Image validator.
100+
return core_schema.union_schema(
101+
[
102+
core_schema.is_instance_schema(cls),
103+
core_schema.no_info_after_validator_function(
104+
validate_from_dict,
105+
core_schema.dict_schema(),
106+
),
107+
],
106108
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
107109
)
108110

python/packages/autogen-core/tests/test_serialization.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SerializationRegistry,
1313
try_get_known_serializers_for_type,
1414
)
15+
from autogen_core.models import UserMessage
1516
from PIL import Image as PILImage
1617
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
1718
from pydantic import BaseModel
@@ -188,6 +189,24 @@ class PydanticImageMessage(BaseModel):
188189
assert deserialized.image.image == image.image
189190

190191

192+
@pytest.mark.parametrize(
193+
"content",
194+
[
195+
[Image(PILImage.new("RGB", (1, 1), color="red")), "Please describe this image"],
196+
["What is in this image?", Image(PILImage.new("RGB", (1, 1), color="blue"))],
197+
],
198+
)
199+
def test_user_message_round_trips_mixed_text_and_image_content(content: list[str | Image]) -> None:
200+
message = UserMessage(content=content, source="user")
201+
202+
deserialized = UserMessage.model_validate_json(message.model_dump_json())
203+
204+
assert isinstance(deserialized.content, list)
205+
assert len(deserialized.content) == 2
206+
assert any(isinstance(item, Image) for item in deserialized.content)
207+
assert any(isinstance(item, str) for item in deserialized.content)
208+
209+
191210
def test_type_name_for_protos() -> None:
192211
type_name = SerializationRegistry().type_name(ProtoMessage())
193212
assert type_name == "agents.ProtoMessage"

0 commit comments

Comments
 (0)