Skip to content

Commit 07d0e39

Browse files
refactor: responses and messages genai streams integration (#92)
* wip: responses and messages genai streams integration * polish: added changelog, fixed precommit and typechecks. * Address stream wrapper review feedback Assisted-by: ChatGPT 5.5 * wip: Code design refactor. Assisted by: GPT 5.5. --------- Co-authored-by: Leighton Chen <lechen@microsoft.com>
1 parent 7656fc8 commit 07d0e39

7 files changed

Lines changed: 287 additions & 258 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use shared GenAI stream wrappers for Messages API streams.

instrumentation/opentelemetry-instrumentation-genai-anthropic/src/opentelemetry/instrumentation/genai/anthropic/wrappers.py

Lines changed: 103 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@
99
Any,
1010
Callable,
1111
Generic,
12-
Iterator,
12+
Protocol,
1313
TypeVar,
1414
cast,
1515
)
1616

17+
from opentelemetry.util.genai.stream import (
18+
AsyncStreamWrapper,
19+
SyncStreamWrapper,
20+
)
21+
1722
from .messages_extractors import set_invocation_response_attributes
1823

1924
try:
@@ -48,6 +53,11 @@
4853
accumulate_event = cast("Callable[..., Message] | None", _sdk_accumulate_event)
4954

5055

56+
class _StreamWrapperWithStream(Protocol):
57+
@property
58+
def stream(self) -> object: ...
59+
60+
5161
def _set_response_attributes(
5262
invocation: InferenceInvocation,
5363
result: Message | None,
@@ -105,174 +115,144 @@ def message(self) -> Message:
105115
return self._message
106116

107117

108-
class MessagesStreamWrapper(
109-
Generic[ResponseFormatT],
110-
Iterator[
111-
"RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]"
112-
],
113-
):
114-
"""Wrapper for Anthropic Stream that handles telemetry."""
115-
116-
def __init__(
117-
self,
118-
stream: Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT],
119-
invocation: InferenceInvocation,
120-
capture_content: bool,
121-
):
122-
self.stream = stream
123-
self.invocation = invocation
124-
self._message: Message | ParsedMessage[ResponseFormatT] | None = None
125-
self._capture_content = capture_content
126-
self._finalized = False
127-
128-
def __enter__(self) -> MessagesStreamWrapper[ResponseFormatT]:
129-
return self
130-
131-
def __exit__(
132-
self,
133-
exc_type: type[BaseException] | None,
134-
exc_val: BaseException | None,
135-
exc_tb: TracebackType | None,
136-
) -> bool:
137-
try:
138-
if exc_val is not None:
139-
self._fail(exc_val)
140-
finally:
141-
self.close()
142-
return False
143-
144-
def close(self) -> None:
145-
try:
146-
self.stream.close()
147-
except Exception as exc:
148-
self._fail(exc)
149-
raise
150-
self._stop()
151-
152-
def __iter__(self) -> MessagesStreamWrapper[ResponseFormatT]:
153-
return self
154-
155-
def __next__(
156-
self,
157-
) -> RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]:
158-
try:
159-
chunk = next(self.stream)
160-
except StopIteration:
161-
self._stop()
162-
raise
163-
except Exception as exc:
164-
self._fail(exc)
165-
raise
166-
self._process_chunk(chunk)
167-
return chunk
168-
169-
def __getattr__(self, name: str) -> object:
170-
return getattr(self.stream, name)
171-
172-
@property
173-
def response(self):
174-
return _ResponseProxy(self.stream.response, self._stop)
118+
class _MessagesStreamMixin(Generic[ResponseFormatT]):
119+
_self_invocation: InferenceInvocation
120+
_self_message: Message | ParsedMessage[ResponseFormatT] | None
121+
_self_capture_content: bool
122+
_self_message_telemetry_finalized: bool
175123

176124
def _stop(self) -> None:
177-
if self._finalized:
125+
if self._self_message_telemetry_finalized:
178126
return
179127
_set_response_attributes(
180-
self.invocation, self._message, self._capture_content
128+
self._self_invocation,
129+
self._self_message,
130+
self._self_capture_content,
181131
)
182-
self.invocation.stop()
183-
self._finalized = True
132+
self._self_invocation.stop()
133+
self._self_message_telemetry_finalized = True
184134

185135
def _fail(self, exc: BaseException) -> None:
186-
if self._finalized:
136+
if self._self_message_telemetry_finalized:
187137
return
188-
self.invocation.fail(exc)
189-
self._finalized = True
138+
self._self_invocation.fail(exc)
139+
self._self_message_telemetry_finalized = True
140+
141+
def _on_stream_end(self) -> None:
142+
self._stop()
143+
144+
def _on_stream_error(self, error: BaseException) -> None:
145+
self._fail(error)
190146

