|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import dataclasses |
| 5 | +import json |
5 | 6 | import os |
6 | 7 | import re |
7 | 8 | from typing import TYPE_CHECKING, Any, Callable, Optional, cast |
8 | 9 |
|
9 | 10 | from marimo import _loggers |
10 | 11 | from marimo._ai._pydantic_ai_utils import generate_id |
| 12 | +from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION, DONE_CHUNK |
11 | 13 | from marimo._plugins.utils import remove_none_values |
12 | 14 |
|
13 | 15 | if TYPE_CHECKING: |
@@ -789,21 +791,37 @@ def _serialize_vercel_ai_chunk( |
789 | 791 | ) -> dict[str, Any] | None: |
790 | 792 | """ |
791 | 793 | Serialize vercel ai chunk to a dictionary. Skip "done" chunks - not part of Vercel AI SDK schema. |
792 | | -
|
793 | | - by_alias=True: Use camelCase keys expected by Vercel AI SDK. |
794 | | - exclude_none=True: Remove null values which cause validation errors. |
| 794 | + We use encode as it uses Pydantic-AI's method of serializing dataclasses to JSON. |
795 | 795 | """ |
796 | 796 | try: |
797 | | - serialized = chunk.model_dump( |
798 | | - mode="json", by_alias=True, exclude_none=True |
799 | | - ) |
| 797 | + encoded = chunk.encode(sdk_version=AI_SDK_VERSION) |
| 798 | + if encoded == DONE_CHUNK: |
| 799 | + return None |
| 800 | + result = json.loads(encoded) |
| 801 | + if not isinstance(result, dict): |
| 802 | + LOGGER.debug( |
| 803 | + "Serialized vercel ai chunk is not a dictionary: %s", |
| 804 | + result, |
| 805 | + ) |
| 806 | + return result # type: ignore[no-any-return] |
| 807 | + except TypeError: |
| 808 | + # Fallback for pydantic-ai < 1.52.0 which doesn't have sdk_version param |
| 809 | + try: |
| 810 | + # by_alias=True: Use camelCase keys expected by Vercel AI SDK. |
| 811 | + # exclude_none=True: Remove null values which cause validation errors. |
| 812 | + serialized = chunk.model_dump( |
| 813 | + mode="json", by_alias=True, exclude_none=True |
| 814 | + ) |
| 815 | + except Exception as e: |
| 816 | + LOGGER.error("Error serializing vercel ai chunk: %s", e) |
| 817 | + return None |
| 818 | + else: |
| 819 | + if serialized.get("type") == "done": |
| 820 | + return None |
| 821 | + return serialized |
800 | 822 | except Exception as e: |
801 | 823 | LOGGER.error("Error serializing vercel ai chunk: %s", e) |
802 | 824 | return None |
803 | | - else: |
804 | | - if serialized.get("type") == "done": |
805 | | - return None |
806 | | - return serialized |
807 | 825 |
|
808 | 826 | async def _stream_response( |
809 | 827 | self, messages: list[ChatMessage], config: ChatModelConfig |
|
0 commit comments