Skip to content

Commit 9eb5239

Browse files
merge
2 parents 051812d + ba7b2d4 commit 9eb5239

File tree

3 files changed

+124
-92
lines changed

3 files changed

+124
-92
lines changed
Lines changed: 2 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,16 @@
1-
from functools import wraps
2-
from contextlib import asynccontextmanager
3-
from contextvars import ContextVar
41
from sentry_sdk.integrations import DidNotEnable, Integration
5-
from sentry_sdk.consts import SPANDATA
62

73
try:
84
import pydantic_ai # type: ignore # noqa: F401
9-
from pydantic_ai.capabilities.combined import CombinedCapability # type: ignore
10-
from pydantic_ai._agent_graph import ModelRequestNode
115
except ImportError:
126
raise DidNotEnable("pydantic-ai not installed")
137

148

159
from .patches import (
1610
_patch_agent_run,
11+
_patch_graph_nodes,
1712
_patch_tool_execution,
1813
)
19-
from .spans import (
20-
ai_client_span,
21-
update_ai_client_span,
22-
)
23-
24-
from typing import TYPE_CHECKING
25-
26-
if TYPE_CHECKING:
27-
from typing import Any, Awaitable, Callable
28-
29-
from pydantic_ai._run_context import RunContext
30-
from pydantic_ai.models import ModelRequestContext
31-
from pydantic_ai.messages import ModelResponse
32-
33-
34-
_is_streaming: ContextVar[bool] = ContextVar(
35-
"sentry_pydantic_ai_is_streaming", default=False
36-
)
37-
38-
39-
def _patch_wrap_model_request():
40-
original_wrap_model_request = CombinedCapability.wrap_model_request
41-
42-
@wraps(original_wrap_model_request)
43-
async def wrapped_wrap_model_request(
44-
self,
45-
ctx: "RunContext[Any]",
46-
*,
47-
request_context: "ModelRequestContext",
48-
handler: "Callable[[ModelRequestContext], Awaitable[ModelResponse]]",
49-
) -> "Any":
50-
with ai_client_span(
51-
request_context.messages,
52-
None,
53-
request_context.model,
54-
request_context.model_settings,
55-
) as span:
56-
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, _is_streaming.get())
57-
58-
result = await original_wrap_model_request(
59-
self, ctx, request_context=request_context, handler=handler
60-
)
61-
62-
update_ai_client_span(span, result)
63-
return result
64-
65-
CombinedCapability.wrap_model_request = wrapped_wrap_model_request
66-
67-
68-
def _patch_model_request_node_run():
69-
original_model_request_run = ModelRequestNode.run
70-
71-
@wraps(original_model_request_run)
72-
async def wrapped_model_request_run(self: "Any", ctx: "Any") -> "Any":
73-
token = _is_streaming.set(False)
74-
try:
75-
return await original_model_request_run(self, ctx)
76-
finally:
77-
_is_streaming.reset(token)
78-
79-
ModelRequestNode.run = wrapped_model_request_run
80-
81-
82-
def _patch_model_request_node_stream():
83-
original_model_request_stream = ModelRequestNode.stream
84-
85-
def create_wrapped_stream(
86-
original_stream_method: "Callable[..., Any]",
87-
) -> "Callable[..., Any]":
88-
@asynccontextmanager
89-
@wraps(original_stream_method)
90-
async def wrapped_model_request_stream(self: "Any", ctx: "Any") -> "Any":
91-
token = _is_streaming.set(True)
92-
try:
93-
async with original_stream_method(self, ctx) as stream:
94-
yield stream
95-
finally:
96-
_is_streaming.reset(token)
97-
98-
return wrapped_model_request_stream
99-
100-
ModelRequestNode.stream = create_wrapped_stream(original_model_request_stream)
10114

10215

