Skip to content

Commit 5c3f4eb

Browse files
fix(pydantic-ai): Stop double reporting model requests
1 parent 6ef56fd commit 5c3f4eb

File tree

4 files changed

+65
-133
lines changed

4 files changed

+65
-133
lines changed

sentry_sdk/integrations/pydantic_ai/__init__.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,57 @@
1+
from functools import wraps
12
from sentry_sdk.integrations import DidNotEnable, Integration
23

3-
44
try:
55
import pydantic_ai # type: ignore # noqa: F401
6+
from pydantic_ai.capabilities.combined import CombinedCapability # type: ignore
67
except ImportError:
78
raise DidNotEnable("pydantic-ai not installed")
89

910

1011
from .patches import (
1112
_patch_agent_run,
12-
_patch_graph_nodes,
1313
_patch_tool_execution,
1414
)
15+
from .spans import (
16+
ai_client_span,
17+
update_ai_client_span,
18+
)
19+
20+
from typing import TYPE_CHECKING
21+
22+
if TYPE_CHECKING:
23+
from typing import Any, Awaitable, Callable
24+
25+
from pydantic_ai._run_context import RunContext
26+
from pydantic_ai.models import ModelRequestContext
27+
from pydantic_ai.messages import ModelResponse
28+
29+
30+
def _patch_wrap_model_request():
31+
original_wrap_model_request = CombinedCapability.wrap_model_request
32+
33+
@wraps(original_wrap_model_request)
34+
async def wrapped_wrap_model_request(
35+
self,
36+
ctx: "RunContext[Any]",
37+
*,
38+
request_context: "ModelRequestContext",
39+
handler: "Callable[[ModelRequestContext], Awaitable[ModelResponse]]",
40+
) -> "Any":
41+
with ai_client_span(
42+
request_context.messages,
43+
None,
44+
request_context.model,
45+
request_context.model_settings,
46+
) as span:
47+
result = await original_wrap_model_request(
48+
self, ctx, request_context=request_context, handler=handler
49+
)
50+
51+
update_ai_client_span(span, result)
52+
return result
53+
54+
CombinedCapability.wrap_model_request = wrapped_wrap_model_request
1555

1656

