Skip to content

Commit 637d965

Browse files
fix schema drift for chat-ui (ai-sdk 6) (marimo-team#8105)
## 📝 Summary <!-- Provide a concise summary of what this pull request is addressing. If this PR closes any issues, list them here by number (e.g., Closes marimo-team#123). --> There is a bug with Pydantic-AI where it started returning AI sdk v6 chunks. This has been fixed here pydantic/pydantic-ai#4166 but we need to do some patching to ensure we have the fix. pyproject.toml is only updated for dev & test dependencies. ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [ ] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [x] Tests have been added for the changes made. - [ ] Documentation has been updated where applicable, including docstrings for API changes. - [x] Pull request title is a good summary of the changes - it will be used in the [release notes](https://github.com/marimo-team/marimo/releases). --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 43f35eb commit 637d965

6 files changed

Lines changed: 136 additions & 23 deletions

File tree

marimo/_ai/llm/_impl.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
from __future__ import annotations
33

44
import dataclasses
5+
import json
56
import os
67
import re
78
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
89

910
from marimo import _loggers
1011
from marimo._ai._pydantic_ai_utils import generate_id
12+
from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION, DONE_CHUNK
1113
from marimo._plugins.utils import remove_none_values
1214

1315
if TYPE_CHECKING:
@@ -789,21 +791,37 @@ def _serialize_vercel_ai_chunk(
789791
) -> dict[str, Any] | None:
790792
"""
791793
Serialize vercel ai chunk to a dictionary. Skip "done" chunks - not part of Vercel AI SDK schema.
792-
793-
by_alias=True: Use camelCase keys expected by Vercel AI SDK.
794-
exclude_none=True: Remove null values which cause validation errors.
794+
We use encode as it uses Pydantic-AI's method of serializing dataclasses to JSON.
795795
"""
796796
try:
797-
serialized = chunk.model_dump(
798-
mode="json", by_alias=True, exclude_none=True
799-
)
797+
encoded = chunk.encode(sdk_version=AI_SDK_VERSION)
798+
if encoded == DONE_CHUNK:
799+
return None
800+
result = json.loads(encoded)
801+
if not isinstance(result, dict):
802+
LOGGER.debug(
803+
"Serialized vercel ai chunk is not a dictionary: %s",
804+
result,
805+
)
806+
return result # type: ignore[no-any-return]
807+
except TypeError:
808+
# Fallback for pydantic-ai < 1.52.0 which doesn't have sdk_version param
809+
try:
810+
# by_alias=True: Use camelCase keys expected by Vercel AI SDK.
811+
# exclude_none=True: Remove null values which cause validation errors.
812+
serialized = chunk.model_dump(
813+
mode="json", by_alias=True, exclude_none=True
814+
)
815+
except Exception as e:
816+
LOGGER.error("Error serializing vercel ai chunk: %s", e)
817+
return None
818+
else:
819+
if serialized.get("type") == "done":
820+
return None
821+
return serialized
800822
except Exception as e:
801823
LOGGER.error("Error serializing vercel ai chunk: %s", e)
802824
return None
803-
else:
804-
if serialized.get("type") == "done":
805-
return None
806-
return serialized
807825

808826
async def _stream_response(
809827
self, messages: list[ChatMessage], config: ChatModelConfig

marimo/_plugins/ui/_impl/chat/chat.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from __future__ import annotations
33

44
import inspect
5+
import json
56
import uuid
67
from dataclasses import dataclass
7-
from typing import Any, Callable, Final, Optional, Union, cast
8+
from typing import Any, Callable, Final, Literal, Optional, Union, cast
89

910
from marimo import _loggers
1011
from marimo._ai._types import (
@@ -34,6 +35,8 @@
3435
presence_penalty=0,
3536
)
3637

38+
# The version of the Vercel AI SDK we use
39+
AI_SDK_VERSION: Final[Literal[5, 6]] = 5
3740
DONE_CHUNK: Final[str] = "[DONE]"
3841

3942

@@ -464,13 +467,16 @@ def handle_chunk(self, chunk: Any) -> None:
464467
)
465468

466469
if isinstance(chunk, BaseChunk):
467-
# by_alias=True: Use camelCase keys expected by Vercel AI SDK.
468-
# exclude_none=True: Remove null values which cause validation errors.
469-
self.on_send_chunk(
470-
chunk.model_dump(
470+
try:
471+
serialized = json.loads(
472+
chunk.encode(sdk_version=AI_SDK_VERSION)
473+
)
474+
except TypeError:
475+
# Fallback for pydantic-ai < 1.52.0 which doesn't have sdk_version param
476+
serialized = chunk.model_dump(
471477
mode="json", by_alias=True, exclude_none=True
472478
)
473-
)
479+
self.on_send_chunk(serialized)
474480
return
475481

476482
# Handle plain text chunks

marimo/_server/ai/providers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
generate_id,
2626
)
2727
from marimo._dependencies.dependencies import Dependency, DependencyManager
28+
from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION
2829
from marimo._server.ai.config import AnyProviderConfig
2930
from marimo._server.ai.ids import AiModelId
3031
from marimo._server.ai.tools.tool_manager import get_tool_manager
@@ -166,9 +167,17 @@ async def stream_completion(
166167
stream_options = stream_options or StreamOptions()
167168

168169
vercel_adapter = self.get_vercel_adapter()
169-
adapter = vercel_adapter(
170-
agent=agent, run_input=run_input, accept=stream_options.accept
171-
)
170+
if DependencyManager.pydantic_ai.has_at_version(min_version="1.52.0"):
171+
adapter = vercel_adapter(
172+
agent=agent,
173+
run_input=run_input,
174+
accept=stream_options.accept,
175+
sdk_version=AI_SDK_VERSION,
176+
)
177+
else:
178+
adapter = vercel_adapter(
179+
agent=agent, run_input=run_input, accept=stream_options.accept
180+
)
172181
event_stream = adapter.run_stream()
173182
return adapter.streaming_response(event_stream)
174183

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ dev = [
129129
# For linting
130130
"ruff>=0.13.2",
131131
# For AI
132-
"pydantic-ai-slim[openai]>=1.47.0",
132+
"pydantic-ai-slim[openai]>=1.52.0",
133133
]
134134

135135
test = [
@@ -186,7 +186,7 @@ test-optional = [
186186
"anywidget~=0.9.18",
187187
"ipython~=8.12.3",
188188
# testing gen ai
189-
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.47.0",
189+
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.52.0",
190190
# - google-auth uses cachetools, and cachetools<5.0.0 uses collections.MutableMapping (removed in Python 3.10)
191191
"cachetools>=5.0.0",
192192
"boto3>=1.38.46",
@@ -246,7 +246,7 @@ dependencies = [
246246
"matplotlib>=3.8.0",
247247
"sqlglot[rs]>=26.2.0",
248248
"sqlalchemy>=2.0.40",
249-
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.47.0",
249+
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.52.0",
250250
"loro>=1.5.0",
251251
"pandas-stubs>=1.5.3.230321",
252252
"pyiceberg>=0.9.0",

tests/_ai/llm/test_impl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,42 @@ def test_pydantic_ai_serialize_vercel_ai_chunk(self) -> None:
15911591
"input": {"query": "test"},
15921592
}
15931593

1594+
def test_pydantic_ai_serialize_vercel_ai_chunk_v5(self) -> None:
1595+
"""Test that tool-input-start chunks exclude providerMetadata for SDK v5.
1596+
1597+
The Vercel AI SDK v5 schema drifts from v6, so we need to use Pydantic's handling.
1598+
1599+
For tool-input-start chunks, providerMetadata must be excluded.
1600+
See: https://github.com/pydantic/pydantic-ai/pull/4166
1601+
"""
1602+
from pydantic_ai.ui.vercel_ai.response_types import ToolInputStartChunk
1603+
1604+
mock_agent = MagicMock()
1605+
model = pydantic_ai(mock_agent)
1606+
1607+
# Create chunk with providerMetadata (like Google Gemini produces)
1608+
chunk = ToolInputStartChunk(
1609+
tool_call_id="tc_1",
1610+
tool_name="my_tool",
1611+
provider_metadata={
1612+
"pydantic_ai": {
1613+
"id": "test_id",
1614+
"provider_name": "google-gla",
1615+
"provider_details": {
1616+
"thought_signature": "encrypted_data"
1617+
},
1618+
}
1619+
},
1620+
)
1621+
result = model._serialize_vercel_ai_chunk(chunk)
1622+
1623+
# providerMetadata should be excluded for SDK v5 compatibility
1624+
assert result == {
1625+
"type": "tool-input-start",
1626+
"toolCallId": "tc_1",
1627+
"toolName": "my_tool",
1628+
}
1629+
15941630
def test_pydantic_ai_serialize_vercel_ai_chunk_done_type(self) -> None:
15951631
"""Test that 'done' type chunks are skipped."""
15961632
from pydantic_ai.ui.vercel_ai.response_types import DoneChunk

tests/_plugins/ui/_impl/chat/test_chat.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,6 @@ def on_send_chunk(chunk: dict):
953953

954954
serializer = ChunkSerializer(on_send_chunk=on_send_chunk)
955955

956-
# Pydantic BaseChunk should be serialized with model_dump
957956
chunk = TextDeltaChunk(id="text-1", delta="Hello")
958957
serializer.handle_chunk(chunk)
959958

@@ -963,6 +962,51 @@ def on_send_chunk(chunk: dict):
963962
]
964963

965964

965+
@pytest.mark.skipif(
966+
not DependencyManager.pydantic_ai.has(),
967+
reason="Pydantic AI is not installed",
968+
)
969+
def test_serialize_pydantic_v5():
970+
"""Test ChunkSerializer excludes providerMetadata from tool-input-start for SDK v5.
971+
972+
The Vercel AI SDK v5 schema drifts from v6, so we need to use Pydantic's handling.
973+
974+
Since pydantic-ai uses toolCallId, providerMetadata must be excluded.
975+
See: https://github.com/pydantic/pydantic-ai/pull/4166
976+
"""
977+
from pydantic_ai.ui.vercel_ai.response_types import ToolInputStartChunk
978+
979+
sent_chunks: list[dict] = []
980+
981+
def on_send_chunk(chunk: dict):
982+
sent_chunks.append(chunk)
983+
984+
serializer = ChunkSerializer(on_send_chunk=on_send_chunk)
985+
986+
# Create chunk with providerMetadata (like Google Gemini produces)
987+
chunk = ToolInputStartChunk(
988+
tool_call_id="tc_1",
989+
tool_name="my_tool",
990+
provider_metadata={
991+
"pydantic_ai": {
992+
"id": "test_id",
993+
"provider_name": "google-gla",
994+
"provider_details": {"thought_signature": "encrypted_data"},
995+
}
996+
},
997+
)
998+
serializer.handle_chunk(chunk)
999+
1000+
# providerMetadata should be excluded for SDK v5 compatibility
1001+
assert sent_chunks == [
1002+
{
1003+
"type": "tool-input-start",
1004+
"toolCallId": "tc_1",
1005+
"toolName": "my_tool",
1006+
}
1007+
]
1008+
1009+
9661010
@pytest.mark.skipif(
9671011
not DependencyManager.pydantic_ai.has(),
9681012
reason="Pydantic AI is not installed",

0 commit comments

Comments
 (0)