Skip to content

Commit b4a9084

Browse files
authored
refactor(genai-util): pass sampling attributes at span creation time (#4538)
1 parent 0293197 commit b4a9084

4 files changed

Lines changed: 155 additions & 63 deletions

File tree

util/opentelemetry-util-genai/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
- Change `InferenceInvocation` init params to only accept base params
11+
- Pass in `attributes` on invocation `_start` so samplers have access to attributes.
12+
([#4538](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4538))
13+
1014
## Version 0.4b0 (2026-05-01)
1115

1216
- Add `AgentInvocation` type with `invoke_agent` span lifecycle

util/opentelemetry-util-genai/src/opentelemetry/util/genai/_inference_invocation.py

Lines changed: 37 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class InferenceInvocation(GenAIInvocation):
5555
context manager rather than constructing this directly.
5656
"""
5757

58-
def __init__( # pylint: disable=too-many-locals
58+
def __init__(
5959
self,
6060
tracer: Tracer,
6161
metrics_recorder: InvocationMetricsRecorder,
@@ -64,25 +64,8 @@ def __init__( # pylint: disable=too-many-locals
6464
provider: str,
6565
*,
6666
request_model: str | None = None,
67-
input_messages: list[InputMessage] | None = None,
68-
output_messages: list[OutputMessage] | None = None,
69-
system_instruction: list[MessagePart] | None = None,
70-
response_model_name: str | None = None,
71-
response_id: str | None = None,
72-
finish_reasons: list[str] | None = None,
73-
input_tokens: int | None = None,
74-
output_tokens: int | None = None,
75-
temperature: float | None = None,
76-
top_p: float | None = None,
77-
frequency_penalty: float | None = None,
78-
presence_penalty: float | None = None,
79-
max_tokens: int | None = None,
80-
stop_sequences: list[str] | None = None,
81-
seed: int | None = None,
8267
server_address: str | None = None,
8368
server_port: int | None = None,
84-
attributes: dict[str, Any] | None = None,
85-
metric_attributes: dict[str, Any] | None = None,
8669
) -> None:
8770
"""Use handler.start_inference(provider) or handler.inference(provider) instead of calling this directly."""
8871
_operation_name = GenAI.GenAiOperationNameValues.CHAT.value
@@ -96,38 +79,31 @@ def __init__( # pylint: disable=too-many-locals
9679
if request_model
9780
else _operation_name,
9881
span_kind=SpanKind.CLIENT,
99-
attributes=attributes,
100-
metric_attributes=metric_attributes,
10182
)
10283
self.provider = provider
10384
self.request_model = request_model
104-
self.input_messages: list[InputMessage] = (
105-
[] if input_messages is None else input_messages
106-
)
107-
self.output_messages: list[OutputMessage] = (
108-
[] if output_messages is None else output_messages
109-
)
110-
self.system_instruction: list[MessagePart] = (
111-
[] if system_instruction is None else system_instruction
112-
)
113-
self.response_model_name = response_model_name
114-
self.response_id = response_id
115-
self.finish_reasons = finish_reasons
116-
self.input_tokens = input_tokens
117-
self.output_tokens = output_tokens
118-
self.temperature = temperature
119-
self.top_p = top_p
120-
self.frequency_penalty = frequency_penalty
121-
self.presence_penalty = presence_penalty
122-
self.max_tokens = max_tokens
123-
self.stop_sequences = stop_sequences
124-
self.seed = seed
12585
self.server_address = server_address
12686
self.server_port = server_port
87+
88+
self.input_messages: list[InputMessage] = []
89+
self.output_messages: list[OutputMessage] = []
90+
self.system_instruction: list[MessagePart] = []
91+
self.response_model_name: str | None = None
92+
self.response_id: str | None = None
93+
self.finish_reasons: list[str] | None = None
94+
self.input_tokens: int | None = None
95+
self.output_tokens: int | None = None
96+
self.temperature: float | None = None
97+
self.top_p: float | None = None
98+
self.frequency_penalty: float | None = None
99+
self.presence_penalty: float | None = None
100+
self.max_tokens: int | None = None
101+
self.stop_sequences: list[str] | None = None
102+
self.seed: int | None = None
127103
self.cache_creation_input_tokens: int | None = None
128104
self.cache_read_input_tokens: int | None = None
129105
self.tool_definitions: list[ToolDefinition] | None = None
130-
self._start()
106+
self._start(self._get_base_attributes())
131107

132108
def _get_message_attributes(self, *, for_span: bool) -> dict[str, Any]:
133109
return get_content_attributes(
@@ -288,33 +264,34 @@ def _start_with_handler(
288264
completion_hook: CompletionHook,
289265
) -> None:
290266
"""Create and start an InferenceInvocation from this data container. Called by handler.start_llm()."""
291-
self._inference_invocation = InferenceInvocation(
267+
inv = InferenceInvocation(
292268
tracer,
293269
metrics_recorder,
294270
logger,
295271
completion_hook,
296272
self.provider or "",
297273
request_model=self.request_model,
298-
input_messages=self.input_messages,
299-
output_messages=self.output_messages,
300-
system_instruction=self.system_instruction,
301-
response_model_name=self.response_model_name,
302-
response_id=self.response_id,
303-
finish_reasons=self.finish_reasons,
304-
input_tokens=self.input_tokens,
305-
output_tokens=self.output_tokens,
306-
temperature=self.temperature,
307-
top_p=self.top_p,
308-
frequency_penalty=self.frequency_penalty,
309-
presence_penalty=self.presence_penalty,
310-
max_tokens=self.max_tokens,
311-
stop_sequences=self.stop_sequences,
312-
seed=self.seed,
313274
server_address=self.server_address,
314275
server_port=self.server_port,
315-
attributes=self.attributes,
316-
metric_attributes=self.metric_attributes,
317276
)
277+
inv.input_messages = self.input_messages
278+
inv.output_messages = self.output_messages
279+
inv.system_instruction = self.system_instruction
280+
inv.response_model_name = self.response_model_name
281+
inv.response_id = self.response_id
282+
inv.finish_reasons = self.finish_reasons
283+
inv.input_tokens = self.input_tokens
284+
inv.output_tokens = self.output_tokens
285+
inv.temperature = self.temperature
286+
inv.top_p = self.top_p
287+
inv.frequency_penalty = self.frequency_penalty
288+
inv.presence_penalty = self.presence_penalty
289+
inv.max_tokens = self.max_tokens
290+
inv.stop_sequences = self.stop_sequences
291+
inv.seed = self.seed
292+
inv.attributes.update(self.attributes)
293+
inv.metric_attributes.update(self.metric_attributes)
294+
self._inference_invocation = inv
318295

319296
def _sync_to_invocation(self) -> None:
320297
inv = self._inference_invocation

util/opentelemetry-util-genai/src/opentelemetry/util/genai/_invocation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,16 @@ def __init__(
9696
self._context_token: ContextToken | None = None
9797
self._monotonic_start_s: float | None = None
9898

99-
def _start(self) -> None:
100-
"""Start the invocation span and attach it to the current context."""
99+
def _start(self, attributes: dict[str, Any] | None = None) -> None:
100+
"""Start the invocation span and attach it to the current context.
101+
102+
Args:
103+
attributes: Initial span attributes available for sampling decisions.
104+
"""
101105
self.span = self._tracer.start_span(
102106
name=self._span_name,
103107
kind=self._span_kind,
108+
attributes=attributes,
104109
)
105110
self._span_context = set_span_in_context(self.span)
106111
self._monotonic_start_s = timeit.default_timer()

util/opentelemetry-util-genai/tests/test_utils.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
3434
InMemorySpanExporter,
3535
)
36+
from opentelemetry.sdk.trace.sampling import Decision, SamplingResult
3637
from opentelemetry.semconv._incubating.attributes import (
3738
gen_ai_attributes as GenAI,
3839
)
@@ -46,7 +47,10 @@
4647
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT,
4748
OTEL_INSTRUMENTATION_GENAI_EMIT_EVENT,
4849
)
49-
from opentelemetry.util.genai.handler import get_telemetry_handler
50+
from opentelemetry.util.genai.handler import (
51+
TelemetryHandler,
52+
get_telemetry_handler,
53+
)
5054
from opentelemetry.util.genai.types import (
5155
ContentCapturingMode,
5256
InputMessage,
@@ -357,6 +361,108 @@ def test_llm_manual_start_and_stop_creates_span(self):
357361
},
358362
)
359363

364+
def test_start_inference_passes_sampling_attributes_at_span_creation(self):
365+
"""Verify that sampling-relevant attributes are available at start_span() time."""
366+
captured_attributes = {}
367+
368+
class AttributeCapturingSampler: # pylint: disable=no-self-use
369+
"""A sampler that records the attributes passed to should_sample."""
370+
371+
def should_sample(
372+
self,
373+
parent_context,
374+
trace_id,
375+
name,
376+
kind=None,
377+
attributes=None,
378+
links=None,
379+
):
380+
captured_attributes.update(attributes or {})
381+
382+
return SamplingResult(Decision.RECORD_AND_SAMPLE, attributes)
383+
384+
def get_description(self):
385+
return "AttributeCapturingSampler"
386+
387+
sampler_provider = TracerProvider(sampler=AttributeCapturingSampler())
388+
sampler_provider.add_span_processor(
389+
SimpleSpanProcessor(self.span_exporter)
390+
)
391+
392+
handler = TelemetryHandler(tracer_provider=sampler_provider)
393+
394+
invocation = handler.start_inference(
395+
"test-provider",
396+
request_model="sampler-model",
397+
server_address="api.example.com",
398+
server_port=8080,
399+
)
400+
invocation.stop()
401+
402+
assert captured_attributes[GenAI.GEN_AI_OPERATION_NAME] == "chat"
403+
assert (
404+
captured_attributes[GenAI.GEN_AI_REQUEST_MODEL] == "sampler-model"
405+
)
406+
assert (
407+
captured_attributes[GenAI.GEN_AI_PROVIDER_NAME] == "test-provider"
408+
)
409+
assert (
410+
captured_attributes[server_attributes.SERVER_ADDRESS]
411+
== "api.example.com"
412+
)
413+
assert captured_attributes[server_attributes.SERVER_PORT] == 8080
414+
415+
def test_start_inference_sampler_can_drop_span_based_on_attributes(self):
416+
"""Verify that a sampler can reject spans based on attributes passed at creation time."""
417+
418+
class ModelRejectingSampler: # pylint: disable=no-self-use
419+
"""Drops spans whose gen_ai.request.model matches the reject list."""
420+
421+
def __init__(self, reject_models):
422+
self._reject_models = reject_models
423+
424+
def should_sample(
425+
self,
426+
parent_context,
427+
trace_id,
428+
name,
429+
kind=None,
430+
attributes=None,
431+
links=None,
432+
):
433+
model = (attributes or {}).get(GenAI.GEN_AI_REQUEST_MODEL)
434+
if model in self._reject_models:
435+
return SamplingResult(Decision.DROP)
436+
return SamplingResult(Decision.RECORD_AND_SAMPLE, attributes)
437+
438+
def get_description(self):
439+
return "ModelRejectingSampler"
440+
441+
sampler_provider = TracerProvider(
442+
sampler=ModelRejectingSampler(reject_models={"rejected-model"})
443+
)
444+
sampler_provider.add_span_processor(
445+
SimpleSpanProcessor(self.span_exporter)
446+
)
447+
448+
handler = TelemetryHandler(tracer_provider=sampler_provider)
449+
450+
# This invocation should be dropped
451+
invocation = handler.start_inference(
452+
"test-provider", request_model="rejected-model"
453+
)
454+
invocation.stop()
455+
456+
# This invocation should be recorded
457+
invocation = handler.start_inference(
458+
"test-provider", request_model="accepted-model"
459+
)
460+
invocation.stop()
461+
462+
spans = self.span_exporter.get_finished_spans()
463+
assert len(spans) == 1
464+
assert spans[0].name == "chat accepted-model"
465+
360466
def test_llm_span_finish_reasons_without_output_messages(self):
361467
invocation = self.telemetry_handler.start_inference(
362468
"test-provider", request_model="model-without-output"

0 commit comments

Comments
 (0)