1757
class PydanticAIIntegration(Integration):
@@ -44,5 +84,5 @@ def setup_once() -> None:
4484
- Tool executions
4585
"""
4686
_patch_agent_run()
47-
_patch_graph_nodes()
87+
_patch_wrap_model_request()
4888
_patch_tool_execution()
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .agent_run import _patch_agent_run # noqa: F401
2-
from .graph_nodes import _patch_graph_nodes # noqa: F401
32
from .tools import _patch_tool_execution # noqa: F401

sentry_sdk/integrations/pydantic_ai/patches/graph_nodes.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

tests/integrations/pydantic_ai/test_pydantic_ai.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def test_agent_run_async(sentry_init, capture_events, test_agent):
7575

7676
# Find child span types (invoke_agent is the transaction, not a child span)
7777
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
78-
assert len(chat_spans) >= 1
78+
assert len(chat_spans) == 1
7979

8080
# Check chat span
8181
chat_span = chat_spans[0]
@@ -158,7 +158,7 @@ def test_agent_run_sync(sentry_init, capture_events, test_agent):
158158

159159
# Find span types
160160
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
161-
assert len(chat_spans) >= 1
161+
assert len(chat_spans) == 1
162162

163163
# Verify streaming flag is False for sync
164164
for chat_span in chat_spans:
@@ -192,7 +192,7 @@ async def test_agent_run_stream(sentry_init, capture_events, test_agent):
192192

193193
# Find chat spans
194194
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
195-
assert len(chat_spans) >= 1
195+
assert len(chat_spans) == 1
196196

197197
# Verify streaming flag is True for streaming
198198
for chat_span in chat_spans:
@@ -231,9 +231,8 @@ async def test_agent_run_stream_events(sentry_init, capture_events, test_agent):
231231
# Find chat spans
232232
spans = transaction["spans"]
233233
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
234-
assert len(chat_spans) >= 1
234+
assert len(chat_spans) == 1
235235

236-
# run_stream_events uses run() internally, so streaming should be False
237236
for chat_span in chat_spans:
238237
assert chat_span["data"]["gen_ai.response.streaming"] is False
239238

@@ -269,7 +268,7 @@ def add_numbers(a: int, b: int) -> int:
269268
tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
270269

271270
# Should have tool spans
272-
assert len(tool_spans) >= 1
271+
assert len(tool_spans) == 1
273272

274273
# Check tool span
275274
tool_span = tool_spans[0]
@@ -342,7 +341,7 @@ def add_numbers(a: int, b: int) -> float:
342341
tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
343342

344343
# Should have tool spans
345-
assert len(tool_spans) >= 1
344+
assert len(tool_spans) == 2
346345

347346
# Check tool spans
348347
model_retry_tool_span = tool_spans[0]
@@ -421,7 +420,7 @@ def add_numbers(a: Annotated[int, Field(gt=0, lt=0)], b: int) -> int:
421420
tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
422421

423422
# Should have tool spans
424-
assert len(tool_spans) >= 1
423+
assert len(tool_spans) == 1
425424

426425
# Check tool spans
427426
model_retry_tool_span = tool_spans[0]
@@ -470,7 +469,7 @@ def multiply(a: int, b: int) -> int:
470469
tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
471470

472471
# Should have tool spans
473-
assert len(tool_spans) >= 1
472+
assert len(tool_spans) == 1
474473

475474
# Verify streaming flag is True
476475
for chat_span in chat_spans:
@@ -502,7 +501,7 @@ async def test_model_settings(sentry_init, capture_events, test_agent_with_setti
502501

503502
# Find chat span
504503
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
505-
assert len(chat_spans) >= 1
504+
assert len(chat_spans) == 1
506505

507506
chat_span = chat_spans[0]
508507
# Check that model settings are captured
@@ -548,7 +547,7 @@ async def test_system_prompt_attribute(
548547

549548
# The transaction IS the invoke_agent span, check for messages in chat spans instead
550549
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
551-
assert len(chat_spans) >= 1
550+
assert len(chat_spans) == 1
552551

553552
chat_span = chat_spans[0]
554553

@@ -587,7 +586,7 @@ async def test_error_handling(sentry_init, capture_events):
587586
await agent.run("Hello")
588587

589588
# At minimum, we should have a transaction
590-
assert len(events) >= 1
589+
assert len(events) == 1
591590
transaction = [e for e in events if e.get("type") == "transaction"][0]
592591
assert transaction["transaction"] == "invoke_agent test_error"
593592
# Transaction should complete successfully (status key may not exist if no error)
@@ -681,7 +680,7 @@ async def run_agent(input_text):
681680
assert transaction["type"] == "transaction"
682681
assert transaction["transaction"] == "invoke_agent test_agent"
683682
# Each should have its own spans
684-
assert len(transaction["spans"]) >= 1
683+
assert len(transaction["spans"]) == 1
685684

686685

687686
@pytest.mark.asyncio
@@ -721,7 +720,7 @@ async def test_message_history(sentry_init, capture_events):
721720
await agent.run("What is my name?", message_history=history)
722721

723722
# We should have 2 transactions
724-
assert len(events) >= 2
723+
assert len(events) == 2
725724

726725
# Check the second transaction has the full history
727726
second_transaction = events[1]
@@ -755,7 +754,7 @@ async def test_gen_ai_system(sentry_init, capture_events, test_agent):
755754

756755
# Find chat span
757756
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
758-
assert len(chat_spans) >= 1
757+
assert len(chat_spans) == 1
759758

760759
chat_span = chat_spans[0]
761760
# gen_ai.system should be set from the model (TestModel -> 'test')
@@ -812,7 +811,7 @@ async def test_include_prompts_true(sentry_init, capture_events, test_agent):
812811
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
813812

814813
# Verify that messages are captured in chat spans
815-
assert len(chat_spans) >= 1
814+
assert len(chat_spans) == 1
816815
for chat_span in chat_spans:
817816
assert "gen_ai.request.messages" in chat_span["data"]
818817

@@ -1242,7 +1241,7 @@ async def test_invoke_agent_with_instructions(
12421241

12431242
# The transaction IS the invoke_agent span, check for messages in chat spans instead
12441243
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
1245-
assert len(chat_spans) >= 1
1244+
assert len(chat_spans) == 1
12461245

12471246
chat_span = chat_spans[0]
12481247

@@ -1366,7 +1365,7 @@ async def test_usage_data_partial(sentry_init, capture_events):
13661365
spans = transaction["spans"]
13671366

13681367
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
1369-
assert len(chat_spans) >= 1
1368+
assert len(chat_spans) == 1
13701369

13711370
# Check that usage data fields exist (they may or may not be set depending on TestModel)
13721371
chat_span = chat_spans[0]
@@ -1461,7 +1460,7 @@ def calc_tool(value: int) -> int:
14611460
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
14621461

14631462
# At least one chat span should exist
1464-
assert len(chat_spans) >= 1
1463+
assert len(chat_spans) == 2
14651464

14661465
# Check if tool calls are captured in response
14671466
for chat_span in chat_spans:
@@ -1509,7 +1508,7 @@ async def test_message_formatting_with_different_parts(sentry_init, capture_even
15091508
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
15101509

15111510
# Should have chat spans
1512-
assert len(chat_spans) >= 1
1511+
assert len(chat_spans) == 1
15131512

15141513
# Check that messages are captured
15151514
chat_span = chat_spans[0]
@@ -1781,7 +1780,7 @@ def test_tool(x: int) -> int:
17811780
chat_spans = [s for s in spans if s["op"] == "gen_ai.chat"]
17821781

17831782
# Should have chat spans
1784-
assert len(chat_spans) >= 1
1783+
assert len(chat_spans) == 2
17851784

17861785

17871786
@pytest.mark.asyncio
@@ -2762,7 +2761,7 @@ async def test_binary_content_in_agent_run(sentry_init, capture_events):
27622761

27632762
(transaction,) = events
27642763
chat_spans = [s for s in transaction["spans"] if s["op"] == "gen_ai.chat"]
2765-
assert len(chat_spans) >= 1
2764+
assert len(chat_spans) == 1
27662765

27672766
chat_span = chat_spans[0]
27682767
if "gen_ai.request.messages" in chat_span["data"]:
@@ -2906,7 +2905,7 @@ def multiply_numbers(a: int, b: int) -> int:
29062905
spans = transaction["spans"]
29072906

29082907
tool_spans = [s for s in spans if s["op"] == "gen_ai.execute_tool"]
2909-
assert len(tool_spans) >= 1
2908+
assert len(tool_spans) == 1
29102909

29112910
tool_span = tool_spans[0]
29122911
assert tool_span["data"]["gen_ai.tool.name"] == "multiply_numbers"

0 commit comments

Comments
 (0)