|
9 | 9 | Any, |
10 | 10 | Callable, |
11 | 11 | Generic, |
12 | | - Iterator, |
| 12 | + Protocol, |
13 | 13 | TypeVar, |
14 | 14 | cast, |
15 | 15 | ) |
16 | 16 |
|
| 17 | +from opentelemetry.util.genai.stream import ( |
| 18 | + AsyncStreamWrapper, |
| 19 | + SyncStreamWrapper, |
| 20 | +) |
| 21 | + |
17 | 22 | from .messages_extractors import set_invocation_response_attributes |
18 | 23 |
|
19 | 24 | try: |
|
48 | 53 | accumulate_event = cast("Callable[..., Message] | None", _sdk_accumulate_event) |
49 | 54 |
|
50 | 55 |
|
| 56 | +class _StreamWrapperWithStream(Protocol): |
| 57 | + @property |
| 58 | + def stream(self) -> object: ... |
| 59 | + |
| 60 | + |
51 | 61 | def _set_response_attributes( |
52 | 62 | invocation: InferenceInvocation, |
53 | 63 | result: Message | None, |
@@ -105,174 +115,144 @@ def message(self) -> Message: |
105 | 115 | return self._message |
106 | 116 |
|
107 | 117 |
|
108 | | -class MessagesStreamWrapper( |
109 | | - Generic[ResponseFormatT], |
110 | | - Iterator[ |
111 | | - "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" |
112 | | - ], |
113 | | -): |
114 | | - """Wrapper for Anthropic Stream that handles telemetry.""" |
115 | | - |
116 | | - def __init__( |
117 | | - self, |
118 | | - stream: Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT], |
119 | | - invocation: InferenceInvocation, |
120 | | - capture_content: bool, |
121 | | - ): |
122 | | - self.stream = stream |
123 | | - self.invocation = invocation |
124 | | - self._message: Message | ParsedMessage[ResponseFormatT] | None = None |
125 | | - self._capture_content = capture_content |
126 | | - self._finalized = False |
127 | | - |
128 | | - def __enter__(self) -> MessagesStreamWrapper[ResponseFormatT]: |
129 | | - return self |
130 | | - |
131 | | - def __exit__( |
132 | | - self, |
133 | | - exc_type: type[BaseException] | None, |
134 | | - exc_val: BaseException | None, |
135 | | - exc_tb: TracebackType | None, |
136 | | - ) -> bool: |
137 | | - try: |
138 | | - if exc_val is not None: |
139 | | - self._fail(exc_val) |
140 | | - finally: |
141 | | - self.close() |
142 | | - return False |
143 | | - |
144 | | - def close(self) -> None: |
145 | | - try: |
146 | | - self.stream.close() |
147 | | - except Exception as exc: |
148 | | - self._fail(exc) |
149 | | - raise |
150 | | - self._stop() |
151 | | - |
152 | | - def __iter__(self) -> MessagesStreamWrapper[ResponseFormatT]: |
153 | | - return self |
154 | | - |
155 | | - def __next__( |
156 | | - self, |
157 | | - ) -> RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]: |
158 | | - try: |
159 | | - chunk = next(self.stream) |
160 | | - except StopIteration: |
161 | | - self._stop() |
162 | | - raise |
163 | | - except Exception as exc: |
164 | | - self._fail(exc) |
165 | | - raise |
166 | | - self._process_chunk(chunk) |
167 | | - return chunk |
168 | | - |
169 | | - def __getattr__(self, name: str) -> object: |
170 | | - return getattr(self.stream, name) |
171 | | - |
172 | | - @property |
173 | | - def response(self): |
174 | | - return _ResponseProxy(self.stream.response, self._stop) |
| 118 | +class _MessagesStreamMixin(Generic[ResponseFormatT]): |
| 119 | + _self_invocation: InferenceInvocation |
| 120 | + _self_message: Message | ParsedMessage[ResponseFormatT] | None |
| 121 | + _self_capture_content: bool |
| 122 | + _self_message_telemetry_finalized: bool |
175 | 123 |
|
176 | 124 | def _stop(self) -> None: |
177 | | - if self._finalized: |
| 125 | + if self._self_message_telemetry_finalized: |
178 | 126 | return |
179 | 127 | _set_response_attributes( |
180 | | - self.invocation, self._message, self._capture_content |
| 128 | + self._self_invocation, |
| 129 | + self._self_message, |
| 130 | + self._self_capture_content, |
181 | 131 | ) |
182 | | - self.invocation.stop() |
183 | | - self._finalized = True |
| 132 | + self._self_invocation.stop() |
| 133 | + self._self_message_telemetry_finalized = True |
184 | 134 |
|
185 | 135 | def _fail(self, exc: BaseException) -> None: |
186 | | - if self._finalized: |
| 136 | + if self._self_message_telemetry_finalized: |
187 | 137 | return |
188 | | - self.invocation.fail(exc) |
189 | | - self._finalized = True |
| 138 | + self._self_invocation.fail(exc) |
| 139 | + self._self_message_telemetry_finalized = True |
| 140 | + |
| 141 | + def _on_stream_end(self) -> None: |
| 142 | + self._stop() |
| 143 | + |
| 144 | + def _on_stream_error(self, error: BaseException) -> None: |
| 145 | + self._fail(error) |
190 | 146 |
|
191 | 147 | def _process_chunk( |
192 | 148 | self, |
193 | 149 | chunk: RawMessageStreamEvent |
194 | 150 | | ParsedMessageStreamEvent[ResponseFormatT], |
195 | 151 | ) -> None: |
196 | 152 | """Accumulate a final message snapshot from a streaming chunk.""" |
| 153 | + stream = cast(_StreamWrapperWithStream, self).stream |
197 | 154 | snapshot = cast( |
198 | 155 | "ParsedMessage[ResponseFormatT] | None", |
199 | | - getattr(self.stream, "current_message_snapshot", None), |
| 156 | + getattr(stream, "current_message_snapshot", None), |
200 | 157 | ) |
201 | 158 | if snapshot is not None: |
202 | | - self._message = snapshot |
| 159 | + self._self_message = snapshot |
203 | 160 | return |
204 | 161 | if accumulate_event is None: |
205 | 162 | return |
206 | | - self._message = accumulate_event( |
| 163 | + self._self_message = accumulate_event( |
207 | 164 | event=cast("RawMessageStreamEvent", chunk), |
208 | 165 | current_snapshot=cast( |
209 | | - "ParsedMessage[ResponseFormatT] | None", self._message |
| 166 | + "ParsedMessage[ResponseFormatT] | None", self._self_message |
210 | 167 | ), |
211 | 168 | ) |
212 | 169 |
|
213 | 170 |
|
214 | | -class AsyncMessagesStreamWrapper(MessagesStreamWrapper[ResponseFormatT]): |
215 | | - """Wrapper for async Anthropic Stream that handles telemetry.""" |
| 171 | +class MessagesStreamWrapper( |
| 172 | + _MessagesStreamMixin[ResponseFormatT], |
| 173 | + SyncStreamWrapper[ |
| 174 | + "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" |
| 175 | + ], |
| 176 | + Generic[ResponseFormatT], |
| 177 | +): |
| 178 | + """Wrapper for Anthropic Stream that handles telemetry.""" |
216 | 179 |
|
217 | 180 | def __init__( |
218 | 181 | self, |
219 | | - stream: AsyncStream[RawMessageStreamEvent] |
220 | | - | AsyncMessageStream[ResponseFormatT], |
| 182 | + stream: Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT], |
221 | 183 | invocation: InferenceInvocation, |
222 | 184 | capture_content: bool, |
223 | 185 | ): |
224 | | - self.stream = stream |
225 | | - self.invocation = invocation |
226 | | - self._message: Message | ParsedMessage[ResponseFormatT] | None = None |
227 | | - self._capture_content = capture_content |
228 | | - self._finalized = False |
| 186 | + super().__init__(stream) |
| 187 | + self._self_invocation = invocation |
| 188 | + self._self_message = None |
| 189 | + self._self_capture_content = capture_content |
| 190 | + self._self_message_telemetry_finalized = False |
229 | 191 |
|
230 | | - async def __aenter__( |
| 192 | + @property |
| 193 | + def response(self) -> _ResponseProxy[object]: |
| 194 | + return _ResponseProxy(self.stream.response, self._stop) |
| 195 | + |
| 196 | + @property |
| 197 | + def stream( |
231 | 198 | self, |
232 | | - ) -> AsyncMessagesStreamWrapper[ResponseFormatT]: |
233 | | - return self |
| 199 | + ) -> Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT]: |
| 200 | + return self._self_stream |
234 | 201 |
|
235 | | - async def __aexit__( |
| 202 | + @stream.setter |
| 203 | + def stream( |
236 | 204 | self, |
237 | | - exc_type: type[BaseException] | None, |
238 | | - exc_val: BaseException | None, |
239 | | - exc_tb: TracebackType | None, |
240 | | - ) -> bool: |
241 | | - try: |
242 | | - if exc_val is not None: |
243 | | - self._fail(exc_val) |
244 | | - finally: |
245 | | - await self.close() |
246 | | - return False |
| 205 | + stream: Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT], |
| 206 | + ) -> None: |
| 207 | + self.__wrapped__ = stream |
| 208 | + self._self_stream = stream |
| 209 | + self._self_iterator = iter(stream) |
247 | 210 |
|
248 | | - async def close(self) -> None: # type: ignore[override] |
249 | | - try: |
250 | | - await self.stream.close() |
251 | | - except Exception as exc: |
252 | | - self._fail(exc) |
253 | | - raise |
254 | | - self._stop() |
255 | 211 |
|
256 | | - def __aiter__(self) -> AsyncMessagesStreamWrapper[ResponseFormatT]: |
257 | | - return self |
| 212 | +class AsyncMessagesStreamWrapper( |
| 213 | + _MessagesStreamMixin[ResponseFormatT], |
| 214 | + AsyncStreamWrapper[ |
| 215 | + "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" |
| 216 | + ], |
| 217 | + Generic[ResponseFormatT], |
| 218 | +): |
| 219 | + """Wrapper for async Anthropic Stream that handles telemetry.""" |
| 220 | + |
| 221 | + def __init__( |
| 222 | + self, |
| 223 | + stream: AsyncStream[RawMessageStreamEvent] |
| 224 | + | AsyncMessageStream[ResponseFormatT], |
| 225 | + invocation: InferenceInvocation, |
| 226 | + capture_content: bool, |
| 227 | + ): |
| 228 | + super().__init__(stream) |
| 229 | + self._self_invocation = invocation |
| 230 | + self._self_message = None |
| 231 | + self._self_capture_content = capture_content |
| 232 | + self._self_message_telemetry_finalized = False |
258 | 233 |
|
259 | 234 | @property |
260 | 235 | def response(self) -> Any: |
261 | 236 | return _AsyncResponseProxy(self.stream.response, self._stop) |
262 | 237 |
|
263 | | - async def __anext__( |
| 238 | + @property |
| 239 | + def stream( |
264 | 240 | self, |
265 | | - ) -> RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]: |
266 | | - try: |
267 | | - chunk = await self.stream.__anext__() |
268 | | - except StopAsyncIteration: |
269 | | - self._stop() |
270 | | - raise |
271 | | - except Exception as exc: |
272 | | - self._fail(exc) |
273 | | - raise |
274 | | - self._process_chunk(chunk) |
275 | | - return chunk |
| 241 | + ) -> ( |
| 242 | + AsyncStream[RawMessageStreamEvent] |
| 243 | + | AsyncMessageStream[ResponseFormatT] |
| 244 | + ): |
| 245 | + return self._self_stream |
| 246 | + |
| 247 | + @stream.setter |
| 248 | + def stream( |
| 249 | + self, |
| 250 | + stream: AsyncStream[RawMessageStreamEvent] |
| 251 | + | AsyncMessageStream[ResponseFormatT], |
| 252 | + ) -> None: |
| 253 | + self.__wrapped__ = stream |
| 254 | + self._self_stream = stream |
| 255 | + self._self_aiter = aiter(stream) |
276 | 256 |
|
277 | 257 |
|
278 | 258 | class MessagesStreamManagerWrapper(Generic[ResponseFormatT]): |
|
0 commit comments