Skip to content

Commit 963366a

Browse files
wip: keeping anthropic wrapper same as openai wrapper.
1 parent 91c3342 commit 963366a

2 files changed

Lines changed: 62 additions & 100 deletions

File tree

  • instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic

instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def messages_create(
5858
Union[
5959
"AnthropicMessage",
6060
"AnthropicStream[RawMessageStreamEvent]",
61-
MessagesStreamWrapper[RawMessageStreamEvent],
61+
MessagesStreamWrapper[None],
6262
],
6363
]:
6464
"""Wrap the `create` method of the `Messages` class to trace it."""
@@ -78,7 +78,7 @@ def traced_method(
7878
) -> Union[
7979
"AnthropicMessage",
8080
"AnthropicStream[RawMessageStreamEvent]",
81-
MessagesStreamWrapper[RawMessageStreamEvent],
81+
MessagesStreamWrapper[None],
8282
]:
8383
params = extract_params(*args, **kwargs)
8484
attributes = get_llm_request_attributes(params, instance)
@@ -123,6 +123,6 @@ def traced_method(
123123
raise
124124

125125
return cast(
126-
'Callable[..., Union["AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", MessagesStreamWrapper[RawMessageStreamEvent]]]',
126+
'Callable[..., Union["AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", MessagesStreamWrapper[None]]]',
127127
traced_method,
128128
)

instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py

Lines changed: 59 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
from types import TracebackType
2020
from typing import (
2121
TYPE_CHECKING,
22-
AsyncIterator,
22+
Any,
2323
Callable,
2424
Generator,
2525
Generic,
2626
Iterator,
27-
Protocol,
2827
TypeVar,
2928
cast,
3029
)
@@ -45,64 +44,29 @@
4544
_sdk_accumulate_event = None
4645

4746
if TYPE_CHECKING:
47+
from anthropic._streaming import AsyncStream, Stream
4848
from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module
49+
AsyncMessageStream,
4950
AsyncMessageStreamManager,
51+
MessageStream,
5052
MessageStreamManager,
5153
)
5254
from anthropic.lib.streaming._types import ( # pylint: disable=no-name-in-module
53-
MessageStreamEvent,
55+
ParsedMessageStreamEvent,
5456
)
5557
from anthropic.types import (
5658
Message,
5759
RawMessageStreamEvent,
5860
)
61+
from anthropic.types.parsed_message import ParsedMessage
5962

6063

6164
_logger = logging.getLogger(__name__)
62-
SyncResponseT = TypeVar("SyncResponseT", bound="_SupportsClose")
63-
AsyncResponseT = TypeVar("AsyncResponseT", bound="_SupportsAclose")
64-
StreamEventT = TypeVar(
65-
"StreamEventT", "RawMessageStreamEvent", "MessageStreamEvent"
66-
)
67-
StreamEventT_co = TypeVar(
68-
"StreamEventT_co",
69-
"RawMessageStreamEvent",
70-
"MessageStreamEvent",
71-
covariant=True,
72-
)
65+
ResponseT = TypeVar("ResponseT")
66+
ResponseFormatT = TypeVar("ResponseFormatT")
7367
accumulate_event = cast("Callable[..., Message] | None", _sdk_accumulate_event)
7468

7569

