|
| 1 | +from functools import wraps |
| 2 | +from contextlib import asynccontextmanager |
| 3 | +from contextvars import ContextVar |
1 | 4 | from sentry_sdk.integrations import DidNotEnable, Integration |
2 | | - |
| 5 | +from sentry_sdk.consts import SPANDATA |
3 | 6 |
|
4 | 7 | try: |
5 | 8 | import pydantic_ai # type: ignore # noqa: F401 |
| 9 | + from pydantic_ai.capabilities.combined import CombinedCapability # type: ignore |
| 10 | + from pydantic_ai._agent_graph import ModelRequestNode |
6 | 11 | except ImportError: |
7 | 12 | raise DidNotEnable("pydantic-ai not installed") |
8 | 13 |
|
9 | 14 |
|
10 | 15 | from .patches import ( |
11 | 16 | _patch_agent_run, |
12 | | - _patch_graph_nodes, |
13 | 17 | _patch_tool_execution, |
14 | 18 | ) |
| 19 | +from .spans import ( |
| 20 | + ai_client_span, |
| 21 | + update_ai_client_span, |
| 22 | +) |
| 23 | + |
| 24 | +from typing import TYPE_CHECKING |
| 25 | + |
| 26 | +if TYPE_CHECKING: |
| 27 | + from typing import Any, Awaitable, Callable |
| 28 | + |
| 29 | + from pydantic_ai._run_context import RunContext |
| 30 | + from pydantic_ai.models import ModelRequestContext |
| 31 | + from pydantic_ai.messages import ModelResponse |
| 32 | + |
| 33 | + |
| 34 | +_is_streaming: ContextVar[bool] = ContextVar( |
| 35 | + "sentry_pydantic_ai_is_streaming", default=False |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +def _patch_wrap_model_request(): |
| 40 | + original_wrap_model_request = CombinedCapability.wrap_model_request |
| 41 | + |
| 42 | + @wraps(original_wrap_model_request) |
| 43 | + async def wrapped_wrap_model_request( |
| 44 | + self, |
| 45 | + ctx: "RunContext[Any]", |
| 46 | + *, |
| 47 | + request_context: "ModelRequestContext", |
| 48 | + handler: "Callable[[ModelRequestContext], Awaitable[ModelResponse]]", |
| 49 | + ) -> "Any": |
| 50 | + with ai_client_span( |
| 51 | + request_context.messages, |
| 52 | + None, |
| 53 | + request_context.model, |
| 54 | + request_context.model_settings, |
| 55 | + ) as span: |
| 56 | + span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, _is_streaming.get()) |
| 57 | + |
| 58 | + result = await original_wrap_model_request( |
| 59 | + self, ctx, request_context=request_context, handler=handler |
| 60 | + ) |
| 61 | + |
| 62 | + update_ai_client_span(span, result) |
| 63 | + return result |
| 64 | + |
| 65 | + CombinedCapability.wrap_model_request = wrapped_wrap_model_request |
| 66 | + |
| 67 | + |
| 68 | +def _patch_model_request_node_run(): |
| 69 | + original_model_request_run = ModelRequestNode.run |
| 70 | + |
| 71 | + @wraps(original_model_request_run) |
| 72 | + async def wrapped_model_request_run(self: "Any", ctx: "Any") -> "Any": |
| 73 | + token = _is_streaming.set(False) |
| 74 | + try: |
| 75 | + return await original_model_request_run(self, ctx) |
| 76 | + finally: |
| 77 | + _is_streaming.reset(token) |
| 78 | + |
| 79 | + ModelRequestNode.run = wrapped_model_request_run |
| 80 | + |
| 81 | + |
| 82 | +def _patch_model_request_node_stream(): |
| 83 | + original_model_request_stream = ModelRequestNode.stream |
| 84 | + |
| 85 | + def create_wrapped_stream( |
| 86 | + original_stream_method: "Callable[..., Any]", |
| 87 | + ) -> "Callable[..., Any]": |
| 88 | + @asynccontextmanager |
| 89 | + @wraps(original_stream_method) |
| 90 | + async def wrapped_model_request_stream(self: "Any", ctx: "Any") -> "Any": |
| 91 | + token = _is_streaming.set(True) |
| 92 | + try: |
| 93 | + async with original_stream_method(self, ctx) as stream: |
| 94 | + yield stream |
| 95 | + finally: |
| 96 | + _is_streaming.reset(token) |
| 97 | + |
| 98 | + return wrapped_model_request_stream |
| 99 | + |
| 100 | + ModelRequestNode.stream = create_wrapped_stream(original_model_request_stream) |
15 | 101 |
|
16 | 102 |
|
17 | 103 | class PydanticAIIntegration(Integration): |
@@ -44,5 +130,8 @@ def setup_once() -> None: |
44 | 130 | - Tool executions |
45 | 131 | """ |
46 | 132 | _patch_agent_run() |
47 | | - _patch_graph_nodes() |
| 133 | + _patch_wrap_model_request() |
48 | 134 | _patch_tool_execution() |
| 135 | + |
| 136 | + _patch_model_request_node_run() |
| 137 | + _patch_model_request_node_stream() |
0 commit comments