Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/ai/chat/pydantic-ai-chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "httpx==0.28.1",
# "marimo>=0.21.1",
# "pydantic==2.12.5",
# ]
# ///
import marimo

__generated_with = "0.19.2"
__generated_with = "0.21.1"
app = marimo.App(width="medium")

with app.setup(hide_code=True):
Expand Down
4 changes: 2 additions & 2 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"./unstable_internal/*": "./src/*"
},
"dependencies": {
"@ai-sdk/react": "^2.0.125",
"@ai-sdk/react": "^3.0.131",
"@anywidget/types": "^0.2.0",
Comment on lines 23 to 25
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR checklist/description states that tests were added, but this PR only shows dependency/version updates and type fixes (no test file changes). If no tests were actually added, please update the checklist/description; otherwise, point to the new/updated tests.

Copilot uses AI. Check for mistakes.
"@codemirror/autocomplete": "^6.20.1",
"@codemirror/commands": "^6.10.2",
Expand Down Expand Up @@ -114,7 +114,7 @@
"@xterm/addon-web-links": "^0.12.0",
"@xterm/xterm": "^5.5.0",
"@zed-industries/agent-client-protocol": "^0.4.5",
"ai": "^5.0.123",
"ai": "^6.0.129",
"ansi_up": "^6.0.6",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
Expand Down
15 changes: 7 additions & 8 deletions frontend/src/components/chat/chat-utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
/* Copyright 2026 Marimo. All rights reserved. */

import type { components } from "@marimo-team/marimo-api";
import type { FileUIPart, ToolUIPart, UIMessage } from "ai";
import type {
ChatAddToolOutputFunction,
FileUIPart,
ToolUIPart,
UIMessage,
} from "ai";
import { useState } from "react";
import useEvent from "react-use-event-hook";
import type { ProviderId } from "@/core/ai/ids/ids";
Expand Down Expand Up @@ -111,20 +116,14 @@ export async function buildCompletionRequestBody(
};
}

interface AddToolOutput {
tool: string;
toolCallId: string;
output: unknown;
}

export async function handleToolCall({
invokeAiTool,
addToolOutput, // Important that we don't await addToolOutput to prevent potential deadlocks
toolCall,
toolContext,
}: {
invokeAiTool: (request: InvokeAiToolRequest) => Promise<InvokeAiToolResponse>;
addToolOutput: (output: AddToolOutput) => Promise<void>;
addToolOutput: ChatAddToolOutputFunction<UIMessage>;
toolCall: {
toolName: string;
toolCallId: string;
Expand Down
6 changes: 6 additions & 0 deletions frontend/src/core/ai/staged-cells.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ export function useStagedCells(store: JotaiStore) {
case "tool-output-error":
Logger.error("Error", chunk.type, { chunk });
break;
case "tool-approval-request":
Logger.log("Tool approval request", { chunk });
break;
case "tool-output-denied":
Logger.error("Tool output denied", { chunk });
break;
// These logs are not useful for debugging
case "start":
case "start-step":
Expand Down
27 changes: 14 additions & 13 deletions marimo/_plugins/ui/_impl/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@
)

# The version of the Vercel AI SDK we use
AI_SDK_VERSION: Final[Literal[5, 6]] = 5
AI_SDK_VERSION: Final[Literal[5, 6]] = 6
DONE_CHUNK: Final[str] = "[DONE]"


def require_vercel_ai_sdk_support() -> None:
"""Only Pydantic AI >=1.52.0 supports AI SDK v6. So, we require it."""
DependencyManager.pydantic_ai.require_at_version(
why="for Vercel AI SDK support", min_version="1.52.0"
)


@dataclass
class SendMessageRequest:
messages: list[ChatMessage]
Expand Down Expand Up @@ -413,9 +420,8 @@ def _convert_value(self, value: dict[str, Any]) -> list[ChatMessage]:

part_validator_class = None
if DependencyManager.pydantic_ai.imported():
from pydantic_ai.ui.vercel_ai.request_types import (
UIMessagePart,
)
require_vercel_ai_sdk_support()
from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart

# The frontend sends messages as ChatMessage parts so we use pydantic-ai to cast them
# as Vercel UIMessagePart
Expand Down Expand Up @@ -466,20 +472,15 @@ def handle_chunk(self, chunk: Any) -> None:

# Handle Pydantic AI's Vercel AI SDK chunks
if DependencyManager.pydantic_ai.imported():
require_vercel_ai_sdk_support()
from pydantic_ai.ui.vercel_ai.response_types import (
BaseChunk,
)

if isinstance(chunk, BaseChunk):
try:
serialized = json.loads(
chunk.encode(sdk_version=AI_SDK_VERSION)
)
except TypeError:
# Fallback for pydantic-ai < 1.52.0 which doesn't have sdk_version param
serialized = chunk.model_dump(
mode="json", by_alias=True, exclude_none=True
)
serialized = json.loads(
chunk.encode(sdk_version=AI_SDK_VERSION)
)
self.on_send_chunk(serialized)
return

Expand Down
105 changes: 14 additions & 91 deletions marimo/_server/ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
generate_id,
)
from marimo._dependencies.dependencies import Dependency, DependencyManager
from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION
from marimo._plugins.ui._impl.chat.chat import (
AI_SDK_VERSION,
require_vercel_ai_sdk_support,
)
from marimo._server.ai.config import AnyProviderConfig
from marimo._server.ai.ids import AiModelId
from marimo._server.ai.tools.tool_manager import get_tool_manager
Expand All @@ -40,7 +43,6 @@
from openai import AsyncOpenAI
from openai.types.shared.reasoning_effort import ReasoningEffort
from pydantic_ai import Agent, DeferredToolRequests, FunctionToolset
from pydantic_ai.messages import ThinkingPart
from pydantic_ai.models import Model
from pydantic_ai.models.bedrock import BedrockConverseModel
from pydantic_ai.models.google import GoogleModel
Expand All @@ -54,9 +56,7 @@
)
from pydantic_ai.providers.google import GoogleProvider as PydanticGoogle
from pydantic_ai.providers.openai import OpenAIProvider as PydanticOpenAI
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
from pydantic_ai.ui.vercel_ai.request_types import UIMessage, UIMessagePart
from pydantic_ai.ui.vercel_ai.response_types import BaseChunk
from starlette.responses import StreamingResponse


Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
*(deps or []),
source="server",
)
require_vercel_ai_sdk_support()

