Skip to content

Commit 8290186

Browse files
.
1 parent cf39d82 commit 8290186

File tree

3 files changed

+120
-43
lines changed

3 files changed

+120
-43
lines changed
Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,17 @@
1-
from functools import wraps
21
from sentry_sdk.integrations import DidNotEnable, Integration
32

3+
44
try:
55
import pydantic_ai # type: ignore # noqa: F401
6-
from pydantic_ai.capabilities.combined import CombinedCapability # type: ignore
76
except ImportError:
87
raise DidNotEnable("pydantic-ai not installed")
98

109

1110
from .patches import (
1211
_patch_agent_run,
12+
_patch_graph_nodes,
1313
_patch_tool_execution,
1414
)
15-
from .spans import (
16-
ai_client_span,
17-
update_ai_client_span,
18-
)
19-
20-
from typing import TYPE_CHECKING
21-
22-
if TYPE_CHECKING:
23-
from typing import Any, Awaitable, Callable
24-
25-
from pydantic_ai._run_context import RunContext # type: ignore
26-
from pydantic_ai.models import ModelRequestContext # type: ignore
27-
from pydantic_ai.messages import ModelResponse # type: ignore
28-
29-
30-
def _patch_wrap_model_request() -> None:
31-
original_wrap_model_request = CombinedCapability.wrap_model_request
32-
33-
@wraps(original_wrap_model_request)
34-
async def wrapped_wrap_model_request(
35-
self: "CombinedCapability",
36-
ctx: "RunContext[Any]",
37-
*,
38-
request_context: "ModelRequestContext",
39-
handler: "Callable[[ModelRequestContext], Awaitable[ModelResponse]]",
40-
) -> "Any":
41-
with ai_client_span(
42-
request_context.messages,
43-
None,
44-
request_context.model,
45-
request_context.model_settings,
46-
) as span:
47-
result = await original_wrap_model_request(
48-
self, ctx, request_context=request_context, handler=handler
49-
)
50-
51-
update_ai_client_span(span, result)
52-
return result
53-
54-
CombinedCapability.wrap_model_request = wrapped_wrap_model_request
5515

5616

5717
class PydanticAIIntegration(Integration):
@@ -84,5 +44,5 @@ def setup_once() -> None:
8444
- Tool executions
8545
"""
8646
_patch_agent_run()
87-
_patch_wrap_model_request()
47+
_patch_graph_nodes()
8848
_patch_tool_execution()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .agent_run import _patch_agent_run # noqa: F401
2+
from .graph_nodes import _patch_graph_nodes # noqa: F401
23
from .tools import _patch_tool_execution # noqa: F401
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from contextlib import asynccontextmanager
2+
from functools import wraps
3+
4+
from sentry_sdk.integrations import DidNotEnable
5+
6+
from ..spans import (
7+
ai_client_span,
8+
update_ai_client_span,
9+
)
10+
11+
try:
12+
from pydantic_ai._agent_graph import ModelRequestNode # type: ignore
13+
except ImportError:
14+
raise DidNotEnable("pydantic-ai not installed")
15+
16+
from typing import TYPE_CHECKING
17+
18+
if TYPE_CHECKING:
19+
from typing import Any, Callable
20+
21+
22+
def _extract_span_data(node: "Any", ctx: "Any") -> "tuple[list[Any], Any, Any]":
23+
"""Extract common data needed for creating chat spans.
24+
25+
Returns:
26+
Tuple of (messages, model, model_settings)
27+
"""
28+
# Extract model and settings from context
29+
model = None
30+
model_settings = None
31+
if hasattr(ctx, "deps"):
32+
model = getattr(ctx.deps, "model", None)
33+
model_settings = getattr(ctx.deps, "model_settings", None)
34+
35+
# Build full message list: history + current request
36+
messages = []
37+
if hasattr(ctx, "state") and hasattr(ctx.state, "message_history"):
38+
messages.extend(ctx.state.message_history)
39+
40+
current_request = getattr(node, "request", None)
41+
if current_request:
42+
messages.append(current_request)
43+
44+
return messages, model, model_settings
45+
46+
47+
def _patch_graph_nodes() -> None:
48+
"""
49+
Patches the graph node execution to create appropriate spans.
50+
51+
ModelRequestNode -> Creates ai_client span for model requests
52+
CallToolsNode -> Handles tool calls (spans created in tool patching)
53+
"""
54+
55+
# Patch ModelRequestNode to create ai_client spans
56+
original_model_request_run = ModelRequestNode._make_request
57+
58+
@wraps(original_model_request_run)
59+
async def wrapped_model_request_run(self: "Any", ctx: "Any") -> "Any":
60+
did_stream = getattr(self, "_did_stream", None)
61+
cached_result = getattr(self, "_result", None)
62+
if did_stream or cached_result is not None:
63+
return await original_model_request_run(self, ctx)
64+
65+
messages, model, model_settings = _extract_span_data(self, ctx)
66+
67+
with ai_client_span(messages, None, model, model_settings) as span:
68+
result = await original_model_request_run(self, ctx)
69+
70+
# Extract response from result if available
71+
model_response = None
72+
if hasattr(result, "model_response"):
73+
model_response = result.model_response
74+
75+
update_ai_client_span(span, model_response)
76+
return result
77+
78+
ModelRequestNode.run = wrapped_model_request_run
79+
80+
# Patch ModelRequestNode.stream for streaming requests
81+
original_model_request_stream = ModelRequestNode.stream
82+
83+
def create_wrapped_stream(
84+
original_stream_method: "Callable[..., Any]",
85+
) -> "Callable[..., Any]":
86+
"""Create a wrapper for ModelRequestNode.stream that creates chat spans."""
87+
88+
@asynccontextmanager
89+
@wraps(original_stream_method)
90+
async def wrapped_model_request_stream(self: "Any", ctx: "Any") -> "Any":
91+
did_stream = getattr(self, "_did_stream", None)
92+
if did_stream:
93+
async with original_stream_method(self, ctx) as stream:
94+
yield stream
95+
96+
messages, model, model_settings = _extract_span_data(self, ctx)
97+
98+
# Create chat span for streaming request
99+
with ai_client_span(messages, None, model, model_settings) as span:
100+
# Call the original stream method
101+
async with original_stream_method(self, ctx) as stream:
102+
yield stream
103+
104+
# After streaming completes, update span with response data
105+
# The ModelRequestNode stores the final response in _result
106+
model_response = None
107+
if hasattr(self, "_result") and self._result is not None:
108+
# _result is a NextNode containing the model_response
109+
if hasattr(self._result, "model_response"):
110+
model_response = self._result.model_response
111+
112+
update_ai_client_span(span, model_response)
113+
114+
return wrapped_model_request_stream
115+
116+
ModelRequestNode.stream = create_wrapped_stream(original_model_request_stream)

0 commit comments

Comments
 (0)