Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions python/packages/autogen-core/src/autogen_core/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, cast

from PIL import Image as PILImage
from pydantic import GetCoreSchemaHandler, ValidationInfo
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from typing_extensions import Literal

Expand Down Expand Up @@ -84,25 +84,27 @@ def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> D

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
# Custom validation
def validate(value: Any, validation_info: ValidationInfo) -> Image:
if isinstance(value, dict):
base_64 = cast(str | None, value.get("data")) # type: ignore
if base_64 is None:
raise ValueError("Expected 'data' key in the dictionary")
return cls.from_base64(base_64)
elif isinstance(value, cls):
return value
else:
raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}")
# Custom validation for dict input (from JSON deserialization)
def validate_from_dict(value: dict[str, Any]) -> Image:
base_64 = cast(str | None, value.get("data"))
if base_64 is None:
raise ValueError("Expected 'data' key in the dictionary")
return cls.from_base64(base_64)

# Custom serialization
def serialize(value: Image) -> dict[str, Any]:
return {"data": value.to_base64()}

return core_schema.with_info_after_validator_function(
validate,
core_schema.any_schema(), # Accept any type; adjust if needed
# Keep Image validation narrow so Union[str, Image] can fall through to
# the string branch instead of failing inside the Image validator.
return core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
core_schema.no_info_after_validator_function(
validate_from_dict,
core_schema.dict_schema(),
),
],
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
)

Expand Down
19 changes: 19 additions & 0 deletions python/packages/autogen-core/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SerializationRegistry,
try_get_known_serializers_for_type,
)
from autogen_core.models import UserMessage
from PIL import Image as PILImage
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
from pydantic import BaseModel
Expand Down Expand Up @@ -188,6 +189,24 @@ class PydanticImageMessage(BaseModel):
assert deserialized.image.image == image.image


@pytest.mark.parametrize(
"content",
[
[Image(PILImage.new("RGB", (1, 1), color="red")), "Please describe this image"],
["What is in this image?", Image(PILImage.new("RGB", (1, 1), color="blue"))],
],
)
def test_user_message_round_trips_mixed_text_and_image_content(content: list[str | Image]) -> None:
message = UserMessage(content=content, source="user")

deserialized = UserMessage.model_validate_json(message.model_dump_json())

assert isinstance(deserialized.content, list)
assert len(deserialized.content) == 2
assert any(isinstance(item, Image) for item in deserialized.content)
assert any(isinstance(item, str) for item in deserialized.content)


def test_type_name_for_protos() -> None:
type_name = SerializationRegistry().type_name(ProtoMessage())
assert type_name == "agents.ProtoMessage"
Expand Down