|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | | -from dataclasses import asdict |
| 17 | +from dataclasses import asdict, dataclass, field |
18 | 18 | from typing import Any |
19 | 19 |
|
20 | 20 | from typing_extensions import deprecated |
|
24 | 24 | gen_ai_attributes as GenAI, |
25 | 25 | ) |
26 | 26 | from opentelemetry.semconv.attributes import server_attributes |
27 | | -from opentelemetry.trace import ( |
28 | | - INVALID_SPAN, |
29 | | - SpanKind, |
30 | | - Tracer, |
31 | | - set_span_in_context, |
32 | | -) |
| 27 | +from opentelemetry.trace import INVALID_SPAN, Span, SpanKind, Tracer |
33 | 28 | from opentelemetry.util.genai._invocation import Error, GenAIInvocation |
34 | 29 | from opentelemetry.util.genai.metrics import InvocationMetricsRecorder |
35 | 30 | from opentelemetry.util.genai.types import ( |
@@ -260,122 +255,107 @@ def _emit_event(self) -> None: |
260 | 255 |
|
261 | 256 |
|
262 | 257 | @deprecated("LLMInvocation is deprecated. Use InferenceInvocation instead.") |
263 | | -class LLMInvocation(InferenceInvocation): |
264 | | - """Deprecated. Use InferenceInvocation instead.""" |
| 258 | +@dataclass |
| 259 | +class LLMInvocation: |
| 260 | + """Deprecated. Use InferenceInvocation instead. |
265 | 261 |
|
266 | | - def __init__( # pylint: disable=too-many-locals |
267 | | - self, |
268 | | - tracer: Tracer | None = None, |
269 | | - metrics_recorder: InvocationMetricsRecorder | None = None, |
270 | | - logger: Logger | None = None, |
271 | | - provider: str = "", |
272 | | - *, |
273 | | - request_model: str | None = None, |
274 | | - input_messages: list[InputMessage] | None = None, |
275 | | - output_messages: list[OutputMessage] | None = None, |
276 | | - system_instruction: list[MessagePart] | None = None, |
277 | | - response_model_name: str | None = None, |
278 | | - response_id: str | None = None, |
279 | | - finish_reasons: list[str] | None = None, |
280 | | - input_tokens: int | None = None, |
281 | | - output_tokens: int | None = None, |
282 | | - temperature: float | None = None, |
283 | | - top_p: float | None = None, |
284 | | - frequency_penalty: float | None = None, |
285 | | - presence_penalty: float | None = None, |
286 | | - max_tokens: int | None = None, |
287 | | - stop_sequences: list[str] | None = None, |
288 | | - seed: int | None = None, |
289 | | - server_address: str | None = None, |
290 | | - server_port: int | None = None, |
291 | | - attributes: dict[str, Any] | None = None, |
292 | | - metric_attributes: dict[str, Any] | None = None, |
293 | | - ) -> None: |
294 | | - if tracer is not None: |
295 | | - super().__init__( |
296 | | - tracer, |
297 | | - metrics_recorder, |
298 | | - logger, |
299 | | - provider, |
300 | | - request_model=request_model, |
301 | | - input_messages=input_messages, |
302 | | - output_messages=output_messages, |
303 | | - system_instruction=system_instruction, |
304 | | - response_model_name=response_model_name, |
305 | | - response_id=response_id, |
306 | | - finish_reasons=finish_reasons, |
307 | | - input_tokens=input_tokens, |
308 | | - output_tokens=output_tokens, |
309 | | - temperature=temperature, |
310 | | - top_p=top_p, |
311 | | - frequency_penalty=frequency_penalty, |
312 | | - presence_penalty=presence_penalty, |
313 | | - max_tokens=max_tokens, |
314 | | - stop_sequences=stop_sequences, |
315 | | - seed=seed, |
316 | | - server_address=server_address, |
317 | | - server_port=server_port, |
318 | | - attributes=attributes, |
319 | | - metric_attributes=metric_attributes, |
320 | | - ) |
321 | | - return |
322 | | - # Old-style: data container, started later via handler.start_llm() |
323 | | - # _tracer/_metrics_recorder/_logger are set by _start_with_handler() in that case |
324 | | - self._operation_name = GenAI.GenAiOperationNameValues.CHAT.value |
325 | | - self._tracer = None |
326 | | - self._metrics_recorder = None |
327 | | - self._logger = None |
328 | | - self.attributes = {} if attributes is None else attributes |
329 | | - self.metric_attributes = ( |
330 | | - {} if metric_attributes is None else metric_attributes |
331 | | - ) |
332 | | - self.span = INVALID_SPAN |
333 | | - self._span_context = set_span_in_context(INVALID_SPAN) |
334 | | - self._span_kind = SpanKind.CLIENT |
335 | | - self._context_token = None |
336 | | - self._monotonic_start_s = None |
337 | | - self.provider = provider |
338 | | - self.request_model = request_model |
339 | | - self.input_messages = [] if input_messages is None else input_messages |
340 | | - self.output_messages = ( |
341 | | - [] if output_messages is None else output_messages |
342 | | - ) |
343 | | - self.system_instruction = ( |
344 | | - [] if system_instruction is None else system_instruction |
345 | | - ) |
346 | | - self.response_model_name = response_model_name |
347 | | - self.response_id = response_id |
348 | | - self.finish_reasons = finish_reasons |
349 | | - self.input_tokens = input_tokens |
350 | | - self.output_tokens = output_tokens |
351 | | - self.temperature = temperature |
352 | | - self.top_p = top_p |
353 | | - self.frequency_penalty = frequency_penalty |
354 | | - self.presence_penalty = presence_penalty |
355 | | - self.max_tokens = max_tokens |
356 | | - self.stop_sequences = stop_sequences |
357 | | - self.seed = seed |
358 | | - self.server_address = server_address |
359 | | - self.server_port = server_port |
360 | | - self._span_name = ( |
361 | | - f"{self._operation_name} {request_model}" |
362 | | - if request_model |
363 | | - else self._operation_name |
364 | | - ) |
| 262 | + Data container for an LLM invocation. Pass to handler.start_llm() to start |
| 263 | + the span, then update fields and call handler.stop_llm() or handler.fail_llm(). |
| 264 | + """ |
365 | 265 |
|
366 | | - @property |
367 | | - def invocation(self) -> LLMInvocation | None: # pyright: ignore[reportDeprecated] |
368 | | - """Returns self once started, None before handler.start_llm() is called.""" |
369 | | - return self if self._context_token is not None else None |
| 266 | + request_model: str | None = None |
| 267 | + input_messages: list[InputMessage] = field(default_factory=list) |
| 268 | + output_messages: list[OutputMessage] = field(default_factory=list) |
| 269 | + system_instruction: list[MessagePart] = field(default_factory=list) |
| 270 | + provider: str | None = None |
| 271 | + response_model_name: str | None = None |
| 272 | + response_id: str | None = None |
| 273 | + finish_reasons: list[str] | None = None |
| 274 | + input_tokens: int | None = None |
| 275 | + output_tokens: int | None = None |
| 276 | + attributes: dict[str, Any] = field(default_factory=dict) |
| 277 | + """Additional attributes to set on spans and/or events. Not set on metrics.""" |
| 278 | + metric_attributes: dict[str, Any] = field(default_factory=dict) |
| 279 | + """Additional attributes to set on metrics. Must be low cardinality. Not set on spans or events.""" |
| 280 | + temperature: float | None = None |
| 281 | + top_p: float | None = None |
| 282 | + frequency_penalty: float | None = None |
| 283 | + presence_penalty: float | None = None |
| 284 | + max_tokens: int | None = None |
| 285 | + stop_sequences: list[str] | None = None |
| 286 | + seed: int | None = None |
| 287 | + server_address: str | None = None |
| 288 | + server_port: int | None = None |
| 289 | + |
| 290 | + _inference_invocation: InferenceInvocation | None = field( |
| 291 | + default=None, init=False, repr=False |
| 292 | + ) |
370 | 293 |
|
371 | 294 | def _start_with_handler( |
372 | 295 | self, |
373 | 296 | tracer: Tracer, |
374 | 297 | metrics_recorder: InvocationMetricsRecorder, |
375 | 298 | logger: Logger, |
376 | 299 | ) -> None: |
377 | | - """Attach telemetry components and start the span. Called by handler.start_llm().""" |
378 | | - self._tracer = tracer |
379 | | - self._metrics_recorder = metrics_recorder |
380 | | - self._logger = logger |
381 | | - self._start() |
| 300 | + """Create and start an InferenceInvocation from this data container. Called by handler.start_llm().""" |
| 301 | + self._inference_invocation = InferenceInvocation( |
| 302 | + tracer, |
| 303 | + metrics_recorder, |
| 304 | + logger, |
| 305 | + self.provider or "", |
| 306 | + request_model=self.request_model, |
| 307 | + input_messages=self.input_messages, |
| 308 | + output_messages=self.output_messages, |
| 309 | + system_instruction=self.system_instruction, |
| 310 | + response_model_name=self.response_model_name, |
| 311 | + response_id=self.response_id, |
| 312 | + finish_reasons=self.finish_reasons, |
| 313 | + input_tokens=self.input_tokens, |
| 314 | + output_tokens=self.output_tokens, |
| 315 | + temperature=self.temperature, |
| 316 | + top_p=self.top_p, |
| 317 | + frequency_penalty=self.frequency_penalty, |
| 318 | + presence_penalty=self.presence_penalty, |
| 319 | + max_tokens=self.max_tokens, |
| 320 | + stop_sequences=self.stop_sequences, |
| 321 | + seed=self.seed, |
| 322 | + server_address=self.server_address, |
| 323 | + server_port=self.server_port, |
| 324 | + attributes=self.attributes, |
| 325 | + metric_attributes=self.metric_attributes, |
| 326 | + ) |
| 327 | + |
| 328 | + def _sync_to_invocation(self) -> None: |
| 329 | + inv = self._inference_invocation |
| 330 | + if inv is None: |
| 331 | + return |
| 332 | + inv.provider = self.provider or "" |
| 333 | + inv.request_model = self.request_model |
| 334 | + inv.input_messages = self.input_messages |
| 335 | + inv.output_messages = self.output_messages |
| 336 | + inv.system_instruction = self.system_instruction |
| 337 | + inv.response_model_name = self.response_model_name |
| 338 | + inv.response_id = self.response_id |
| 339 | + inv.finish_reasons = self.finish_reasons |
| 340 | + inv.input_tokens = self.input_tokens |
| 341 | + inv.output_tokens = self.output_tokens |
| 342 | + inv.temperature = self.temperature |
| 343 | + inv.top_p = self.top_p |
| 344 | + inv.frequency_penalty = self.frequency_penalty |
| 345 | + inv.presence_penalty = self.presence_penalty |
| 346 | + inv.max_tokens = self.max_tokens |
| 347 | + inv.stop_sequences = self.stop_sequences |
| 348 | + inv.seed = self.seed |
| 349 | + inv.server_address = self.server_address |
| 350 | + inv.server_port = self.server_port |
| 351 | + inv.attributes = self.attributes |
| 352 | + inv.metric_attributes = self.metric_attributes |
| 353 | + |
| 354 | + @property |
| 355 | + def span(self) -> Span: |
| 356 | + """The underlying span, for back-compat with code that checks span.is_recording().""" |
| 357 | + return ( |
| 358 | + self._inference_invocation.span |
| 359 | + if self._inference_invocation is not None |
| 360 | + else INVALID_SPAN |
| 361 | + ) |
0 commit comments