191147
def _process_chunk(
192148
self,
193149
chunk: RawMessageStreamEvent
194150
| ParsedMessageStreamEvent[ResponseFormatT],
195151
) -> None:
196152
"""Accumulate a final message snapshot from a streaming chunk."""
153+
stream = cast(_StreamWrapperWithStream, self).stream
197154
snapshot = cast(
198155
"ParsedMessage[ResponseFormatT] | None",
199-
getattr(self.stream, "current_message_snapshot", None),
156+
getattr(stream, "current_message_snapshot", None),
200157
)
201158
if snapshot is not None:
202-
self._message = snapshot
159+
self._self_message = snapshot
203160
return
204161
if accumulate_event is None:
205162
return
206-
self._message = accumulate_event(
163+
self._self_message = accumulate_event(
207164
event=cast("RawMessageStreamEvent", chunk),
208165
current_snapshot=cast(
209-
"ParsedMessage[ResponseFormatT] | None", self._message
166+
"ParsedMessage[ResponseFormatT] | None", self._self_message
210167
),
211168
)
212169

213170

214-
class AsyncMessagesStreamWrapper(MessagesStreamWrapper[ResponseFormatT]):
215-
"""Wrapper for async Anthropic Stream that handles telemetry."""
171+
class MessagesStreamWrapper(
172+
_MessagesStreamMixin[ResponseFormatT],
173+
SyncStreamWrapper[
174+
"RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]"
175+
],
176+
Generic[ResponseFormatT],
177+
):
178+
"""Wrapper for Anthropic Stream that handles telemetry."""
216179

217180
def __init__(
218181
self,
219-
stream: AsyncStream[RawMessageStreamEvent]
220-
| AsyncMessageStream[ResponseFormatT],
182+
stream: Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT],
221183
invocation: InferenceInvocation,
222184
capture_content: bool,
223185
):
224-
self.stream = stream
225-
self.invocation = invocation
226-
self._message: Message | ParsedMessage[ResponseFormatT] | None = None
227-
self._capture_content = capture_content
228-
self._finalized = False
186+
super().__init__(stream)
187+
self._self_invocation = invocation
188+
self._self_message = None
189+
self._self_capture_content = capture_content
190+
self._self_message_telemetry_finalized = False
229191

230-
async def __aenter__(
192+
@property
193+
def response(self) -> _ResponseProxy[object]:
194+
return _ResponseProxy(self.stream.response, self._stop)
195+
196+
@property
197+
def stream(
231198
self,
232-
) -> AsyncMessagesStreamWrapper[ResponseFormatT]:
233-
return self
199+
) -> Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT]:
200+
return self._self_stream
234201

235-
async def __aexit__(
202+
@stream.setter
203+
def stream(
236204
self,
237-
exc_type: type[BaseException] | None,
238-
exc_val: BaseException | None,
239-
exc_tb: TracebackType | None,
240-
) -> bool:
241-
try:
242-
if exc_val is not None:
243-
self._fail(exc_val)
244-
finally:
245-
await self.close()
246-
return False
205+
stream: Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT],
206+
) -> None:
207+
self.__wrapped__ = stream
208+
self._self_stream = stream
209+
self._self_iterator = iter(stream)
247210

248-
async def close(self) -> None: # type: ignore[override]
249-
try:
250-
await self.stream.close()
251-
except Exception as exc:
252-
self._fail(exc)
253-
raise
254-
self._stop()
255211

256-
def __aiter__(self) -> AsyncMessagesStreamWrapper[ResponseFormatT]:
257-
return self
212+
class AsyncMessagesStreamWrapper(
213+
_MessagesStreamMixin[ResponseFormatT],
214+
AsyncStreamWrapper[
215+
"RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]"
216+
],
217+
Generic[ResponseFormatT],
218+
):
219+
"""Wrapper for async Anthropic Stream that handles telemetry."""
220+
221+
def __init__(
222+
self,
223+
stream: AsyncStream[RawMessageStreamEvent]
224+
| AsyncMessageStream[ResponseFormatT],
225+
invocation: InferenceInvocation,
226+
capture_content: bool,
227+
):
228+
super().__init__(stream)
229+
self._self_invocation = invocation
230+
self._self_message = None
231+
self._self_capture_content = capture_content
232+
self._self_message_telemetry_finalized = False
258233

259234
@property
260235
def response(self) -> Any:
261236
return _AsyncResponseProxy(self.stream.response, self._stop)
262237

263-
async def __anext__(
238+
@property
239+
def stream(
264240
self,
265-
) -> RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]:
266-
try:
267-
chunk = await self.stream.__anext__()
268-
except StopAsyncIteration:
269-
self._stop()
270-
raise
271-
except Exception as exc:
272-
self._fail(exc)
273-
raise
274-
self._process_chunk(chunk)
275-
return chunk
241+
) -> (
242+
AsyncStream[RawMessageStreamEvent]
243+
| AsyncMessageStream[ResponseFormatT]
244+
):
245+
return self._self_stream
246+
247+
@stream.setter
248+
def stream(
249+
self,
250+
stream: AsyncStream[RawMessageStreamEvent]
251+
| AsyncMessageStream[ResponseFormatT],
252+
) -> None:
253+
self.__wrapped__ = stream
254+
self._self_stream = stream
255+
self._self_aiter = aiter(stream)
276256

277257

278258
class MessagesStreamManagerWrapper(Generic[ResponseFormatT]):

0 commit comments

Comments
 (0)