Skip to content

Commit 8bca50a

Browse files
committed
common code for context manager
1 parent 35f74be commit 8bca50a

2 files changed

Lines changed: 25 additions & 44 deletions

File tree

util/opentelemetry-util-genai/src/opentelemetry/util/genai/_invocation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
import timeit
1818
from abc import ABC, abstractmethod
19+
from contextlib import contextmanager
1920
from contextvars import Token
20-
from typing import TYPE_CHECKING, Any
21+
from typing import TYPE_CHECKING, Any, Iterator
2122

22-
from typing_extensions import TypeAlias
23+
from typing_extensions import Self, TypeAlias
2324

2425
from opentelemetry._logs import Logger
2526
from opentelemetry.context import Context, attach, detach
@@ -127,3 +128,13 @@ def fail(self, error: Error | BaseException) -> None:
127128
if isinstance(error, BaseException):
128129
error = Error(type=type(error), message=str(error))
129130
self._finish(error)
131+
132+
@contextmanager
133+
def _managed(self) -> Iterator[Self]:
134+
"""Context manager that calls stop() on success or fail() on exception."""
135+
try:
136+
yield self
137+
except Exception as exc:
138+
self.fail(exc)
139+
raise
140+
self.stop()

util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848

4949
from __future__ import annotations
5050

51-
from contextlib import contextmanager
52-
from typing import Iterator
51+
from contextlib import AbstractContextManager
5352

5453
from opentelemetry._logs import (
5554
LoggerProvider,
@@ -229,15 +228,14 @@ def fail_llm( # pylint: disable=no-self-use
229228
invocation._inference_invocation.fail(error)
230229
return invocation
231230

232-
@contextmanager
233231
def inference(
234232
self,
235233
provider: str,
236234
*,
237235
request_model: str | None = None,
238236
server_address: str | None = None,
239237
server_port: int | None = None,
240-
) -> Iterator[InferenceInvocation]:
238+
) -> AbstractContextManager[InferenceInvocation]:
241239
"""Context manager for LLM inference invocations.
242240
243241
Only set data attributes on the invocation object, do not modify the span or context.
@@ -246,28 +244,21 @@ def inference(
246244
If an exception occurs inside the context, marks the span as error, ends it, and
247245
re-raises the original exception.
248246
"""
249-
invocation = self.start_inference(
247+
return self.start_inference(
250248
provider=provider,
251249
request_model=request_model,
252250
server_address=server_address,
253251
server_port=server_port,
254-
)
255-
try:
256-
yield invocation
257-
except Exception as exc:
258-
invocation.fail(exc)
259-
raise
260-
invocation.stop()
252+
)._managed()
261253

262-
@contextmanager
263254
def embedding(
264255
self,
265256
provider: str,
266257
*,
267258
request_model: str | None = None,
268259
server_address: str | None = None,
269260
server_port: int | None = None,
270-
) -> Iterator[EmbeddingInvocation]:
261+
) -> AbstractContextManager[EmbeddingInvocation]:
271262
"""Context manager for Embedding invocations.
272263
273264
Only set data attributes on the invocation object, do not modify the span or context.
@@ -276,20 +267,13 @@ def embedding(
276267
If an exception occurs inside the context, marks the span as error, ends it, and
277268
re-raises the original exception.
278269
"""
279-
invocation = self.start_embedding(
270+
return self.start_embedding(
280271
provider=provider,
281272
request_model=request_model,
282273
server_address=server_address,
283274
server_port=server_port,
284-
)
285-
try:
286-
yield invocation
287-
except Exception as exc:
288-
invocation.fail(exc)
289-
raise
290-
invocation.stop()
275+
)._managed()
291276

292-
@contextmanager
293277
def tool(
294278
self,
295279
name: str,
@@ -298,7 +282,7 @@ def tool(
298282
tool_call_id: str | None = None,
299283
tool_type: str | None = None,
300284
tool_description: str | None = None,
301-
) -> Iterator[ToolInvocation]:
285+
) -> AbstractContextManager[ToolInvocation]:
302286
"""Context manager for Tool invocations.
303287
304288
Only set data attributes on the invocation object, do not modify the span or context.
@@ -307,25 +291,18 @@ def tool(
307291
If an exception occurs inside the context, marks the span as error, ends it, and
308292
re-raises the original exception.
309293
"""
310-
invocation = self.start_tool(
294+
return self.start_tool(
311295
name,
312296
arguments=arguments,
313297
tool_call_id=tool_call_id,
314298
tool_type=tool_type,
315299
tool_description=tool_description,
316-
)
317-
try:
318-
yield invocation
319-
except Exception as exc:
320-
invocation.fail(exc)
321-
raise
322-
invocation.stop()
300+
)._managed()
323301

324-
@contextmanager
325302
def workflow(
326303
self,
327304
name: str | None = None,
328-
) -> Iterator[WorkflowInvocation]:
305+
) -> AbstractContextManager[WorkflowInvocation]:
329306
"""Context manager for Workflow invocations.
330307
331308
Only set data attributes on the invocation object, do not modify the span or context.
@@ -334,14 +311,7 @@ def workflow(
334311
If an exception occurs inside the context, marks the span as error, ends it, and
335312
re-raises the original exception.
336313
"""
337-
invocation = self.start_workflow(name=name)
338-
339-
try:
340-
yield invocation
341-
except Exception as exc:
342-
invocation.fail(exc)
343-
raise
344-
invocation.stop()
314+
return self.start_workflow(name=name)._managed()
345315

346316

347317
def get_telemetry_handler(

0 commit comments

Comments
 (0)