10316
class PydanticAIIntegration(Integration):
@@ -130,8 +43,5 @@ def setup_once() -> None:
13043
- Tool executions
13144
"""
13245
_patch_agent_run()
133-
_patch_wrap_model_request()
46+
_patch_graph_nodes()
13447
_patch_tool_execution()
135-
136-
_patch_model_request_node_run()
137-
_patch_model_request_node_stream()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .agent_run import _patch_agent_run # noqa: F401
2+
from .graph_nodes import _patch_graph_nodes # noqa: F401
23
from .tools import _patch_tool_execution # noqa: F401
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from contextlib import asynccontextmanager
2+
from functools import wraps
3+
4+
from sentry_sdk.integrations import DidNotEnable
5+
from sentry_sdk.consts import SPANDATA
6+
7+
from ..spans import (
8+
ai_client_span,
9+
update_ai_client_span,
10+
)
11+
12+
try:
13+
from pydantic_ai._agent_graph import ModelRequestNode # type: ignore
14+
except ImportError:
15+
raise DidNotEnable("pydantic-ai not installed")
16+
17+
from typing import TYPE_CHECKING
18+
19+
if TYPE_CHECKING:
20+
from typing import Any, Callable
21+
22+
23+
def _extract_span_data(node: "Any", ctx: "Any") -> "tuple[list[Any], Any, Any]":
24+
"""Extract common data needed for creating chat spans.
25+
26+
Returns:
27+
Tuple of (messages, model, model_settings)
28+
"""
29+
# Extract model and settings from context
30+
model = None
31+
model_settings = None
32+
if hasattr(ctx, "deps"):
33+
model = getattr(ctx.deps, "model", None)
34+
model_settings = getattr(ctx.deps, "model_settings", None)
35+
36+
# Build full message list: history + current request
37+
messages = []
38+
if hasattr(ctx, "state") and hasattr(ctx.state, "message_history"):
39+
messages.extend(ctx.state.message_history)
40+
41+
current_request = getattr(node, "request", None)
42+
if current_request:
43+
messages.append(current_request)
44+
45+
return messages, model, model_settings
46+
47+
48+
def _patch_graph_nodes() -> None:
49+
"""
50+
Patches the graph node execution to create appropriate spans.
51+
52+
ModelRequestNode -> Creates ai_client span for model requests
53+
CallToolsNode -> Handles tool calls (spans created in tool patching)
54+
"""
55+
56+
# Patch ModelRequestNode to create ai_client spans
57+
original_model_request_run = ModelRequestNode.run
58+
59+
@wraps(original_model_request_run)
60+
async def wrapped_model_request_run(self: "Any", ctx: "Any") -> "Any":
61+
did_stream = getattr(self, "_did_stream", None)
62+
cached_result = getattr(self, "_result", None)
63+
if did_stream or cached_result is not None:
64+
return await original_model_request_run(self, ctx)
65+
66+
messages, model, model_settings = _extract_span_data(self, ctx)
67+
68+
with ai_client_span(messages, None, model, model_settings) as span:
69+
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)
70+
71+
result = await original_model_request_run(self, ctx)
72+
73+
# Extract response from result if available
74+
model_response = None
75+
if hasattr(result, "model_response"):
76+
model_response = result.model_response
77+
78+
update_ai_client_span(span, model_response)
79+
return result
80+
81+
ModelRequestNode.run = wrapped_model_request_run
82+
83+
# Patch ModelRequestNode.stream for streaming requests
84+
original_model_request_stream = ModelRequestNode.stream
85+
86+
def create_wrapped_stream(
87+
original_stream_method: "Callable[..., Any]",
88+
) -> "Callable[..., Any]":
89+
"""Create a wrapper for ModelRequestNode.stream that creates chat spans."""
90+
91+
@asynccontextmanager
92+
@wraps(original_stream_method)
93+
async def wrapped_model_request_stream(self: "Any", ctx: "Any") -> "Any":
94+
did_stream = getattr(self, "_did_stream", None)
95+
if did_stream:
96+
async with original_stream_method(self, ctx) as stream:
97+
yield stream
98+
99+
messages, model, model_settings = _extract_span_data(self, ctx)
100+
101+
# Create chat span for streaming request
102+
with ai_client_span(messages, None, model, model_settings) as span:
103+
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
104+
105+
# Call the original stream method
106+
async with original_stream_method(self, ctx) as stream:
107+
yield stream
108+
109+
# After streaming completes, update span with response data
110+
# The ModelRequestNode stores the final response in _result
111+
model_response = None
112+
if hasattr(self, "_result") and self._result is not None:
113+
# _result is a NextNode containing the model_response
114+
if hasattr(self._result, "model_response"):
115+
model_response = self._result.model_response
116+
117+
update_ai_client_span(span, model_response)
118+
119+
return wrapped_model_request_stream
120+
121+
ModelRequestNode.stream = create_wrapped_stream(original_model_request_stream)

0 commit comments

Comments
 (0)