76-
class _SupportsClose(Protocol):
77-
def close(self) -> None: ...
78-
79-
80-
class _SupportsAclose(_SupportsClose, Protocol):
81-
async def aclose(self) -> None: ...
82-
83-
84-
class _SyncStream(Protocol[StreamEventT_co]):
85-
@property
86-
def response(self) -> _SupportsClose: ...
87-
88-
def __iter__(self) -> Iterator[StreamEventT_co]: ...
89-
90-
def __next__(self) -> StreamEventT_co: ...
91-
92-
def close(self) -> None: ...
93-
94-
95-
class _AsyncStream(Protocol[StreamEventT_co]):
96-
@property
97-
def response(self) -> _SupportsAclose: ...
98-
99-
def __aiter__(self) -> AsyncIterator[StreamEventT_co]: ...
100-
101-
async def __anext__(self) -> StreamEventT_co: ...
102-
103-
async def close(self) -> None: ...
104-
105-
10670
def _set_response_attributes(
10771
invocation: LLMInvocation,
10872
result: "Message | None",
@@ -111,9 +75,9 @@ def _set_response_attributes(
11175
set_invocation_response_attributes(invocation, result, capture_content)
11276

11377

114-
class _ResponseProxy(Generic[SyncResponseT]):
115-
def __init__(self, response: SyncResponseT, finalize: Callable[[], None]):
116-
self._response = response
78+
class _ResponseProxy(Generic[ResponseT]):
79+
def __init__(self, response: ResponseT, finalize: Callable[[], None]):
80+
self._response: Any = response
11781
self._finalize = finalize
11882

11983
def close(self) -> None:
@@ -126,17 +90,11 @@ def __getattr__(self, name: str):
12690
return getattr(self._response, name)
12791

12892

129-
class _AsyncResponseProxy(Generic[AsyncResponseT]):
130-
def __init__(self, response: AsyncResponseT, finalize: Callable[[], None]):
131-
self._response = response
93+
class _AsyncResponseProxy(Generic[ResponseT]):
94+
def __init__(self, response: ResponseT, finalize: Callable[[], None]):
95+
self._response: Any = response
13296
self._finalize = finalize
13397

134-
def close(self) -> None:
135-
try:
136-
self._response.close()
137-
finally:
138-
self._finalize()
139-
14098
async def aclose(self) -> None:
14199
try:
142100
await self._response.aclose()
@@ -166,24 +124,29 @@ def message(self) -> Message:
166124
return self._message
167125

168126

169-
class MessagesStreamWrapper(Generic[StreamEventT], Iterator[StreamEventT]):
127+
class MessagesStreamWrapper(
128+
Generic[ResponseFormatT],
129+
Iterator[
130+
"RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]"
131+
],
132+
):
170133
"""Wrapper for Anthropic Stream that handles telemetry."""
171134

172135
def __init__(
173136
self,
174-
stream: _SyncStream[StreamEventT],
137+
stream: "Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT]",
175138
handler: TelemetryHandler,
176139
invocation: LLMInvocation,
177140
capture_content: bool,
178141
):
179142
self.stream = stream
180143
self.handler = handler
181144
self.invocation = invocation
182-
self._message: "Message | None" = None
145+
self._message: "Message | ParsedMessage[ResponseFormatT] | None" = None
183146
self._capture_content = capture_content
184147
self._finalized = False
185148

186-
def __enter__(self) -> "MessagesStreamWrapper[StreamEventT]":
149+
def __enter__(self) -> "MessagesStreamWrapper[ResponseFormatT]":
187150
return self
188151

189152
def __exit__(
@@ -207,10 +170,12 @@ def close(self) -> None:
207170
finally:
208171
self._stop()
209172

210-
def __iter__(self) -> "MessagesStreamWrapper[StreamEventT]":
173+
def __iter__(self) -> "MessagesStreamWrapper[ResponseFormatT]":
211174
return self
212175

213-
def __next__(self) -> StreamEventT:
176+
def __next__(
177+
self,
178+
) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]":
214179
try:
215180
chunk = next(self.stream)
216181
except StopIteration:
@@ -227,9 +192,7 @@ def __getattr__(self, name: str) -> object:
227192
return getattr(self.stream, name)
228193

229194
@property
230-
def response(
231-
self,
232-
) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None":
195+
def response(self):
233196
return _ResponseProxy(self.stream.response, self._stop)
234197

235198
def _stop(self) -> None:
@@ -266,10 +229,13 @@ def _safe_instrumentation(
266229
exc_info=True,
267230
)
268231

269-
def _process_chunk(self, chunk: StreamEventT) -> None:
232+
def _process_chunk(
233+
self,
234+
chunk: "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]",
235+
) -> None:
270236
"""Accumulate a final message snapshot from a streaming chunk."""
271237
snapshot = cast(
272-
"Message | None",
238+
"ParsedMessage[ResponseFormatT] | None",
273239
getattr(self.stream, "current_message_snapshot", None),
274240
)
275241
if snapshot is not None:
@@ -279,30 +245,32 @@ def _process_chunk(self, chunk: StreamEventT) -> None:
279245
return
280246
self._message = accumulate_event(
281247
event=cast("RawMessageStreamEvent", chunk),
282-
current_snapshot=self._message,
248+
current_snapshot=cast(
249+
"ParsedMessage[ResponseFormatT] | None", self._message
250+
),
283251
)
284252

285253

286-
class AsyncMessagesStreamWrapper(MessagesStreamWrapper[StreamEventT]):
254+
class AsyncMessagesStreamWrapper(MessagesStreamWrapper[ResponseFormatT]):
287255
"""Wrapper for async Anthropic Stream that handles telemetry."""
288256

289-
stream: _AsyncStream[StreamEventT]
290-
291257
def __init__(
292258
self,
293-
stream: _AsyncStream[StreamEventT],
259+
stream: "AsyncStream[RawMessageStreamEvent] | AsyncMessageStream[ResponseFormatT]",
294260
handler: TelemetryHandler,
295261
invocation: LLMInvocation,
296262
capture_content: bool,
297263
):
298264
self.stream = stream
299265
self.handler = handler
300266
self.invocation = invocation
301-
self._message: "Message | None" = None
267+
self._message: "Message | ParsedMessage[ResponseFormatT] | None" = None
302268
self._capture_content = capture_content
303269
self._finalized = False
304270

305-
async def __aenter__(self) -> "AsyncMessagesStreamWrapper[StreamEventT]":
271+
async def __aenter__(
272+
self,
273+
) -> "AsyncMessagesStreamWrapper[ResponseFormatT]":
306274
return self
307275

308276
async def __aexit__(
@@ -326,16 +294,16 @@ async def close(self) -> None: # type: ignore[override]
326294
finally:
327295
self._stop()
328296

329-
def __aiter__(self) -> "AsyncMessagesStreamWrapper[StreamEventT]":
297+
def __aiter__(self) -> "AsyncMessagesStreamWrapper[ResponseFormatT]":
330298
return self
331299

332300
@property
333-
def response(
334-
self,
335-
) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None":
301+
def response(self) -> Any:
336302
return _AsyncResponseProxy(self.stream.response, self._stop)
337303

338-
async def __anext__(self) -> StreamEventT:
304+
async def __anext__(
305+
self,
306+
) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]":
339307
try:
340308
chunk = await self.stream.__anext__()
341309
except StopAsyncIteration:
@@ -349,12 +317,12 @@ async def __anext__(self) -> StreamEventT:
349317
return chunk
350318

351319

352-
class MessagesStreamManagerWrapper:
320+
class MessagesStreamManagerWrapper(Generic[ResponseFormatT]):
353321
"""Wrapper for sync Anthropic stream managers."""
354322

355323
def __init__(
356324
self,
357-
manager: "MessageStreamManager",
325+
manager: "MessageStreamManager[ResponseFormatT]",
358326
handler: TelemetryHandler,
359327
invocation: LLMInvocation,
360328
capture_content: bool,
@@ -363,15 +331,12 @@ def __init__(
363331
self._handler = handler
364332
self._invocation = invocation
365333
self._capture_content = capture_content
366-
self._stream_wrapper: (
367-
MessagesStreamWrapper[MessageStreamEvent] | None
368-
) = None
369-
370-
def __enter__(self) -> MessagesStreamWrapper[MessageStreamEvent]:
371-
stream = cast(
372-
"_SyncStream[MessageStreamEvent]",
373-
self._manager.__enter__(),
334+
self._stream_wrapper: MessagesStreamWrapper[ResponseFormatT] | None = (
335+
None
374336
)
337+
338+
def __enter__(self) -> MessagesStreamWrapper[ResponseFormatT]:
339+
stream = self._manager.__enter__()
375340
self._stream_wrapper = MessagesStreamWrapper(
376341
stream,
377342
self._handler,
@@ -406,7 +371,7 @@ def __getattr__(self, name: str) -> object:
406371
return getattr(self._manager, name)
407372

408373

409-
class AsyncMessagesStreamManagerWrapper:
374+
class AsyncMessagesStreamManagerWrapper(Generic[ResponseFormatT]):
410375
"""Wrapper for AsyncMessageStreamManager that handles telemetry.
411376
412377
Wraps AsyncMessageStreamManager from the Anthropic SDK:
@@ -415,7 +380,7 @@ class AsyncMessagesStreamManagerWrapper:
415380

416381
def __init__(
417382
self,
418-
manager: "AsyncMessageStreamManager",
383+
manager: "AsyncMessageStreamManager[ResponseFormatT]",
419384
handler: TelemetryHandler,
420385
invocation: LLMInvocation,
421386
capture_content: bool,
@@ -425,16 +390,13 @@ def __init__(
425390
self._invocation = invocation
426391
self._capture_content = capture_content
427392
self._stream_wrapper: (
428-
AsyncMessagesStreamWrapper[MessageStreamEvent] | None
393+
AsyncMessagesStreamWrapper[ResponseFormatT] | None
429394
) = None
430395

431396
async def __aenter__(
432397
self,
433-
) -> AsyncMessagesStreamWrapper[MessageStreamEvent]:
434-
msg_stream = cast(
435-
"_AsyncStream[MessageStreamEvent]",
436-
await self._manager.__aenter__(),
437-
)
398+
) -> AsyncMessagesStreamWrapper[ResponseFormatT]:
399+
msg_stream = await self._manager.__aenter__()
438400
self._stream_wrapper = AsyncMessagesStreamWrapper(
439401
msg_stream,
440402
self._handler,

0 commit comments

Comments
 (0)