self.model = model
self.config = config
Expand Down Expand Up @@ -132,12 +133,6 @@ def create_agent(
output_type=output_type,
)

def get_vercel_adapter(self) -> type[VercelAIAdapter[Any, Any]]:
"""Return the Vercel AI adapter for the given provider."""
from pydantic_ai.ui.vercel_ai import VercelAIAdapter

return VercelAIAdapter

def convert_messages(
self, messages: list[ServerUIMessage]
) -> list[UIMessage]:
Expand All @@ -153,6 +148,7 @@ async def stream_completion(
stream_options: Optional[StreamOptions] = None,
) -> StreamingResponse:
"""Return a streaming response from the given messages. The response are AI SDK events."""
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
from pydantic_ai.ui.vercel_ai.request_types import SubmitMessage

tools = (self.config.tools or []) + additional_tools
Expand All @@ -169,18 +165,12 @@ async def stream_completion(
# TODO: Text only and format stream are not supported yet
stream_options = stream_options or StreamOptions()

vercel_adapter = self.get_vercel_adapter()
if DependencyManager.pydantic_ai.has_at_version(min_version="1.52.0"):
adapter = vercel_adapter(
agent=agent,
run_input=run_input,
accept=stream_options.accept,
sdk_version=AI_SDK_VERSION,
)
else:
adapter = vercel_adapter(
agent=agent, run_input=run_input, accept=stream_options.accept
)
adapter = VercelAIAdapter(
agent=agent,
run_input=run_input,
accept=stream_options.accept,
sdk_version=AI_SDK_VERSION,
)
event_stream = adapter.run_stream()
return adapter.streaming_response(event_stream)

Expand All @@ -193,16 +183,16 @@ async def stream_text(
additional_tools: list[ToolDefinition],
) -> AsyncGenerator[str]:
"""Return a stream of text from the given messages."""
from pydantic_ai.ui.vercel_ai import VercelAIAdapter

tools = (self.config.tools or []) + additional_tools
agent = self.create_agent(
max_tokens=max_tokens, tools=tools, system_prompt=system_prompt
)
vercel_adapter = self.get_vercel_adapter()

async with agent.run_stream(
user_prompt=user_prompt,
message_history=vercel_adapter.load_messages(
message_history=VercelAIAdapter.load_messages(
self.convert_messages(messages)
),
) as result:
Expand Down Expand Up @@ -800,73 +790,6 @@ def process_part(self, part: UIMessagePart) -> UIMessagePart:
)
return part

def get_vercel_adapter(
self,
) -> type[VercelAIAdapter[None, DeferredToolRequests | str]]:
"""
Return a custom adapter that includes thinking signatures in ReasoningEndChunk.

pydantic_ai's VercelAIEventStream.handle_thinking_end doesn't pass the signature
from ThinkingPart to ReasoningEndChunk, which breaks Anthropic's extended thinking
on follow-up messages (Anthropic requires signatures on thinking blocks).

This is a patch for pydantic-ai <1.47.0, which doesn't include the signature in the ReasoningEndChunk.
"""
if DependencyManager.pydantic_ai.has_at_version(min_version="1.47.0"):
return super().get_vercel_adapter()

from pydantic_ai import DeferredToolRequests
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
from pydantic_ai.ui.vercel_ai._event_stream import VercelAIEventStream
from pydantic_ai.ui.vercel_ai.response_types import ReasoningEndChunk

AnthropicOutputType = DeferredToolRequests | str

# Custom event stream that includes signature in ReasoningEndChunk
class AnthropicVercelAIEventStream(
VercelAIEventStream[None, AnthropicOutputType]
):
async def handle_thinking_end(
self, part: ThinkingPart, followed_by_thinking: bool = False
) -> AsyncIterator[BaseChunk]:
"""Override to include signature in provider_metadata."""
try:
provider_metadata = None
if part.signature:
pydantic_ai_meta: dict[str, Any] = {
"signature": part.signature
}
if part.provider_name:
pydantic_ai_meta["provider_name"] = (
part.provider_name
)
if part.id:
pydantic_ai_meta["id"] = part.id
provider_metadata = {"pydantic_ai": pydantic_ai_meta}

yield ReasoningEndChunk(
id=self.message_id, provider_metadata=provider_metadata
)
except Exception as e:
LOGGER.warning(
f"Error in AnthropicVercelAIEventStream.handle_thinking_end: {e}"
)
async for chunk in super().handle_thinking_end(
part, followed_by_thinking
):
yield chunk

# Custom adapter that uses the custom event stream
class AnthropicVercelAIAdapter(
VercelAIAdapter[None, AnthropicOutputType]
):
def build_event_stream(self) -> AnthropicVercelAIEventStream:
return AnthropicVercelAIEventStream(
self.run_input, accept=self.accept
)

return AnthropicVercelAIAdapter


class BedrockProvider(PydanticProvider["PydanticBedrock"]):
def setup_credentials(self, config: AnyProviderConfig) -> None:
Expand Down
Loading
Loading