diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/CHANGELOG.md b/instrumentation-genai/opentelemetry-instrumentation-langchain/CHANGELOG.md index 3db24e484e..4b05f909b0 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/CHANGELOG.md +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Enhanced the LangChain instrumentor with semconv-aligned model, agent, + workflow, tool, and retriever tracing plus richer event and W3C propagation + handling. ([#4389](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4389)) - Added span support for genAI langchain llm invocation. ([#3665](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3665)) - Added support to call genai utils handler for langchain LLM invocations. diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py index acb9a9bf7d..635d23083b 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py @@ -45,8 +45,13 @@ from opentelemetry.instrumentation.langchain.callback_handler import ( OpenTelemetryLangChainCallbackHandler, ) +from opentelemetry.instrumentation.langchain.event_emitter import EventEmitter from opentelemetry.instrumentation.langchain.package import _instruments +from opentelemetry.instrumentation.langchain.span_manager import _SpanManager +from opentelemetry.instrumentation.langchain.version import __version__ from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.semconv.schemas import Schemas +from opentelemetry.trace import get_tracer from opentelemetry.util.genai.handler import get_telemetry_handler @@ -78,8 +83,20 @@ def _instrument(self, **kwargs: Any): meter_provider=meter_provider, logger_provider=logger_provider, ) + + span_manager = _SpanManager( + tracer=get_tracer( + __name__, + __version__, + tracer_provider, + schema_url=Schemas.V1_37_0.value, + ), + ) + otel_callback_handler = OpenTelemetryLangChainCallbackHandler( telemetry_handler=telemetry_handler, + span_manager=span_manager, + event_emitter=EventEmitter(logger_provider=logger_provider), ) wrap_function_wrapper( diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py index d694857da4..80d1059729 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py @@ -14,16 +14,60 @@ from __future__ import annotations +import json +import logging +import threading +import timeit from typing import Any, Optional, cast from uuid import UUID +from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import BaseMessage from langchain_core.outputs import LLMResult +from opentelemetry.instrumentation.langchain.content_recording import ( + get_content_policy, + should_record_retriever_content, + should_record_tool_content, +) +from opentelemetry.instrumentation.langchain.event_emitter import EventEmitter from opentelemetry.instrumentation.langchain.invocation_manager import ( _InvocationManager, ) +from opentelemetry.instrumentation.langchain.message_formatting import ( + format_documents, + prepare_messages, + serialize_tool_result, +) +from opentelemetry.instrumentation.langchain.operation_mapping import ( + OperationName, + classify_chain_run, + resolve_agent_name, +) +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + GEN_AI_WORKFLOW_NAME, + METRIC_TIME_PER_OUTPUT_CHUNK, + METRIC_TIME_TO_FIRST_CHUNK, + OP_EXECUTE_TOOL, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + _SpanManager, +) +from opentelemetry.instrumentation.langchain.utils import ( + extract_propagation_context, + infer_provider_name, + infer_server_address, + infer_server_port, + propagated_context, +) +from opentelemetry.metrics import get_meter +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.semconv.attributes import server_attributes +from opentelemetry.trace import SpanKind +from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util.genai.handler import TelemetryHandler from opentelemetry.util.genai.types import ( Error, @@ -34,56 +78,150 @@ Text, ) +logger = logging.getLogger(__name__) + + +def _as_dict(value: Any) -> Optional[dict[str, Any]]: + if isinstance(value, dict): + return cast(dict[str, Any], value) + return None + + +def _has_goto(output: Any) -> bool: + """Detect LangGraph ``Command(goto=...)`` patterns in chain/tool output. + + LangGraph ``Command`` objects have both ``.goto`` and ``.update`` + attributes. The output may be the object directly, wrapped in a + dict, or inside a list/tuple. + """ + if output is None: + return False + + # Direct Command-like object + if hasattr(output, "goto") and hasattr(output, "update"): + return bool(getattr(output, "goto", None)) + + # Dict — check for "goto" key or Command-like values + output_dict = _as_dict(output) + if output_dict is not None: + if output_dict.get("goto"): + return True + for val in output_dict.values(): + if hasattr(val, "goto") and hasattr(val, "update"): + if getattr(val, "goto", None): + return True + + # List/tuple — check elements + if isinstance(output, (list, tuple)): + for item in cast(Any, output): + if hasattr(item, "goto") and hasattr(item, "update"): + if getattr(item, "goto", None): + return True + + return False + + +def _extract_chain_messages(data: Any) -> Any: + """Extract message content from chain inputs or outputs. + + LangChain stores messages under various keys depending on the + chain type. Returns the first non-None value found, or ``None``. + """ + data_dict = _as_dict(data) + if data_dict is None: + return data + + for key in ( + "messages", + "input", + "output", + "question", + "query", + "result", + "answer", + "response", + ): + value = data_dict.get(key) + if value is not None: + return value + + return None + class OpenTelemetryLangChainCallbackHandler(BaseCallbackHandler): """ A callback handler for LangChain that uses OpenTelemetry to create spans for LLM calls and chains, tools etc,. in future. """ - def __init__(self, telemetry_handler: TelemetryHandler) -> None: + def __init__( + self, + telemetry_handler: TelemetryHandler, + span_manager: Optional[_SpanManager] = None, + event_emitter: Optional[EventEmitter] = None, + ) -> None: super().__init__() self._telemetry_handler = telemetry_handler + self._span_manager = span_manager + self._event_emitter = event_emitter self._invocation_manager = _InvocationManager() - def on_chat_model_start( + # Streaming state: str(run_id) → monotonic timestamp of the last chunk + self._streaming_state: dict[str, float] = {} + self._streaming_lock = threading.Lock() + + # The TelemetryHandler handles duration and token usage metrics. + # Streaming metrics are not yet in the shared handler, so we create + # them here using the same meter. + meter = get_meter(__name__) + self._ttfc_histogram = meter.create_histogram( + name=METRIC_TIME_TO_FIRST_CHUNK, + description="Time to generate first chunk in a streaming response", + unit="s", + ) + self._tpoc_histogram = meter.create_histogram( + name=METRIC_TIME_PER_OUTPUT_CHUNK, + description="Time between consecutive chunks in a streaming response", + unit="s", + ) + + def _handle_model_start( self, serialized: dict[str, Any], - messages: list[list[BaseMessage]], + input_messages: list[InputMessage], + operation_name: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: - # Other providers/LLMs may be supported in the future and telemetry for them is skipped for now. - if serialized.get("name") not in ("ChatOpenAI", "ChatBedrock"): - return - - if "invocation_params" in kwargs: + """Shared logic for on_chat_model_start and on_llm_start.""" + invocation_params = _as_dict(kwargs.get("invocation_params")) + if invocation_params is not None: params = ( - kwargs["invocation_params"].get("params") - or kwargs["invocation_params"] + _as_dict(invocation_params.get("params")) or invocation_params ) else: params = kwargs request_model = "unknown" for model_tag in ( - "model_name", # ChatOpenAI - "model_id", # ChatBedrock + "model_name", + "model_id", + "model", + "engine", + "deployment_name", ): - if (model := (params or {}).get(model_tag)) is not None: + if (model := params.get(model_tag)) is not None: request_model = model break - elif (model := (metadata or {}).get(model_tag)) is not None: + if ( + metadata is not None + and (model := metadata.get(model_tag)) is not None + ): request_model = model break - # Skip telemetry for unsupported request models - if request_model == "unknown": - return - # Initialize variables with default values to avoid "possibly unbound" errors top_p = None frequency_penalty = None @@ -93,25 +231,93 @@ def on_chat_model_start( temperature = None max_tokens = None - if params is not None: - top_p = params.get("top_p") - frequency_penalty = params.get("frequency_penalty") - presence_penalty = params.get("presence_penalty") - stop_sequences = params.get("stop") - seed = params.get("seed") - temperature = params.get("temperature") - max_tokens = params.get("max_completion_tokens") + top_p = params.get("top_p") + frequency_penalty = params.get("frequency_penalty") + presence_penalty = params.get("presence_penalty") + stop_sequences = params.get("stop") + seed = params.get("seed") + temperature = params.get("temperature") + max_tokens = params.get("max_completion_tokens") - provider = "unknown" - if metadata is not None: - provider = metadata.get("ls_provider", "unknown") + provider = infer_provider_name(serialized, metadata, invocation_params) + if provider is None: + provider = "unknown" - # Override with ChatBedrock values if present + if metadata is not None: + # Override with metadata values if present (e.g. ChatBedrock) if "ls_temperature" in metadata: temperature = metadata.get("ls_temperature") if "ls_max_tokens" in metadata: max_tokens = metadata.get("ls_max_tokens") + server_address = infer_server_address(serialized, invocation_params) + server_port = infer_server_port(serialized, invocation_params) + + # Additional semconv request attributes + extra_attrs: dict[str, Any] = {} + + top_k = params.get("top_k") + if top_k is not None: + extra_attrs[GenAI.GEN_AI_REQUEST_TOP_K] = top_k + + # Choice count (n) — only set if != 1 + choice_count = params.get("n") + if isinstance(choice_count, int) and choice_count != 1: + extra_attrs[GenAI.GEN_AI_REQUEST_CHOICE_COUNT] = choice_count + + # Output type from response_format + response_format = params.get("response_format") + response_format_dict = _as_dict(response_format) + if response_format_dict is not None: + output_type = response_format_dict.get("type") + if output_type is not None: + extra_attrs[GenAI.GEN_AI_OUTPUT_TYPE] = output_type + elif isinstance(response_format, str): + extra_attrs[GenAI.GEN_AI_OUTPUT_TYPE] = response_format + + # Encoding formats + encoding_format = params.get("encoding_format") + if encoding_format is not None: + extra_attrs[GenAI.GEN_AI_REQUEST_ENCODING_FORMATS] = [ + encoding_format + ] + + llm_invocation = LLMInvocation( + operation_name=operation_name, + request_model=request_model, + input_messages=input_messages, + provider=provider, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stop_sequences=stop_sequences, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + server_address=server_address, + server_port=server_port, + attributes=extra_attrs, + ) + llm_invocation = self._telemetry_handler.start_llm( + invocation=llm_invocation + ) + self._invocation_manager.add_invocation_state( + run_id=run_id, + parent_run_id=parent_run_id, + invocation=llm_invocation, + ) + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: input_messages: list[InputMessage] = [] for sub_messages in messages: for message in sub_messages: @@ -140,27 +346,103 @@ def on_chat_model_start( ) ) - llm_invocation = LLMInvocation( - request_model=request_model, - input_messages=input_messages, - provider=provider, - top_p=top_p, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - stop_sequences=stop_sequences, - seed=seed, - temperature=temperature, - max_tokens=max_tokens, - ) - llm_invocation = self._telemetry_handler.start_llm( - invocation=llm_invocation + self._handle_model_start( + serialized, + input_messages, + "chat", + run_id=run_id, + parent_run_id=parent_run_id, + metadata=metadata, + **kwargs, ) - self._invocation_manager.add_invocation_state( + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + input_messages: list[InputMessage] = [ + InputMessage( + role="user", + parts=cast( + list[MessagePart], [Text(content=prompt, type="text")] + ), + ) + for prompt in prompts + ] + + self._handle_model_start( + serialized, + input_messages, + "text_completion", run_id=run_id, parent_run_id=parent_run_id, - invocation=llm_invocation, + metadata=metadata, + **kwargs, ) + def on_llm_new_token( + self, + token: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + invocation = self._invocation_manager.get_invocation(run_id) + if invocation is None or not isinstance(invocation, LLMInvocation): + return + + now = timeit.default_timer() + started_at = invocation.monotonic_start_s + if started_at is None: + return + + # Build metric attributes matching InvocationMetricsRecorder's pattern + metric_attrs: dict[str, Any] = {} + if invocation.operation_name: + metric_attrs[GenAI.GEN_AI_OPERATION_NAME] = ( + invocation.operation_name + ) + if invocation.request_model: + metric_attrs[GenAI.GEN_AI_REQUEST_MODEL] = invocation.request_model + if invocation.provider: + metric_attrs[GenAI.GEN_AI_PROVIDER_NAME] = invocation.provider + if invocation.response_model_name: + metric_attrs[GenAI.GEN_AI_RESPONSE_MODEL] = ( + invocation.response_model_name + ) + if invocation.server_address: + metric_attrs[server_attributes.SERVER_ADDRESS] = ( + invocation.server_address + ) + if invocation.server_port is not None: + metric_attrs[server_attributes.SERVER_PORT] = ( + invocation.server_port + ) + + run_key = str(run_id) + with self._streaming_lock: + last_chunk_at = self._streaming_state.get(run_key) + self._streaming_state[run_key] = now + + if last_chunk_at is None: + # First token — record time to first chunk + self._ttfc_histogram.record( + max(now - started_at, 0.0), attributes=metric_attrs + ) + else: + # Subsequent token — record time per output chunk + self._tpoc_histogram.record( + max(now - last_chunk_at, 0.0), attributes=metric_attrs + ) + def on_llm_end( self, response: LLMResult, @@ -184,22 +466,25 @@ def on_llm_end( generation_info = getattr( chat_generation, "generation_info", None ) - if generation_info is not None: - finish_reason = generation_info.get( + generation_info_dict = _as_dict(generation_info) + if generation_info_dict is not None: + finish_reason = generation_info_dict.get( "finish_reason", "unknown" ) if chat_generation.message: # Get finish reason if generation_info is None above if ( - generation_info is None + generation_info_dict is None and chat_generation.message.response_metadata ): - finish_reason = ( - chat_generation.message.response_metadata.get( + response_metadata = _as_dict( + chat_generation.message.response_metadata + ) + if response_metadata is not None: + finish_reason = response_metadata.get( "stopReason", "unknown" ) - ) # Get message content parts = [ @@ -217,24 +502,55 @@ def on_llm_end( output_messages.append(output_message) # Get token usage if available - if chat_generation.message.usage_metadata: - input_tokens = ( - chat_generation.message.usage_metadata.get( - "input_tokens", 0 - ) - ) + usage_metadata = _as_dict( + chat_generation.message.usage_metadata + ) + if usage_metadata is not None: + input_tokens = usage_metadata.get("input_tokens", 0) llm_invocation.input_tokens = input_tokens - output_tokens = ( - chat_generation.message.usage_metadata.get( - "output_tokens", 0 + output_tokens = usage_metadata.get("output_tokens", 0) + llm_invocation.output_tokens = output_tokens + + # Cache token attributes (when provider exposes them) + # Check direct keys (Anthropic-style) and input_token_details (LangChain-style) + cache_read = usage_metadata.get( + "cache_read_input_tokens" + ) + if cache_read is None: + input_token_details = _as_dict( + usage_metadata.get("input_token_details") ) + if input_token_details is not None: + cache_read = input_token_details.get( + "cache_read" + ) + if cache_read is not None and llm_invocation.span: + llm_invocation.span.set_attribute( + "gen_ai.usage.cache_read.input_tokens", + int(cache_read), + ) + + cache_creation = usage_metadata.get( + "cache_creation_input_tokens" ) - llm_invocation.output_tokens = output_tokens + if cache_creation is None: + input_token_details = _as_dict( + usage_metadata.get("input_token_details") + ) + if input_token_details is not None: + cache_creation = input_token_details.get( + "cache_creation" + ) + if cache_creation is not None and llm_invocation.span: + llm_invocation.span.set_attribute( + "gen_ai.usage.cache_creation.input_tokens", + int(cache_creation), + ) llm_invocation.output_messages = output_messages - llm_output = getattr(response, "llm_output", None) + llm_output = _as_dict(getattr(response, "llm_output", None)) if llm_output is not None: response_model = llm_output.get("model_name") or llm_output.get( "model" @@ -246,9 +562,39 @@ def on_llm_end( if response_id is not None: llm_invocation.response_id = str(response_id) + # OpenAI-specific response attributes + if llm_output is not None: + system_fingerprint = llm_output.get("system_fingerprint") + if system_fingerprint: + if llm_invocation.span: + llm_invocation.span.set_attribute( + "openai.response.system_fingerprint", + str(system_fingerprint), + ) + + service_tier = llm_output.get("service_tier") + if service_tier: + if llm_invocation.span: + llm_invocation.span.set_attribute( + "openai.response.service_tier", + str(service_tier), + ) + + with self._streaming_lock: + self._streaming_state.pop(str(run_id), None) + llm_invocation = self._telemetry_handler.stop_llm( invocation=llm_invocation ) + + # Propagate token usage to the nearest ancestor agent span. + if self._span_manager and parent_run_id: + self._span_manager.accumulate_llm_usage_to_agent( + parent_run_id, + llm_invocation.input_tokens, + llm_invocation.output_tokens, + ) + if llm_invocation.span and not llm_invocation.span.is_recording(): self._invocation_manager.delete_invocation_state(run_id=run_id) @@ -268,8 +614,537 @@ def on_llm_error( return error_otel = Error(message=str(error), type=type(error)) + with self._streaming_lock: + self._streaming_state.pop(str(run_id), None) + llm_invocation = self._telemetry_handler.fail_llm( invocation=llm_invocation, error=error_otel ) if llm_invocation.span and not llm_invocation.span.is_recording(): self._invocation_manager.delete_invocation_state(run_id=run_id) + + # ------------------------------------------------------------------ + # Chain callbacks (agent / workflow spans) + # ------------------------------------------------------------------ + + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + operation = classify_chain_run( + serialized, metadata, kwargs, parent_run_id + ) + if operation is None: + self._span_manager.ignore_run(run_id, parent_run_id) + return + + attributes: dict[str, Any] = { + GenAI.GEN_AI_OPERATION_NAME: operation, + } + span_name = operation + + if operation == OperationName.INVOKE_AGENT: + agent_name = resolve_agent_name(serialized, metadata, kwargs) + if agent_name: + attributes[GenAI.GEN_AI_AGENT_NAME] = agent_name + span_name = f"{operation} {agent_name}" + + if metadata: + agent_id = metadata.get("agent_id") + if agent_id: + attributes[GenAI.GEN_AI_AGENT_ID] = str(agent_id) + + agent_desc = metadata.get("agent_description") + if agent_desc: + attributes[GenAI.GEN_AI_AGENT_DESCRIPTION] = str( + agent_desc + ) + + for key in ("thread_id", "session_id", "conversation_id"): + conv_id = metadata.get(key) + if conv_id: + attributes[GenAI.GEN_AI_CONVERSATION_ID] = str(conv_id) + break + + provider = infer_provider_name(serialized, metadata, None) + if provider: + attributes[GenAI.GEN_AI_PROVIDER_NAME] = provider + elif metadata: + provider_name = metadata.get("provider_name") + if provider_name: + attributes[GenAI.GEN_AI_PROVIDER_NAME] = str(provider_name) + + elif operation == OperationName.INVOKE_WORKFLOW: + workflow_name = kwargs.get("name") or serialized.get("name") + if workflow_name: + attributes[GEN_AI_WORKFLOW_NAME] = str(workflow_name) + span_name = f"{operation} {workflow_name}" + + # Content recording (opt-in) + policy = get_content_policy() + formatted_input_messages = None + system_instructions = None + if policy.record_content: + raw_messages = _extract_chain_messages(inputs) + if raw_messages: + formatted_input_messages, system_instructions = ( + prepare_messages( + raw_messages, + record_content=True, + ) + ) + if policy.should_record_content_on_spans: + if formatted_input_messages: + attributes[GenAI.GEN_AI_INPUT_MESSAGES] = ( + formatted_input_messages + ) + if system_instructions: + attributes[GenAI.GEN_AI_SYSTEM_INSTRUCTIONS] = ( + system_instructions + ) + + headers = extract_propagation_context(metadata, inputs, kwargs) + + # Prefer the LangGraph logical thread_id from metadata; fall back + # to the OS thread identifier for non-LangGraph chains. + thread_key = ( + str(metadata["thread_id"]) + if metadata and metadata.get("thread_id") + else str(threading.get_ident()) + ) + + # For agent nodes, check for a goto-parent override produced by a + # preceding LangGraph Command(goto=...) transition. + effective_parent = parent_run_id + if operation == OperationName.INVOKE_AGENT: + goto_parent = self._span_manager.pop_goto_parent(thread_key) + if goto_parent: + effective_parent = goto_parent + + with propagated_context(headers): + record = self._span_manager.start_span( + run_id=run_id, + name=span_name, + operation=operation, + parent_run_id=effective_parent, + attributes=attributes, + thread_key=thread_key, + ) + + if ( + self._event_emitter is not None + and operation == OperationName.INVOKE_AGENT + ): + self._event_emitter.emit_agent_start_event( + record.span, + attributes.get(GenAI.GEN_AI_AGENT_NAME, span_name), + formatted_input_messages, + ) + + def on_chain_end( + self, + outputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + if self._span_manager.is_ignored(run_id): + self._span_manager.clear_ignored_run(run_id) + return + + record = self._span_manager.get_record(run_id) + if record is None: + return + + policy = get_content_policy() + formatted_output_messages = None + if policy.record_content: + raw_messages = _extract_chain_messages(outputs) + if raw_messages: + formatted_output_messages, _ = prepare_messages( + raw_messages, + record_content=True, + ) + if policy.should_record_content_on_spans and formatted_output_messages: + record.span.set_attribute( + GenAI.GEN_AI_OUTPUT_MESSAGES, formatted_output_messages + ) + record.attributes[GenAI.GEN_AI_OUTPUT_MESSAGES] = ( + formatted_output_messages + ) + + if ( + self._event_emitter is not None + and record.operation == OperationName.INVOKE_AGENT + ): + self._event_emitter.emit_agent_end_event( + record.span, + record.attributes.get( + GenAI.GEN_AI_AGENT_NAME, record.operation + ), + formatted_output_messages, + ) + + # Detect LangGraph Command(goto=...) — the goto target should + # become a child of this node's nearest agent ancestor. + if _has_goto(outputs): + thread_key = record.stash.get("thread_key") + if thread_key: + agent_parent = self._span_manager.nearest_agent_parent(record) + if agent_parent: + self._span_manager.push_goto_parent( + thread_key, agent_parent + ) + + self._span_manager.end_span(run_id, status=StatusCode.OK) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + if self._span_manager.is_ignored(run_id): + self._span_manager.clear_ignored_run(run_id) + return + + self._span_manager.end_span(run_id, error=error) + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager and parent_run_id: + parent_key = self._span_manager.resolve_parent_id(parent_run_id) + record = ( + self._span_manager.get_record(parent_key) + if parent_key + else None + ) + if record: + tool_input: Any = getattr(action, "tool_input", None) + record.stash.setdefault("pending_actions", {})[str(run_id)] = { + "tool": action.tool, + "tool_input": tool_input, + "log": action.log, + } + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if not self._span_manager: + return + record = self._span_manager.get_record(str(run_id)) + if not record: + return + return_values: Any = getattr(finish, "return_values", None) + if return_values: + record.span.set_attribute( + GenAI.GEN_AI_OUTPUT_MESSAGES, + json.dumps(return_values, default=str), + ) + record.span.set_status(Status(StatusCode.OK)) + self._span_manager.end_span(run_id) + + # ------------------------------------------------------------------ + # Tool callbacks + # ------------------------------------------------------------------ + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + metadata = metadata or {} + inputs = inputs or {} + + # Resolve tool name from multiple sources + tool_name = ( + serialized.get("name") + or metadata.get("tool_name") + or kwargs.get("name") + or "unknown_tool" + ) + + span_name = f"{OP_EXECUTE_TOOL} {tool_name}" + + # Build initial attributes + attributes: dict[str, Any] = { + GenAI.GEN_AI_OPERATION_NAME: OP_EXECUTE_TOOL, + GenAI.GEN_AI_TOOL_NAME: tool_name, + } + + # Tool description + description = serialized.get("description") + if description: + attributes[GenAI.GEN_AI_TOOL_DESCRIPTION] = str(description) + + # Tool type: from serialized or infer from definition + tool_type = serialized.get("type") + if tool_type: + attributes[GenAI.GEN_AI_TOOL_TYPE] = str(tool_type) + + # Tool call ID + tool_call_id = inputs.get("tool_call_id") or metadata.get( + "tool_call_id" + ) + if tool_call_id: + attributes[GenAI.GEN_AI_TOOL_CALL_ID] = str(tool_call_id) + + # Inherit provider from parent span if available + resolved_parent_id = self._span_manager.resolve_parent_id( + parent_run_id + ) + if resolved_parent_id is not None: + parent_record = self._span_manager.get_record(resolved_parent_id) + if parent_record is not None: + provider = parent_record.attributes.get( + GenAI.GEN_AI_PROVIDER_NAME + ) + if provider: + attributes[GenAI.GEN_AI_PROVIDER_NAME] = provider + + # Tool call arguments (opt-in content) + policy = get_content_policy() + arguments = None + if policy.record_content: + if inputs: + arg_data = { + k: v for k, v in inputs.items() if k != "tool_call_id" + } + if arg_data: + arguments = json.dumps(arg_data, default=str) + if not arguments and input_str: + arguments = input_str + if should_record_tool_content(policy) and arguments: + attributes[GenAI.GEN_AI_TOOL_CALL_ARGUMENTS] = arguments + + # Thread key for agent stack tracking + thread_key = metadata.get("thread_id") + + record = self._span_manager.start_span( + run_id=str(run_id), + name=span_name, + operation=OP_EXECUTE_TOOL, + parent_run_id=str(parent_run_id) if parent_run_id else None, + attributes=attributes, + thread_key=str(thread_key) if thread_key else None, + ) + + if self._event_emitter is not None: + self._event_emitter.emit_tool_call_event( + record.span, + tool_name, + arguments, + str(tool_call_id) if tool_call_id else None, + ) + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + record = self._span_manager.get_record(str(run_id)) + if record is None: + return + + # Record tool result (opt-in content) + policy = get_content_policy() + result_str = None + if policy.record_content and output is not None: + result_str = serialize_tool_result(output, record_content=True) + if should_record_tool_content(policy) and result_str is not None: + record.span.set_attribute( + GenAI.GEN_AI_TOOL_CALL_RESULT, result_str + ) + record.attributes[GenAI.GEN_AI_TOOL_CALL_RESULT] = result_str + + if self._event_emitter is not None: + self._event_emitter.emit_tool_result_event( + record.span, + record.attributes.get(GenAI.GEN_AI_TOOL_NAME, "tool"), + result_str, + record.attributes.get(GenAI.GEN_AI_TOOL_CALL_ID), + ) + + # Detect LangGraph Command(goto=...) from tool output. + if _has_goto(output): + thread_key = record.stash.get("thread_key") + if thread_key: + agent_parent = self._span_manager.nearest_agent_parent(record) + if agent_parent: + self._span_manager.push_goto_parent( + thread_key, agent_parent + ) + + self._span_manager.end_span(run_id=str(run_id), status=StatusCode.OK) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + self._span_manager.end_span(run_id=str(run_id), error=error) + + # ------------------------------------------------------------------ + # Retriever callbacks + # ------------------------------------------------------------------ + + def on_retriever_start( + self, + serialized: dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + metadata = metadata or {} + + tool_name = serialized.get("name", "retriever") + span_name = f"{OP_EXECUTE_TOOL} {tool_name}" + + attributes: dict[str, Any] = { + GenAI.GEN_AI_OPERATION_NAME: OP_EXECUTE_TOOL, + GenAI.GEN_AI_TOOL_NAME: tool_name, + GenAI.GEN_AI_TOOL_DESCRIPTION: serialized.get( + "description", "retriever" + ), + GenAI.GEN_AI_TOOL_TYPE: "retriever", + } + + # Inherit provider from parent span if available + resolved_parent_id = self._span_manager.resolve_parent_id( + parent_run_id + ) + if resolved_parent_id is not None: + parent_record = self._span_manager.get_record(resolved_parent_id) + if parent_record is not None: + provider = parent_record.attributes.get( + GenAI.GEN_AI_PROVIDER_NAME + ) + if provider: + attributes[GenAI.GEN_AI_PROVIDER_NAME] = provider + + # Query text (opt-in content) + policy = get_content_policy() + if should_record_retriever_content(policy): + attributes["gen_ai.retrieval.query.text"] = query + + thread_key = metadata.get("thread_id") + + record = self._span_manager.start_span( + run_id=str(run_id), + name=span_name, + operation=OP_EXECUTE_TOOL, + kind=SpanKind.INTERNAL, + parent_run_id=str(parent_run_id) if parent_run_id else None, + attributes=attributes, + thread_key=str(thread_key) if thread_key else None, + ) + + if self._event_emitter is not None: + self._event_emitter.emit_retriever_query_event( + record.span, + tool_name, + query if policy.record_content else None, + ) + + def on_retriever_end( + self, + documents: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + record = self._span_manager.get_record(str(run_id)) + if record is None: + return + + policy = get_content_policy() + record_content = should_record_retriever_content(policy) + formatted = format_documents(documents, record_content=record_content) + if formatted is not None: + record.span.set_attribute("gen_ai.retrieval.documents", formatted) + record.attributes["gen_ai.retrieval.documents"] = formatted + + if self._event_emitter is not None: + self._event_emitter.emit_retriever_result_event( + record.span, + record.attributes.get(GenAI.GEN_AI_TOOL_NAME, "retriever"), + formatted, + ) + + self._span_manager.end_span(run_id=str(run_id), status=StatusCode.OK) + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self._span_manager is None: + return + + self._span_manager.end_span(run_id=str(run_id), error=error) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/content_recording.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/content_recording.py new file mode 100644 index 0000000000..caf86b4ee3 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/content_recording.py @@ -0,0 +1,109 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thin integration layer over the shared genai content capture utilities. + +Provides clear APIs for the LangChain callback handler to decide what +content should be recorded on spans and events. +""" + +from typing import Optional + +from opentelemetry.util.genai.types import ContentCapturingMode +from opentelemetry.util.genai.utils import ( + get_content_capturing_mode, + is_experimental_mode, + should_emit_event, +) + + +class ContentPolicy: + """Determines what content should be recorded on spans and events. + + Wraps the shared genai utility functions to provide a clean API + for the callback handler. All properties are evaluated lazily so + that environment variable changes are picked up immediately. + """ + + @property + def should_record_content_on_spans(self) -> bool: + """Whether message/tool content should be recorded as span attributes.""" + return self.mode in ( + ContentCapturingMode.SPAN_ONLY, + ContentCapturingMode.SPAN_AND_EVENT, + ) + + @property + def should_emit_events(self) -> bool: + """Whether content events should be emitted.""" + return should_emit_event() + + @property + def record_content(self) -> bool: + """Whether content should be recorded at all (spans or events).""" + return self.should_record_content_on_spans or self.should_emit_events + + @property + def mode(self) -> ContentCapturingMode: + """The current content capturing mode. + + Returns ``NO_CONTENT`` when not running in experimental mode. + """ + if not is_experimental_mode(): + return ContentCapturingMode.NO_CONTENT + return get_content_capturing_mode() + + +# -- Helper functions for specific content types ------------------------------ +# All opt-in content types follow the same underlying policy today. Separate +# helpers are provided so call-sites read clearly and so that per-type +# overrides can be added later without changing every caller. + + +def should_record_messages(policy: ContentPolicy) -> bool: + """Whether input/output messages should be recorded on spans.""" + return policy.should_record_content_on_spans + + +def should_record_tool_content(policy: ContentPolicy) -> bool: + """Whether tool arguments and results should be recorded on spans.""" + return policy.should_record_content_on_spans + + +def should_record_retriever_content(policy: ContentPolicy) -> bool: + """Whether retriever queries and document content should be recorded.""" + return policy.should_record_content_on_spans + + +def should_record_system_instructions(policy: ContentPolicy) -> bool: + """Whether system instructions should be recorded on spans.""" + return policy.should_record_content_on_spans + + +# -- Default singleton -------------------------------------------------------- + +_default_policy: Optional[ContentPolicy] = None + + +def get_content_policy() -> ContentPolicy: + """Get the content policy based on current environment configuration. + + Returns a module-level singleton. Because the policy reads + environment variables lazily on every property access, a single + instance is sufficient. + """ + global _default_policy # noqa: PLW0603 + if _default_policy is None: + _default_policy = ContentPolicy() + return _default_policy diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/event_emitter.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/event_emitter.py new file mode 100644 index 0000000000..c4b931d08f --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/event_emitter.py @@ -0,0 +1,208 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Event emission for non-LLM GenAI operations in LangChain. + +Emits semantic-convention-aligned log-record events for tool, agent, and +retriever spans. LLM event emission is handled by the shared +``TelemetryHandler`` and is **not** duplicated here. + +All event emission is gated behind the content policy so that events are +only produced when the user opts in via the +``OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT`` / +``OTEL_INSTRUMENTATION_GENAI_EMIT_EVENT`` environment variables. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from opentelemetry._logs import Logger, LoggerProvider, LogRecord, get_logger +from opentelemetry.context import get_current +from opentelemetry.instrumentation.langchain.content_recording import ( + get_content_policy, +) +from opentelemetry.instrumentation.langchain.version import __version__ +from opentelemetry.semconv.schemas import Schemas +from opentelemetry.trace import Span +from opentelemetry.trace.propagation import set_span_in_context + +_REDACTED = "[redacted]" + + +class EventEmitter: + """Emits GenAI semantic convention events for LangChain operations. + + Events are emitted as ``LogRecord`` instances linked to the active span + context, following the same pattern used by the OpenAI v2 instrumentor + and the shared ``_maybe_emit_llm_event`` helper in ``span_utils``. + """ + + def __init__( + self, logger_provider: Optional[LoggerProvider] = None + ) -> None: + self._logger: Logger = get_logger( + __name__, + __version__, + logger_provider, + schema_url=Schemas.V1_37_0.value, + ) + + # ------------------------------------------------------------------ + # Tool events + # ------------------------------------------------------------------ + + def emit_tool_call_event( + self, + span: Span, + tool_name: str, + arguments: Optional[str] = None, + tool_call_id: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.tool.call`` event when a tool is invoked.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": tool_name} + if tool_call_id: + body["id"] = tool_call_id + if arguments is not None: + body["arguments"] = ( + arguments if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.tool.call", body) + + def emit_tool_result_event( + self, + span: Span, + tool_name: str, + result: Optional[str] = None, + tool_call_id: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.tool.result`` event when a tool returns.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": tool_name} + if tool_call_id: + body["id"] = tool_call_id + if result is not None: + body["result"] = result if policy.record_content else _REDACTED + + self._emit(span, "gen_ai.tool.result", body) + + # ------------------------------------------------------------------ + # Agent events + # ------------------------------------------------------------------ + + def emit_agent_start_event( + self, + span: Span, + agent_name: str, + input_messages: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.agent.start`` event when an agent begins.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": agent_name} + if input_messages is not None: + body["input"] = ( + input_messages if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.agent.start", body) + + def emit_agent_end_event( + self, + span: Span, + agent_name: str, + output_messages: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.agent.end`` event when an agent completes.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": agent_name} + if output_messages is not None: + body["output"] = ( + output_messages if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.agent.end", body) + + # ------------------------------------------------------------------ + # Retriever events + # ------------------------------------------------------------------ + + def emit_retriever_query_event( + self, + span: Span, + retriever_name: str, + query: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.retriever.query`` event for a retriever query.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": retriever_name} + if query is not None: + body["query"] = query if policy.record_content else _REDACTED + + self._emit(span, "gen_ai.retriever.query", body) + + def emit_retriever_result_event( + self, + span: Span, + retriever_name: str, + documents: Optional[str] = None, + ) -> None: + """Emit a ``gen_ai.retriever.result`` event with retrieved docs.""" + if not self._should_emit(): + return + + policy = get_content_policy() + body: dict[str, Any] = {"name": retriever_name} + if documents is not None: + body["documents"] = ( + documents if policy.record_content else _REDACTED + ) + + self._emit(span, "gen_ai.retriever.result", body) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _should_emit() -> bool: + """Check whether event emission is enabled via content policy.""" + return get_content_policy().should_emit_events + + def _emit(self, span: Span, event_name: str, body: dict[str, Any]) -> None: + """Create a ``LogRecord`` linked to *span* and emit it.""" + context = set_span_in_context(span, get_current()) + self._logger.emit( + LogRecord( + event_name=event_name, + body=body, + context=context, + ) + ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/message_formatting.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/message_formatting.py new file mode 100644 index 0000000000..30918c51ab --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/message_formatting.py @@ -0,0 +1,541 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Message, tool, and document serialization with content-redaction support. + +Converts LangChain message objects into the compact JSON format expected by +OpenTelemetry GenAI semantic convention span attributes +(``gen_ai.input_messages``, ``gen_ai.output_messages``, +``gen_ai.system_instructions``, ``gen_ai.tool_definitions``, etc.). + +Redaction behaviour +------------------- +When *record_content* is ``False``: + +* Text content → ``"[redacted]"`` +* Tool call arguments → ``"[redacted]"`` +* Tool call results → ``"[redacted]"`` +* Document page content → omitted (only metadata is kept) +* System instruction content → ``"[redacted]"`` +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, cast + +from opentelemetry.util.genai.utils import gen_ai_json_dumps + +logger = logging.getLogger(__name__) + +_REDACTED = "[redacted]" + + +def _as_dict(value: Any) -> Optional[Dict[str, Any]]: + if isinstance(value, dict): + return cast(Dict[str, Any], value) + return None + + +def _as_sequence(value: Any) -> Optional[Sequence[Any]]: + if isinstance(value, (list, tuple)): + return cast(Sequence[Any], value) + return None + + +# --------------------------------------------------------------------------- +# Role mapping +# --------------------------------------------------------------------------- + +# LangChain message type → OpenTelemetry GenAI role +_ROLE_MAP: Dict[str, str] = { + "human": "user", + "HumanMessage": "user", + "ai": "assistant", + "AIMessage": "assistant", + "AIMessageChunk": "assistant", + "system": "system", + "SystemMessage": "system", + "tool": "tool", + "ToolMessage": "tool", + "function": "tool", + "FunctionMessage": "tool", + "chat": "user", + "ChatMessage": "user", +} + + +def message_role(message: Any) -> str: + """Map a LangChain message to its GenAI role. + + Handles ``BaseMessage`` subclasses (via ``.type``), plain dicts + (via ``"role"`` or ``"type"`` keys), and falls back to ``"user"``. + """ + # BaseMessage subclass + msg_type = getattr(message, "type", None) + if isinstance(msg_type, str): + mapped = _ROLE_MAP.get(msg_type) + if mapped is not None: + return mapped + + # Dict-like message + message_dict = _as_dict(message) + if message_dict is not None: + for key in ("role", "type"): + value = message_dict.get(key) + if isinstance(value, str): + mapped = _ROLE_MAP.get(value) + if mapped is not None: + return mapped + # If the value itself is already a canonical role, accept it + if value in ("user", "assistant", "system", "tool"): + return value + + # Class-name fallback + cls_name = type(message).__name__ + mapped = _ROLE_MAP.get(cls_name) + if mapped is not None: + return mapped + + return "user" + + +# --------------------------------------------------------------------------- +# Content extraction +# --------------------------------------------------------------------------- + + +def message_content(message: Any) -> Optional[str]: + """Extract text content from a LangChain message. + + Returns ``None`` when no text content is available. Multi-part content + lists are concatenated with newlines. + """ + raw: Any = getattr(message, "content", None) + message_dict = _as_dict(message) + if raw is None and message_dict is not None: + raw = message_dict.get("content") + + if raw is None: + return None + + if isinstance(raw, str): + return raw if raw else None + + # Multi-part content (list of strings / dicts with "text" key) + raw_parts = _as_sequence(raw) + if raw_parts is not None: + parts: list[str] = [] + for item in raw_parts: + if isinstance(item, str): + parts.append(item) + else: + item_dict = _as_dict(item) + if item_dict is None: + continue + text_value = item_dict.get("text") + if isinstance(text_value, str) and text_value: + parts.append(text_value) + return "\n".join(parts) if parts else None + + return str(raw) if raw else None + + +# --------------------------------------------------------------------------- +# Tool-call extraction +# --------------------------------------------------------------------------- + + +def extract_tool_calls(message: Any) -> List[Dict[str, Any]]: + """Extract tool calls from an ``AIMessage`` or dict. + + Returns a (possibly empty) list of dicts, each with keys + ``"id"``, ``"name"``, and ``"arguments"``. + """ + tool_calls: Any = getattr(message, "tool_calls", None) + message_dict = _as_dict(message) + if tool_calls is None and message_dict is not None: + tool_calls = message_dict.get("tool_calls") + + tool_call_items = _as_sequence(tool_calls) + if not tool_call_items: + return [] + + result: List[Dict[str, Any]] = [] + for tc in tool_call_items: + entry: Dict[str, Any] = {} + + tc_dict = _as_dict(tc) + if tc_dict is not None: + entry["id"] = tc_dict.get("id") or "" + entry["name"] = tc_dict.get("name") or "" + entry["arguments"] = tc_dict.get("args") or tc_dict.get( + "arguments" + ) + else: + entry["id"] = getattr(tc, "id", "") or "" + entry["name"] = getattr(tc, "name", "") or "" + entry["arguments"] = getattr(tc, "args", None) or getattr( + tc, "arguments", None + ) + + result.append(entry) + return result + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _format_tool_call_part( + tc: Dict[str, Any], record_content: bool +) -> Dict[str, Any]: + """Build a serialised tool-call part dict.""" + part: Dict[str, Any] = {"type": "tool_call"} + if tc.get("id"): + part["id"] = tc["id"] + if tc.get("name"): + part["name"] = tc["name"] + + args = tc.get("arguments") + if record_content: + if args is not None: + part["arguments"] = args + else: + part["arguments"] = _REDACTED + + return part + + +def _format_tool_response_part( + message: Any, record_content: bool +) -> Dict[str, Any]: + """Build a serialised tool-call-response part dict.""" + part: Dict[str, Any] = {"type": "tool_call_response"} + + tool_call_id = getattr(message, "tool_call_id", None) + message_dict = _as_dict(message) + if tool_call_id is None and message_dict is not None: + tool_call_id = message_dict.get("tool_call_id") + if tool_call_id: + part["id"] = tool_call_id + + if record_content: + content = message_content(message) + if content is not None: + part["result"] = content + else: + part["result"] = _REDACTED + + return part + + +def _format_text_parts( + message: Any, record_content: bool +) -> List[Dict[str, Any]]: + """Build text-content part dicts for a message.""" + content = message_content(message) + if content is None: + return [] + + return [ + { + "type": "text", + "content": content if record_content else _REDACTED, + } + ] + + +def _format_single_message( + message: Any, record_content: bool +) -> Dict[str, Any]: + """Serialise one LangChain message into the GenAI convention dict.""" + role = message_role(message) + parts: List[Dict[str, Any]] = [] + + if role == "assistant": + # Tool calls first, then text + for tc in extract_tool_calls(message): + parts.append(_format_tool_call_part(tc, record_content)) + parts.extend(_format_text_parts(message, record_content)) + + elif role == "tool": + parts.append(_format_tool_response_part(message, record_content)) + + else: + # user, system, or any other role + parts.extend(_format_text_parts(message, record_content)) + + result: Dict[str, Any] = {"role": role} + if parts: + result["parts"] = parts + return result + + +def _flatten_messages(raw_messages: Any) -> List[Any]: + """Accept messages in multiple shapes and return a flat list. + + LangChain callbacks may pass ``list[list[BaseMessage]]`` (grouped by + prompt) or a simple ``list[BaseMessage]``. + """ + if not raw_messages: + return [] + + raw_sequence = _as_sequence(raw_messages) + if raw_sequence is None: + return [raw_messages] + + # Check for nested lists (list[list[BaseMessage]]) + flat: list[Any] = [] + for item in raw_sequence: + nested_items = _as_sequence(item) + if nested_items is not None: + flat.extend(nested_items) + else: + flat.append(item) + return flat + + +# --------------------------------------------------------------------------- +# Public API – prepare_messages +# --------------------------------------------------------------------------- + + +def prepare_messages( + raw_messages: Any, + *, + record_content: bool, + include_roles: Optional[Set[str]] = None, +) -> Tuple[Optional[str], Optional[str]]: + """Serialise LangChain messages to JSON strings for span attributes. + + Returns ``(formatted_json, system_instructions_json)``: + + * *formatted_json* – JSON array of non-system messages, suitable for + ``gen_ai.input_messages`` / ``gen_ai.output_messages``. + * *system_instructions_json* – JSON array of system-message *parts* + only, suitable for ``gen_ai.system_instructions``. + + Either value may be ``None`` when no messages of that kind exist. + + Parameters + ---------- + raw_messages: + Messages as received from LangChain callbacks. May be a flat list or + a nested ``list[list[BaseMessage]]``. + record_content: + When ``False``, text payloads and tool arguments/results are replaced + with ``"[redacted]"``. + include_roles: + Optional filter. When provided, only messages whose mapped role is in + the set are included. + """ + messages = _flatten_messages(raw_messages) + if not messages: + return None, None + + formatted: List[Dict[str, Any]] = [] + system_parts: List[Dict[str, Any]] = [] + + for msg in messages: + role = message_role(msg) + + if include_roles is not None and role not in include_roles: + continue + + if role == "system": + # System messages contribute to system_instructions only + content = message_content(msg) + if content is not None: + system_parts.append( + { + "type": "text", + "content": content if record_content else _REDACTED, + } + ) + continue + + formatted.append(_format_single_message(msg, record_content)) + + formatted_json = gen_ai_json_dumps(formatted) if formatted else None + system_json = gen_ai_json_dumps(system_parts) if system_parts else None + + return formatted_json, system_json + + +# --------------------------------------------------------------------------- +# Document formatting (for retrievers) +# --------------------------------------------------------------------------- + + +def format_documents( + documents: Optional[Sequence[Any]], *, record_content: bool +) -> Optional[str]: + """Format retrieved documents as a JSON string for span attributes. + + Each document is serialised as a dict with optional ``page_content`` + (when *record_content* is ``True``) and ``metadata`` fields. + + Returns ``None`` when *documents* is empty or ``None``. + """ + if not documents: + return None + + result: List[Dict[str, Any]] = [] + for doc in documents: + entry: Dict[str, Any] = {} + doc_dict = _as_dict(doc) + + # page_content + page_content = getattr(doc, "page_content", None) + if page_content is None and doc_dict is not None: + page_content = doc_dict.get("page_content") + + if record_content and page_content is not None: + entry["page_content"] = str(page_content) + + # metadata + metadata = getattr(doc, "metadata", None) + if metadata is None and doc_dict is not None: + metadata = doc_dict.get("metadata") + if metadata: + entry["metadata"] = metadata + + if entry: + result.append(entry) + + return gen_ai_json_dumps(result) if result else None + + +# --------------------------------------------------------------------------- +# Tool result serialization +# --------------------------------------------------------------------------- + + +def serialize_tool_result(output: Any, record_content: bool) -> str: + """Serialise a tool result for span attributes. + + When *record_content* is ``False`` the literal ``"[redacted]"`` is + returned. + """ + if not record_content: + return _REDACTED + + if isinstance(output, str): + return output + + # Try common attribute shapes produced by LangChain tools + content = getattr(output, "content", None) + if content is not None: + return str(content) + + output_dict = _as_dict(output) + if output_dict is not None: + content = output_dict.get("content") or output_dict.get("output") + if content is not None: + return str(content) + + # Fallback: JSON-encode arbitrary values + try: + return gen_ai_json_dumps(output) + except (TypeError, ValueError): + return str(output) + + +# --------------------------------------------------------------------------- +# Tool definitions formatting +# --------------------------------------------------------------------------- + + +def format_tool_definitions(definitions: Optional[Any]) -> Optional[str]: + """Format tool definitions for ``gen_ai.tool_definitions`` span attribute. + + Accepts a list of LangChain tool objects, dicts, or any mix thereof and + returns a compact JSON string. Returns ``None`` when *definitions* is + empty or ``None``. + """ + if not definitions: + return None + + definition_items = _as_sequence(definitions) + if definition_items is None: + definition_items = [definitions] + + result: List[Dict[str, Any]] = [] + for defn in definition_items: + entry: Dict[str, Any] = {} + defn_dict = _as_dict(defn) + + if defn_dict is not None: + # Already a dict – keep recognised keys + if "name" in defn_dict: + entry["name"] = defn_dict["name"] + if "description" in defn_dict: + entry["description"] = defn_dict["description"] + if "parameters" in defn_dict: + entry["parameters"] = defn_dict["parameters"] + + func_dict = _as_dict(defn_dict.get("function")) + if func_dict is not None: + func_name = func_dict.get("name") + if "name" not in entry and func_name is not None: + entry["name"] = func_name + func_description = func_dict.get("description") + if "description" not in entry and func_description is not None: + entry["description"] = func_description + func_parameters = func_dict.get("parameters") + if "parameters" not in entry and func_parameters is not None: + entry["parameters"] = func_parameters + + entry.setdefault("type", defn_dict.get("type", "function")) + else: + # Object with attributes (e.g. a LangChain BaseTool) + name = getattr(defn, "name", None) + if name is not None: + entry["name"] = str(name) + + description = getattr(defn, "description", None) + if description is not None: + entry["description"] = str(description) + + args_schema = getattr(defn, "args_schema", None) + if args_schema is not None: + schema_method = getattr(args_schema, "schema", None) + if callable(schema_method): + try: + entry["parameters"] = schema_method() + except Exception: # noqa: BLE001 + pass + + entry.setdefault("type", "function") + + if entry: + result.append(entry) + + return gen_ai_json_dumps(result) if result else None + + +# --------------------------------------------------------------------------- +# JSON helper +# --------------------------------------------------------------------------- + + +def as_json_attribute(value: Any) -> str: + """Return a JSON string suitable for OpenTelemetry string attributes. + + Uses the same compact encoder (no whitespace, base64 for bytes) as + the rest of the GenAI instrumentation. + """ + return gen_ai_json_dumps(value) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/operation_mapping.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/operation_mapping.py new file mode 100644 index 0000000000..6402a5a93e --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/operation_mapping.py @@ -0,0 +1,256 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Callback-to-semconv operation mapping for LangChain callbacks. + +Maps each LangChain callback to the correct GenAI semantic convention +operation name. Direct callbacks (``on_chat_model_start``, +``on_llm_start``, ``on_tool_start``, ``on_retriever_start``) have a +fixed 1-to-1 mapping. ``on_chain_start`` requires heuristic +classification because LangChain emits this callback for agents, +workflows, and internal plumbing alike. +""" + +from __future__ import annotations + +from typing import Any, Optional +from uuid import UUID + +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) + +__all__ = [ + "OperationName", + "classify_chain_run", + "resolve_agent_name", + "should_ignore_chain", +] + +# --------------------------------------------------------------------------- +# Operation name constants (sourced from the GenAI semconv enum where +# available, with string fallbacks for values not yet in the enum). +# --------------------------------------------------------------------------- + + +class OperationName: + """Canonical GenAI semantic convention operation names.""" + + CHAT: str = GenAI.GenAiOperationNameValues.CHAT.value + TEXT_COMPLETION: str = GenAI.GenAiOperationNameValues.TEXT_COMPLETION.value + INVOKE_AGENT: str = GenAI.GenAiOperationNameValues.INVOKE_AGENT.value + EXECUTE_TOOL: str = GenAI.GenAiOperationNameValues.EXECUTE_TOOL.value + # invoke_workflow is not yet in the semconv enum; use the expected + # string value so the mapping is forward-compatible. + INVOKE_WORKFLOW: str = "invoke_workflow" + + +# --------------------------------------------------------------------------- +# LangGraph markers – names and prefixes produced by LangGraph that must +# be recognized when classifying ``on_chain_start`` callbacks. +# --------------------------------------------------------------------------- + +LANGGRAPH_NODE_KEY = "langgraph_node" +LANGGRAPH_START_NODE = "__start__" +MIDDLEWARE_PREFIX = "Middleware." +LANGGRAPH_IDENTIFIER = "LangGraph" + +# Metadata keys used by callers to override classification. +_META_AGENT_SPAN = "otel_agent_span" +_META_WORKFLOW_SPAN = "otel_workflow_span" +_META_AGENT_NAME = "agent_name" +_META_AGENT_TYPE = "agent_type" +_META_OTEL_TRACE = "otel_trace" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def resolve_agent_name( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + kwargs: dict[str, Any], +) -> Optional[str]: + """Derive the best-effort agent name from callback arguments. + + Checks (in priority order): + 1. ``metadata["agent_name"]`` + 2. ``kwargs["name"]`` + 3. ``serialized["name"]`` + 4. ``metadata["langgraph_node"]`` (if present and not a start node) + """ + if metadata: + name = metadata.get(_META_AGENT_NAME) + if name: + return str(name) + + name = kwargs.get("name") + if name: + return str(name) + + name = serialized.get("name") + if name: + return str(name) + + if metadata: + node = metadata.get(LANGGRAPH_NODE_KEY) + if node and node != LANGGRAPH_START_NODE: + return str(node) + + return None + + +def _has_agent_signals(metadata: Optional[dict[str, Any]]) -> bool: + """Return True when metadata contains any signal that the chain is an agent.""" + if not metadata: + return False + return bool( + metadata.get(_META_AGENT_SPAN) + or metadata.get(_META_AGENT_NAME) + or metadata.get(_META_AGENT_TYPE) + ) + + +def _is_langgraph_agent_node( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + kwargs: dict[str, Any], +) -> bool: + """Detect a LangGraph agent node that is not a start/middleware node.""" + if not metadata: + return False + + node = metadata.get(LANGGRAPH_NODE_KEY) + if not node: + return False + + # Exclude start and middleware nodes. + if node == LANGGRAPH_START_NODE: + return False + + name = resolve_agent_name(serialized, metadata, kwargs) + if name and name.startswith(MIDDLEWARE_PREFIX): + return False + + return True + + +def _looks_like_workflow( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + parent_run_id: Optional[UUID], +) -> bool: + """Return True if the chain looks like a top-level workflow/graph.""" + if parent_run_id is not None: + return False + + # An explicit workflow override is authoritative. + if metadata and metadata.get(_META_WORKFLOW_SPAN): + return True + + # Heuristic: check for LangGraph identifier in the serialized repr. + name = serialized.get("name", "") + graph_id = ( + serialized.get("graph", {}).get("id", "") + if isinstance(serialized.get("graph"), dict) + else "" + ) + return LANGGRAPH_IDENTIFIER in name or LANGGRAPH_IDENTIFIER in graph_id + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def should_ignore_chain( + metadata: Optional[dict[str, Any]], + agent_name: Optional[str], + parent_run_id: Optional[UUID], + kwargs: dict[str, Any], +) -> bool: + """Return True if the chain callback should be silently suppressed. + + Suppression happens when: + * The node is the LangGraph ``__start__`` node. + * The name carries the ``Middleware.`` prefix. + * ``metadata["otel_trace"]`` is explicitly ``False``. + * ``metadata["otel_agent_span"]`` is explicitly ``False`` and no other + agent signals are present. + """ + if metadata: + node = metadata.get(LANGGRAPH_NODE_KEY) + if node == LANGGRAPH_START_NODE: + return True + + if metadata.get(_META_OTEL_TRACE) is False: + return True + + if ( + metadata.get(_META_AGENT_SPAN) is False + and not metadata.get(_META_AGENT_NAME) + and not metadata.get(_META_AGENT_TYPE) + ): + return True + + if agent_name and agent_name.startswith(MIDDLEWARE_PREFIX): + return True + + name_from_kwargs = kwargs.get("name", "") + if isinstance(name_from_kwargs, str) and name_from_kwargs.startswith( + MIDDLEWARE_PREFIX + ): + return True + + return False + + +def classify_chain_run( + serialized: dict[str, Any], + metadata: Optional[dict[str, Any]], + kwargs: dict[str, Any], + parent_run_id: Optional[UUID] = None, +) -> Optional[str]: + """Classify a ``on_chain_start`` callback into a semconv operation. + + Returns one of the :class:`OperationName` constants, or ``None`` when + the chain should be suppressed (no span emitted). + + Classification order: + 1. Check for explicit suppression signals. + 2. Check for agent signals → ``invoke_agent``. + 3. Check for workflow signals → ``invoke_workflow``. + 4. Default: ``None`` (suppress – unclassified chains are not emitted). + """ + agent_name = resolve_agent_name(serialized, metadata, kwargs) + + # 1. Suppress known noise. + if should_ignore_chain(metadata, agent_name, parent_run_id, kwargs): + return None + + # 2. Agent detection. + if _has_agent_signals(metadata): + return OperationName.INVOKE_AGENT + + if _is_langgraph_agent_node(serialized, metadata, kwargs): + return OperationName.INVOKE_AGENT + + # 3. Workflow / orchestration detection. + if _looks_like_workflow(serialized, metadata, parent_run_id): + return OperationName.INVOKE_WORKFLOW + + # 4. Default: suppress unclassified chains. + return None diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/semconv_attributes.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/semconv_attributes.py new file mode 100644 index 0000000000..a5e5249eeb --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/semconv_attributes.py @@ -0,0 +1,313 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-operation attribute matrix based on OTel GenAI semantic conventions. + +Single source of truth for which attributes apply to which operations +in the LangChain instrumentor. Attribute requirement levels follow: +https://opentelemetry.io/docs/specs/semconv/gen-ai/ +""" + +from __future__ import annotations + +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.semconv._incubating.attributes import ( + server_attributes as Server, +) +from opentelemetry.semconv.attributes import ( + error_attributes as Error, +) +from opentelemetry.trace import SpanKind + +# --------------------------------------------------------------------------- +# Operation name constants +# --------------------------------------------------------------------------- + +OP_CHAT = GenAI.GenAiOperationNameValues.CHAT.value # "chat" +OP_TEXT_COMPLETION = ( + GenAI.GenAiOperationNameValues.TEXT_COMPLETION.value +) # "text_completion" +OP_INVOKE_AGENT = ( + GenAI.GenAiOperationNameValues.INVOKE_AGENT.value +) # "invoke_agent" +OP_EXECUTE_TOOL = ( + GenAI.GenAiOperationNameValues.EXECUTE_TOOL.value +) # "execute_tool" + +# These operations are not yet in the semconv enum; define as literals. +OP_INVOKE_WORKFLOW = "invoke_workflow" +OP_RETRIEVAL = "retrieval" + +# --------------------------------------------------------------------------- +# Attribute key aliases (not yet in the released semconv package) +# --------------------------------------------------------------------------- + +GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS = "gen_ai.usage.cache_read.input_tokens" +GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS = ( + "gen_ai.usage.cache_creation.input_tokens" +) +GEN_AI_AGENT_VERSION = "gen_ai.agent.version" +GEN_AI_WORKFLOW_NAME = "gen_ai.workflow.name" + +# --------------------------------------------------------------------------- +# Attribute sets per operation, grouped by requirement level +# +# Requirement levels (per OpenTelemetry specification): +# REQUIRED – MUST be provided. +# CONDITIONALLY_REQ – MUST be provided when the stated condition is met. +# RECOMMENDED – SHOULD be provided. +# OPT_IN – MAY be provided; typically gated by a config flag. +# --------------------------------------------------------------------------- + +# ---- chat / text_completion (inference client spans) ---------------------- + +INFERENCE_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + GenAI.GEN_AI_PROVIDER_NAME, + } +) + +INFERENCE_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_REQUEST_MODEL, # if available + Error.ERROR_TYPE, # if response is an error + Server.SERVER_PORT, # if server.address is set + GenAI.GEN_AI_REQUEST_SEED, # if present in request + GenAI.GEN_AI_REQUEST_CHOICE_COUNT, # if != 1 + GenAI.GEN_AI_OUTPUT_TYPE, # if applicable + GenAI.GEN_AI_CONVERSATION_ID, # if available + } +) + +INFERENCE_RECOMMENDED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_REQUEST_MAX_TOKENS, + GenAI.GEN_AI_REQUEST_TEMPERATURE, + GenAI.GEN_AI_REQUEST_TOP_P, + GenAI.GEN_AI_REQUEST_TOP_K, + GenAI.GEN_AI_REQUEST_STOP_SEQUENCES, + GenAI.GEN_AI_REQUEST_FREQUENCY_PENALTY, + GenAI.GEN_AI_REQUEST_PRESENCE_PENALTY, + GenAI.GEN_AI_RESPONSE_ID, + GenAI.GEN_AI_RESPONSE_MODEL, + GenAI.GEN_AI_RESPONSE_FINISH_REASONS, + GenAI.GEN_AI_USAGE_INPUT_TOKENS, + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS, + GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS, + Server.SERVER_ADDRESS, + } +) + +INFERENCE_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_SYSTEM_INSTRUCTIONS, + GenAI.GEN_AI_INPUT_MESSAGES, + GenAI.GEN_AI_OUTPUT_MESSAGES, + GenAI.GEN_AI_TOOL_DEFINITIONS, + } +) + +# ---- invoke_agent -------------------------------------------------------- + +AGENT_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + GenAI.GEN_AI_PROVIDER_NAME, + } +) + +AGENT_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_AGENT_ID, + GenAI.GEN_AI_AGENT_NAME, + GenAI.GEN_AI_AGENT_DESCRIPTION, + GEN_AI_AGENT_VERSION, + GenAI.GEN_AI_REQUEST_MODEL, + GenAI.GEN_AI_DATA_SOURCE_ID, + Error.ERROR_TYPE, # if response is an error + GenAI.GEN_AI_CONVERSATION_ID, + } +) + +AGENT_RECOMMENDED: frozenset[str] = frozenset( + { + Server.SERVER_ADDRESS, + # All inference request/response attributes are also recommended + GenAI.GEN_AI_REQUEST_MAX_TOKENS, + GenAI.GEN_AI_REQUEST_TEMPERATURE, + GenAI.GEN_AI_REQUEST_TOP_P, + GenAI.GEN_AI_REQUEST_TOP_K, + GenAI.GEN_AI_REQUEST_STOP_SEQUENCES, + GenAI.GEN_AI_REQUEST_FREQUENCY_PENALTY, + GenAI.GEN_AI_REQUEST_PRESENCE_PENALTY, + GenAI.GEN_AI_RESPONSE_ID, + GenAI.GEN_AI_RESPONSE_MODEL, + GenAI.GEN_AI_RESPONSE_FINISH_REASONS, + GenAI.GEN_AI_USAGE_INPUT_TOKENS, + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS, + GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS, + } +) + +AGENT_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_SYSTEM_INSTRUCTIONS, + GenAI.GEN_AI_INPUT_MESSAGES, + GenAI.GEN_AI_OUTPUT_MESSAGES, + } +) + +# ---- execute_tool -------------------------------------------------------- + +TOOL_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + } +) + +TOOL_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + Error.ERROR_TYPE, # if response is an error + } +) + +TOOL_RECOMMENDED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_TOOL_NAME, + GenAI.GEN_AI_TOOL_CALL_ID, + GenAI.GEN_AI_TOOL_DESCRIPTION, + GenAI.GEN_AI_TOOL_TYPE, + } +) + +TOOL_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_TOOL_CALL_ARGUMENTS, + GenAI.GEN_AI_TOOL_CALL_RESULT, + } +) + +# ---- invoke_workflow ----------------------------------------------------- + +WORKFLOW_REQUIRED: frozenset[str] = frozenset( + { + GenAI.GEN_AI_OPERATION_NAME, + } +) + +WORKFLOW_CONDITIONALLY_REQUIRED: frozenset[str] = frozenset( + { + Error.ERROR_TYPE, # if response is an error + GEN_AI_WORKFLOW_NAME, # if available + } +) + +WORKFLOW_RECOMMENDED: frozenset[str] = frozenset() + +WORKFLOW_OPT_IN: frozenset[str] = frozenset( + { + GenAI.GEN_AI_INPUT_MESSAGES, + GenAI.GEN_AI_OUTPUT_MESSAGES, + } +) + +# --------------------------------------------------------------------------- +# Aggregate lookup: operation → (required, conditionally_required, +# recommended, opt_in) +# --------------------------------------------------------------------------- + +OPERATION_ATTRIBUTES: dict[ + str, + tuple[ + frozenset[str], + frozenset[str], + frozenset[str], + frozenset[str], + ], +] = { + OP_CHAT: ( + INFERENCE_REQUIRED, + INFERENCE_CONDITIONALLY_REQUIRED, + INFERENCE_RECOMMENDED, + INFERENCE_OPT_IN, + ), + OP_TEXT_COMPLETION: ( + INFERENCE_REQUIRED, + INFERENCE_CONDITIONALLY_REQUIRED, + INFERENCE_RECOMMENDED, + INFERENCE_OPT_IN, + ), + OP_INVOKE_AGENT: ( + AGENT_REQUIRED, + AGENT_CONDITIONALLY_REQUIRED, + AGENT_RECOMMENDED, + AGENT_OPT_IN, + ), + OP_EXECUTE_TOOL: ( + TOOL_REQUIRED, + TOOL_CONDITIONALLY_REQUIRED, + TOOL_RECOMMENDED, + TOOL_OPT_IN, + ), + OP_INVOKE_WORKFLOW: ( + WORKFLOW_REQUIRED, + WORKFLOW_CONDITIONALLY_REQUIRED, + WORKFLOW_RECOMMENDED, + WORKFLOW_OPT_IN, + ), +} + +# --------------------------------------------------------------------------- +# SpanKind helper +# --------------------------------------------------------------------------- + +_CLIENT_OPERATIONS: frozenset[str] = frozenset( + {OP_CHAT, OP_TEXT_COMPLETION, OP_INVOKE_AGENT} +) + + +def get_operation_span_kind(operation: str) -> SpanKind: + """Return the correct SpanKind for the given operation. + + * ``chat``, ``text_completion``, ``invoke_agent`` → ``SpanKind.CLIENT`` + * ``execute_tool``, ``invoke_workflow``, and others → ``SpanKind.INTERNAL`` + """ + if operation in _CLIENT_OPERATIONS: + return SpanKind.CLIENT + return SpanKind.INTERNAL + + +# --------------------------------------------------------------------------- +# Metric applicability +# +# Maps metric instrument names to the set of operations they apply to. +# --------------------------------------------------------------------------- + +METRIC_OPERATION_DURATION = "gen_ai.client.operation.duration" +METRIC_TOKEN_USAGE = "gen_ai.client.token.usage" +METRIC_TIME_TO_FIRST_CHUNK = "gen_ai.client.operation.time_to_first_chunk" +METRIC_TIME_PER_OUTPUT_CHUNK = "gen_ai.client.operation.time_per_output_chunk" + +METRIC_APPLICABLE_OPERATIONS: dict[str, frozenset[str]] = { + METRIC_OPERATION_DURATION: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), + METRIC_TOKEN_USAGE: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), + # Streaming-only metrics (chat / text_completion) + METRIC_TIME_TO_FIRST_CHUNK: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), + METRIC_TIME_PER_OUTPUT_CHUNK: frozenset({OP_CHAT, OP_TEXT_COMPLETION}), +} diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py index 7ce588b618..a9c6758aa6 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/span_manager.py @@ -12,106 +12,417 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Span lifecycle manager for the LangChain instrumentor. + +Manages creation, parent-context resolution, ignored-run walk-through, +per-thread agent stacks, and clean teardown for all GenAI operation types. +""" + +from __future__ import annotations + +import threading +import time from dataclasses import dataclass, field -from typing import Dict, List, Optional -from uuid import UUID +from typing import Any, Dict, List, Optional +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_CHAT, + OP_INVOKE_AGENT, + OP_TEXT_COMPLETION, + get_operation_span_kind, +) from opentelemetry.semconv._incubating.attributes import ( gen_ai_attributes as GenAI, ) -from opentelemetry.semconv.attributes import ( - error_attributes, -) +from opentelemetry.semconv.attributes import error_attributes from opentelemetry.trace import Span, SpanKind, Tracer, set_span_in_context from opentelemetry.trace.status import Status, StatusCode __all__ = ["_SpanManager"] +# Operations that produce model-level duration metrics. +_MODEL_OPERATIONS: frozenset[str] = frozenset({OP_CHAT, OP_TEXT_COMPLETION}) + + +def _empty_attributes() -> Dict[str, Any]: + return {} + @dataclass -class _SpanState: +class SpanRecord: + """Rich record stored for every active span.""" + + run_id: str span: Span - children: List[UUID] = field(default_factory=lambda: list()) + operation: str + parent_run_id: Optional[str] = None + attributes: Dict[str, Any] = field(default_factory=_empty_attributes) + # Mutable scratch space for streaming timing, thread keys, etc. + stash: Dict[str, Any] = field(default_factory=_empty_attributes) class _SpanManager: - def __init__( + """Thread-safe span lifecycle manager for every GenAI operation type.""" + + def __init__(self, tracer: Tracer) -> None: + self._tracer = tracer + self._lock = threading.Lock() + + # run_id (str) → SpanRecord + self._spans: Dict[str, SpanRecord] = {} + + # Runs we decided to skip (e.g. internal LangChain plumbing) but + # whose children should still be linked to the correct parent. + self._ignored_runs: set[str] = set() + # Maps an ignored run_id to the parent_run_id it was called with, + # so children can walk through to the real ancestor. + self._run_parent_override: Dict[str, Optional[str]] = {} + + # Per-thread stacks of invoke_agent run_ids for hierarchy tracking + # in concurrent execution. key = thread_key (str). + self._agent_stack_by_thread: Dict[str, List[str]] = {} + + # Per-thread stacks for LangGraph Command(goto=...) transitions. + # key = thread_key (str), value = stack of parent_run_ids. + self._goto_parent_stack: Dict[str, List[str]] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def start_span( + self, + run_id: str | object, + name: str, + operation: str, + kind: Optional[SpanKind] = None, + parent_run_id: Optional[str | object] = None, + attributes: Optional[Dict[str, Any]] = None, + thread_key: Optional[str] = None, + ) -> SpanRecord: + """Create and register a new span. + + Parameters + ---------- + run_id: + Unique identifier for this run (UUID or str). + name: + Human-readable span name (e.g. ``"chat gpt-4o"``). + operation: + One of the ``OP_*`` constants from ``semconv_attributes``. + kind: + Override the SpanKind. When *None* the kind is derived from + *operation* via ``get_operation_span_kind``. + parent_run_id: + The run_id of the logical parent (may be an ignored run). + attributes: + Initial span attributes to set immediately. + thread_key: + Identifies the execution thread; used for agent stack tracking. + """ + rid = str(run_id) + prid = str(parent_run_id) if parent_run_id is not None else None + + if kind is None: + kind = get_operation_span_kind(operation) + + # Walk through ignored runs so children attach to the correct + # visible ancestor. + resolved_prid = self._resolve_parent_id(prid) + + # Build parent context. + ctx = None + with self._lock: + if resolved_prid is not None: + parent_record = self._spans.get(resolved_prid) + if parent_record is not None: + ctx = set_span_in_context(parent_record.span) + + span = self._tracer.start_span(name=name, kind=kind, context=ctx) + + attrs = attributes or {} + for attr_key, attr_val in attrs.items(): + span.set_attribute(attr_key, attr_val) + + stash: Dict[str, Any] = {} + if operation in _MODEL_OPERATIONS: + stash["started_at"] = time.perf_counter() + if thread_key is not None: + stash["thread_key"] = thread_key + + record = SpanRecord( + run_id=rid, + span=span, + operation=operation, + parent_run_id=prid, + attributes=attrs, + stash=stash, + ) + + with self._lock: + self._spans[rid] = record + + # Maintain per-thread agent stack. + if operation == OP_INVOKE_AGENT and thread_key is not None: + self._agent_stack_by_thread.setdefault(thread_key, []).append( + rid + ) + + return record + + def end_span( self, - tracer: Tracer, + run_id: str | object, + status: Optional[StatusCode] = None, + error: Optional[BaseException] = None, ) -> None: - self._tracer = tracer + """Finalise and end the span identified by *run_id*. + + Parameters + ---------- + run_id: + The run whose span should be ended. + status: + Explicit status code. When *error* is provided this defaults to + ``StatusCode.ERROR``. + error: + If supplied the span is marked as failed with ``error.type`` + recorded as an attribute. + """ + rid = str(run_id) + + with self._lock: + record = self._spans.pop(rid, None) + if record is None: + return + + span = record.span - # Map from run_id -> _SpanState, to keep track of spans and parent/child relationships - # TODO: Use weak references or a TTL cache to avoid memory leaks in long-running processes. See #3735 - self.spans: Dict[UUID, _SpanState] = {} + if error is not None: + span.set_attribute( + error_attributes.ERROR_TYPE, type(error).__qualname__ + ) + span.set_status(Status(StatusCode.ERROR, str(error))) + elif status is not None: + span.set_status(Status(status)) + + # Pop from agent stack if applicable. + thread_key = record.stash.get("thread_key") + if record.operation == OP_INVOKE_AGENT and thread_key is not None: + with self._lock: + stack = self._agent_stack_by_thread.get(thread_key) + if stack: + try: + stack.remove(rid) + except ValueError: + pass + if not stack: + del self._agent_stack_by_thread[thread_key] + + span.end() + + def get_record(self, run_id: str | object) -> Optional[SpanRecord]: + """Return the ``SpanRecord`` for *run_id*, or ``None``.""" + rid = str(run_id) + with self._lock: + return self._spans.get(rid) + + # ------------------------------------------------------------------ + # Ignored-run management + # ------------------------------------------------------------------ + + def ignore_run( + self, + run_id: str | object, + parent_run_id: Optional[str | object] = None, + ) -> None: + """Mark *run_id* as ignored. + + Any future child whose ``parent_run_id`` points at an ignored run + will be re-parented to the ignored run's own parent via + ``resolve_parent_id``. + """ + rid = str(run_id) + prid = str(parent_run_id) if parent_run_id is not None else None + with self._lock: + self._ignored_runs.add(rid) + self._run_parent_override[rid] = prid - def _create_span( + def is_ignored(self, run_id: str | object) -> bool: + rid = str(run_id) + with self._lock: + return rid in self._ignored_runs + + def clear_ignored_run(self, run_id: str | object) -> None: + """Remove ignored-run bookkeeping for *run_id*.""" + rid = str(run_id) + with self._lock: + self._ignored_runs.discard(rid) + self._run_parent_override.pop(rid, None) + + def resolve_parent_id( + self, parent_run_id: Optional[str | object] + ) -> Optional[str]: + """Public wrapper around the internal resolver.""" + prid = str(parent_run_id) if parent_run_id is not None else None + return self._resolve_parent_id(prid) + + # ------------------------------------------------------------------ + # Token usage accumulation + # ------------------------------------------------------------------ + + def _accumulate_on_record( self, - run_id: UUID, - parent_run_id: Optional[UUID], - span_name: str, - kind: SpanKind = SpanKind.INTERNAL, - ) -> Span: - if parent_run_id is not None and parent_run_id in self.spans: - parent_state = self.spans[parent_run_id] - parent_span = parent_state.span - ctx = set_span_in_context(parent_span) - span = self._tracer.start_span( - name=span_name, kind=kind, context=ctx + record: SpanRecord, + input_tokens: Optional[int], + output_tokens: Optional[int], + ) -> None: + """Add token counts to *record*. Caller **must** hold ``self._lock``.""" + if input_tokens is not None: + existing = record.attributes.get( + GenAI.GEN_AI_USAGE_INPUT_TOKENS, 0 ) - parent_state.children.append(run_id) - else: - # top-level or missing parent - span = self._tracer.start_span(name=span_name, kind=kind) - set_span_in_context(span) + new_val = (existing or 0) + input_tokens + record.span.set_attribute(GenAI.GEN_AI_USAGE_INPUT_TOKENS, new_val) + record.attributes[GenAI.GEN_AI_USAGE_INPUT_TOKENS] = new_val + if output_tokens is not None: + existing = record.attributes.get( + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, 0 + ) + new_val = (existing or 0) + output_tokens + record.span.set_attribute( + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, new_val + ) + record.attributes[GenAI.GEN_AI_USAGE_OUTPUT_TOKENS] = new_val - span_state = _SpanState(span=span) - self.spans[run_id] = span_state + def accumulate_usage_to_parent( + self, + record: SpanRecord, + input_tokens: Optional[int], + output_tokens: Optional[int], + ) -> None: + """Propagate token usage from a model span to its parent agent span.""" + if input_tokens is None and output_tokens is None: + return - return span + parent_key = record.parent_run_id + visited: set[str] = set() + with self._lock: + while parent_key: + if parent_key in visited: + break + visited.add(parent_key) + parent_record = self._spans.get(parent_key) + if not parent_record: + break + if parent_record.operation == OP_INVOKE_AGENT: + self._accumulate_on_record( + parent_record, input_tokens, output_tokens + ) + break + parent_key = parent_record.parent_run_id - def create_chat_span( + def accumulate_llm_usage_to_agent( self, - run_id: UUID, - parent_run_id: Optional[UUID], - request_model: str, - ) -> Span: - span = self._create_span( - run_id=run_id, - parent_run_id=parent_run_id, - span_name=f"{GenAI.GenAiOperationNameValues.CHAT.value} {request_model}", - kind=SpanKind.CLIENT, - ) - span.set_attribute( - GenAI.GEN_AI_OPERATION_NAME, - GenAI.GenAiOperationNameValues.CHAT.value, - ) - if request_model: - span.set_attribute(GenAI.GEN_AI_REQUEST_MODEL, request_model) - - return span - - def end_span(self, run_id: UUID) -> None: - state = self.spans[run_id] - for child_id in state.children: - child_state = self.spans.get(child_id) - if child_state: - child_state.span.end() - del self.spans[child_id] - state.span.end() - del self.spans[run_id] - - def get_span(self, run_id: UUID) -> Optional[Span]: - state = self.spans.get(run_id) - return state.span if state else None - - def handle_error(self, error: BaseException, run_id: UUID): - span = self.get_span(run_id) - if span is None: - # If the span does not exist, we cannot set the error status + parent_run_id: Optional[str | object], + input_tokens: Optional[int], + output_tokens: Optional[int], + ) -> None: + """Propagate LLM token usage up to the nearest agent span. + + Unlike ``accumulate_usage_to_parent`` (which starts from a + ``SpanRecord``'s parent), this resolves through ignored runs first + and then walks up to find the nearest ``invoke_agent`` ancestor. + Designed to be called from ``on_llm_end`` where the LLM span is + managed by :class:`TelemetryHandler`, not :class:`_SpanManager`. + """ + if input_tokens is None and output_tokens is None: return - span.set_status(Status(StatusCode.ERROR, str(error))) - span.set_attribute( - error_attributes.ERROR_TYPE, type(error).__qualname__ - ) - self.end_span(run_id) + + prid = str(parent_run_id) if parent_run_id is not None else None + resolved = self._resolve_parent_id(prid) + if resolved is None: + return + + visited: set[str] = set() + current = resolved + with self._lock: + while current: + if current in visited: + break + visited.add(current) + record = self._spans.get(current) + if not record: + break + if record.operation == OP_INVOKE_AGENT: + self._accumulate_on_record( + record, input_tokens, output_tokens + ) + return + current = record.parent_run_id + + def nearest_agent_parent(self, record: SpanRecord) -> Optional[str]: + """Walk up the parent chain to find the nearest invoke_agent ancestor. + + Returns the run_id of the nearest agent span, or *None*. + """ + parent_key = record.parent_run_id + visited: set[str] = set() + with self._lock: + while parent_key: + if parent_key in visited: + break + visited.add(parent_key) + parent_record = self._spans.get(parent_key) + if not parent_record: + break + if parent_record.operation == OP_INVOKE_AGENT: + return parent_key + parent_key = parent_record.parent_run_id + return None + + # ------------------------------------------------------------------ + # LangGraph goto support + # ------------------------------------------------------------------ + + def push_goto_parent(self, thread_key: str, parent_run_id: str) -> None: + """Push a goto parent onto the per-thread stack.""" + with self._lock: + self._goto_parent_stack.setdefault(thread_key, []).append( + parent_run_id + ) + + def pop_goto_parent(self, thread_key: str) -> Optional[str]: + """Pop and return the most recent goto parent, or *None*.""" + with self._lock: + stack = self._goto_parent_stack.get(thread_key) + if stack: + val = stack.pop() + if not stack: + del self._goto_parent_stack[thread_key] + return val + return None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _resolve_parent_id( + self, parent_run_id: Optional[str] + ) -> Optional[str]: + """Walk through ignored runs to find the nearest visible ancestor.""" + if parent_run_id is None: + return None + + visited: set[str] = set() + current = parent_run_id + with self._lock: + while current in self._ignored_runs: + if current in visited: + # Cycle guard. + return None + visited.add(current) + current = self._run_parent_override.get(current) + if current is None: + return None + return current diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/utils.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/utils.py new file mode 100644 index 0000000000..23345b38cd --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/utils.py @@ -0,0 +1,354 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Mapping, Optional, cast +from urllib.parse import urlparse + +from opentelemetry.context import attach, detach +from opentelemetry.propagate import extract + +# Provider name constants aligned with OpenTelemetry semantic conventions +_PROVIDER_AZURE_OPENAI = "azure.ai.openai" +_PROVIDER_OPENAI = "openai" +_PROVIDER_AWS_BEDROCK = "aws.bedrock" +_PROVIDER_GCP_GEN_AI = "gcp.gen_ai" +_PROVIDER_ANTHROPIC = "anthropic" +_PROVIDER_COHERE = "cohere" +_PROVIDER_OLLAMA = "ollama" + +# Mapping from LangChain ls_provider values to normalized provider names +_LS_PROVIDER_MAP: Dict[str, str] = { + "azure": _PROVIDER_AZURE_OPENAI, + "azure_openai": _PROVIDER_AZURE_OPENAI, + "azure-openai": _PROVIDER_AZURE_OPENAI, + "openai": _PROVIDER_OPENAI, + "github": _PROVIDER_AZURE_OPENAI, + "google": _PROVIDER_GCP_GEN_AI, + "google_genai": _PROVIDER_GCP_GEN_AI, + "anthropic": _PROVIDER_ANTHROPIC, + "cohere": _PROVIDER_COHERE, + "ollama": _PROVIDER_OLLAMA, +} + +# Substrings in base_url mapped to provider names (checked in order) +_URL_PROVIDER_RULES: List[tuple[str, str]] = [ + ("azure", _PROVIDER_AZURE_OPENAI), + ("openai", _PROVIDER_OPENAI), + ("ollama", _PROVIDER_OLLAMA), + ("bedrock", _PROVIDER_AWS_BEDROCK), + ("amazonaws.com", _PROVIDER_AWS_BEDROCK), + ("anthropic", _PROVIDER_ANTHROPIC), + ("googleapis", _PROVIDER_GCP_GEN_AI), +] + +# Substrings in serialized class identifiers mapped to provider names +_CLASS_PROVIDER_RULES: List[tuple[str, str]] = [ + ("ChatOpenAI", _PROVIDER_OPENAI), + ("ChatBedrock", _PROVIDER_AWS_BEDROCK), + ("Bedrock", _PROVIDER_AWS_BEDROCK), + ("ChatAnthropic", _PROVIDER_ANTHROPIC), + ("ChatGoogleGenerativeAI", _PROVIDER_GCP_GEN_AI), + ("ChatVertexAI", _PROVIDER_GCP_GEN_AI), + ("Ollama", _PROVIDER_OLLAMA), +] + + +def _as_dict(value: Any) -> Optional[Dict[str, Any]]: + if isinstance(value, dict): + return cast(Dict[str, Any], value) + return None + + +def _get_class_identifier(serialized: Dict[str, Any]) -> Optional[str]: + """Extract a class identifier string from serialized data. + + Checks ``serialized["id"]`` (a list of path components) first, + then falls back to ``serialized["name"]``. + """ + id_parts = serialized.get("id") + if isinstance(id_parts, list) and id_parts: + return str(cast(List[Any], id_parts)[-1]) + name = serialized.get("name") + if name: + return str(name) + return None + + +def _infer_from_ls_provider( + metadata: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from LangChain's ls_provider metadata hint.""" + if metadata is None: + return None + ls_provider = metadata.get("ls_provider") + if ls_provider is None: + return None + + ls_lower = str(ls_provider).lower() + + # Direct map lookup + mapped = _LS_PROVIDER_MAP.get(ls_lower) + if mapped is not None: + return mapped + + # Substring check for bedrock variants (e.g. "amazon_bedrock") + if "bedrock" in ls_lower: + return _PROVIDER_AWS_BEDROCK + + return None + + +def _infer_from_url( + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from a base URL in invocation params.""" + if invocation_params is None: + return None + base_url = invocation_params.get("base_url") or invocation_params.get( + "openai_api_base" + ) + if not base_url: + return None + + url_lower = str(base_url).lower() + for substring, provider in _URL_PROVIDER_RULES: + if substring in url_lower: + return provider + return None + + +def _infer_from_class( + serialized: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from the serialized class name or id.""" + if serialized is None: + return None + class_id = _get_class_identifier(serialized) + if class_id is None: + return None + + for substring, provider in _CLASS_PROVIDER_RULES: + if substring in class_id: + return provider + return None + + +def _infer_from_kwargs( + serialized: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer provider from serialized kwargs (endpoint fields).""" + if serialized is None: + return None + ser_kwargs = _as_dict(serialized.get("kwargs")) + if ser_kwargs is None: + return None + + if ser_kwargs.get("azure_endpoint"): + return _PROVIDER_AZURE_OPENAI + + openai_api_base = ser_kwargs.get("openai_api_base") + if isinstance(openai_api_base, str) and openai_api_base.endswith( + ".azure.com" + ): + return _PROVIDER_AZURE_OPENAI + + return None + + +def infer_provider_name( + serialized: Optional[Dict[str, Any]], + metadata: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Infer the GenAI provider name from available LangChain callback data. + + Sources are checked in decreasing order of specificity: + 1. ``metadata["ls_provider"]`` — LangChain's own provider hint + 2. ``invocation_params["base_url"]`` — URL-based inference + 3. ``serialized["id"]`` / ``serialized["name"]`` — class name based + 4. ``serialized["kwargs"]`` — endpoint-based + + Returns ``None`` if the provider cannot be determined. + """ + return ( + _infer_from_ls_provider(metadata) + or _infer_from_url(invocation_params) + or _infer_from_class(serialized) + or _infer_from_kwargs(serialized) + ) + + +def _extract_url( + serialized: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Find the first available URL from invocation params or serialized kwargs.""" + if invocation_params: + url = invocation_params.get("base_url") or invocation_params.get( + "openai_api_base" + ) + if url: + return str(url) + + if serialized: + ser_kwargs = _as_dict(serialized.get("kwargs")) + if ser_kwargs is not None: + url = ser_kwargs.get("openai_api_base") or ser_kwargs.get( + "azure_endpoint" + ) + if url: + return str(url) + + return None + + +def infer_server_address( + serialized: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[str]: + """Extract the server hostname from available URL sources. + + Checks ``invocation_params["base_url"]``, + ``invocation_params["openai_api_base"]``, + ``serialized["kwargs"]["openai_api_base"]``, and + ``serialized["kwargs"]["azure_endpoint"]``. + """ + url = _extract_url(serialized, invocation_params) + if url is None: + return None + + parsed = urlparse(url) + return parsed.hostname or None + + +def infer_server_port( + serialized: Optional[Dict[str, Any]], + invocation_params: Optional[Dict[str, Any]], +) -> Optional[int]: + """Extract the server port from available URL sources. + + Only returns a value when the port is explicitly specified in the URL + (not inferred default ports). + """ + url = _extract_url(serialized, invocation_params) + if url is None: + return None + + parsed = urlparse(url) + return parsed.port # None when port is not explicitly set + + +_logger = logging.getLogger(__name__) + +# Header keys recognised by the W3C Trace Context specification. +_TRACE_HEADER_KEYS = ("traceparent", "tracestate") + +# Common nested attribute names where HTTP / trace headers may reside. +_NESTED_HEADER_KEYS = ( + "headers", + "header", + "http_headers", + "request_headers", + "metadata", + "request", +) + + +def extract_trace_headers(container: Any) -> Optional[Dict[str, str]]: + """Extract W3C trace context headers from a container. + + Looks for traceparent/tracestate at the top level and in common + nested locations (headers, metadata, request, etc.). + """ + container_dict = _as_dict(container) + if container_dict is None: + return None + + # 1. Check top-level keys + found: Dict[str, str] = {} + for key in _TRACE_HEADER_KEYS: + value = container_dict.get(key) + if isinstance(value, str) and value: + found[key] = value + + if found: + return found + + # 2. Check nested containers + for nested_key in _NESTED_HEADER_KEYS: + nested = _as_dict(container_dict.get(nested_key)) + if nested is not None: + for key in _TRACE_HEADER_KEYS: + value = nested.get(key) + if isinstance(value, str) and value: + found[key] = value + if found: + return found + + return None + + +@contextmanager +def propagated_context( + headers: Optional[Mapping[str, str]], +) -> Iterator[None]: + """Temporarily adopt an upstream trace context extracted from W3C headers. + + Uses OpenTelemetry's extract() to deserialize W3C trace context, + then attaches it for the duration of the context manager. + """ + if not headers: + yield + return + + token = None + try: + ctx = extract(headers) + token = attach(ctx) + except Exception: # noqa: BLE001 + _logger.debug( + "Failed to extract/attach propagation context", exc_info=True + ) + + try: + yield + finally: + if token is not None: + try: + detach(token) + except Exception: # noqa: BLE001 + _logger.debug( + "Failed to detach propagation context", exc_info=True + ) + + +def extract_propagation_context( + metadata: Optional[Dict[str, Any]], + inputs: Any, + kwargs: Dict[str, Any], +) -> Optional[Dict[str, str]]: + """Try to extract W3C trace headers from callback arguments. + + Checks metadata, inputs, and kwargs in order. + """ + for source in (metadata, inputs, kwargs): + if source is not None: + headers = extract_trace_headers(source) + if headers: + return headers + return None diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/conftest.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/conftest.py index 7d608ea8b3..6b0598309e 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/conftest.py +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/conftest.py @@ -9,6 +9,7 @@ from langchain_aws import ChatBedrock from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI +from vcr import VCR from opentelemetry.instrumentation.langchain import LangChainInstrumentor from opentelemetry.sdk._logs import LoggerProvider @@ -112,12 +113,24 @@ def fixture_meter_provider(metric_reader): @pytest.fixture(scope="function") def start_instrumentation( + request, + monkeypatch, tracer_provider, meter_provider, logger_provider, ): + if "capture_content" in request.fixturenames: + monkeypatch.setenv( + "OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental" + ) + monkeypatch.setenv( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + request.getfixturevalue("capture_content"), + ) + instrumentor = LangChainInstrumentor() instrumentor.instrument( + skip_dep_check=True, tracer_provider=tracer_provider, meter_provider=meter_provider, logger_provider=logger_provider, @@ -208,10 +221,26 @@ def deserialize(cassette_string): return yaml.load(cassette_string, Loader=yaml.Loader) -@pytest.fixture(scope="module", autouse=True) -def fixture_vcr(vcr): - vcr.register_serializer("yaml", PrettyPrintJSONBody) - return vcr +@pytest.fixture(scope="function", name="vcr") +def fixture_vcr(vcr_config, record_mode): + """Provide a configurable VCR object for tests that call use_cassette(). + + pytest-recording's built-in ``vcr`` fixture now yields an active cassette + instead of the older ``VCR`` object. LangChain's tests still use the + explicit ``with vcr.use_cassette(...)`` style, so this local fixture + restores that interface while keeping the same cassette directory and + record mode behavior. + """ + vcr_instance = VCR( + path_transformer=VCR.ensure_suffix(".yaml"), + cassette_library_dir=os.path.join( + os.path.dirname(__file__), "cassettes" + ), + record_mode=record_mode, + **vcr_config, + ) + vcr_instance.register_serializer("yaml", PrettyPrintJSONBody) + return vcr_instance def scrub_response_headers(response): diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_agent_lifecycle.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_agent_lifecycle.py new file mode 100644 index 0000000000..b8b0abe940 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_agent_lifecycle.py @@ -0,0 +1,185 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import MagicMock +from uuid import uuid4 + +from langchain_core.agents import AgentAction, AgentFinish + +from opentelemetry.instrumentation.langchain.callback_handler import ( + OpenTelemetryLangChainCallbackHandler, +) +from opentelemetry.instrumentation.langchain.span_manager import SpanRecord +from opentelemetry.trace.status import StatusCode + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_handler(span_manager=None): + telemetry_handler = MagicMock() + telemetry_handler.meter = MagicMock() + return OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=span_manager, + ) + + +def _make_record(run_id=None): + record = MagicMock(spec=SpanRecord) + record.run_id = str(run_id or uuid4()) + record.stash = {} + record.span = MagicMock() + return record + + +# ------------------------------------------------------------------ +# on_agent_action +# ------------------------------------------------------------------ + + +class TestOnAgentAction: + def test_stashes_action_on_parent_pending_actions(self): + run_id = uuid4() + parent_run_id = uuid4() + parent_key = str(parent_run_id) + record = _make_record(parent_run_id) + + span_manager = MagicMock() + span_manager.resolve_parent_id.return_value = parent_key + span_manager.get_record.return_value = record + + handler = _make_handler(span_manager) + + action = AgentAction( + tool="search", tool_input={"query": "weather"}, log="Searching…" + ) + handler.on_agent_action( + action, run_id=run_id, parent_run_id=parent_run_id + ) + + span_manager.resolve_parent_id.assert_called_once_with(parent_run_id) + span_manager.get_record.assert_called_once_with(parent_key) + + pending = record.stash["pending_actions"] + assert str(run_id) in pending + entry = pending[str(run_id)] + assert entry["tool"] == "search" + assert entry["tool_input"] == {"query": "weather"} + assert entry["log"] == "Searching…" + + def test_noop_when_span_manager_is_none(self): + handler = _make_handler(span_manager=None) + action = AgentAction(tool="t", tool_input="i", log="l") + # Should not raise + handler.on_agent_action(action, run_id=uuid4(), parent_run_id=uuid4()) + + def test_noop_when_parent_run_id_is_none(self): + span_manager = MagicMock() + handler = _make_handler(span_manager) + + action = AgentAction(tool="t", tool_input="i", log="l") + handler.on_agent_action(action, run_id=uuid4(), parent_run_id=None) + + span_manager.resolve_parent_id.assert_not_called() + span_manager.get_record.assert_not_called() + + def test_noop_when_parent_record_not_found(self): + run_id = uuid4() + parent_run_id = uuid4() + parent_key = str(parent_run_id) + + span_manager = MagicMock() + span_manager.resolve_parent_id.return_value = parent_key + span_manager.get_record.return_value = None + + handler = _make_handler(span_manager) + + action = AgentAction(tool="t", tool_input="i", log="l") + handler.on_agent_action( + action, run_id=run_id, parent_run_id=parent_run_id + ) + + # resolve_parent_id was called, but nothing was stashed + span_manager.resolve_parent_id.assert_called_once_with(parent_run_id) + span_manager.get_record.assert_called_once_with(parent_key) + + +# ------------------------------------------------------------------ +# on_agent_finish +# ------------------------------------------------------------------ + + +class TestOnAgentFinish: + def test_sets_output_messages_and_ok_status(self): + run_id = uuid4() + record = _make_record(run_id) + + span_manager = MagicMock() + span_manager.get_record.return_value = record + + handler = _make_handler(span_manager) + + return_values = {"output": "The weather is sunny."} + finish = AgentFinish(return_values=return_values, log="done") + handler.on_agent_finish(finish, run_id=run_id) + + record.span.set_attribute.assert_called_once() + attr_name, attr_value = record.span.set_attribute.call_args[0] + assert "output" in attr_name.lower() or "message" in attr_name.lower() + assert json.loads(attr_value) == return_values + + record.span.set_status.assert_called_once() + status_arg = record.span.set_status.call_args[0][0] + assert status_arg.status_code is StatusCode.OK + span_manager.end_span.assert_called_once_with(run_id) + + def test_ok_status_when_no_return_values(self): + run_id = uuid4() + record = _make_record(run_id) + + span_manager = MagicMock() + span_manager.get_record.return_value = record + + handler = _make_handler(span_manager) + + finish = AgentFinish(return_values={}, log="") + handler.on_agent_finish(finish, run_id=run_id) + + record.span.set_attribute.assert_not_called() + record.span.set_status.assert_called_once() + status_arg = record.span.set_status.call_args[0][0] + assert status_arg.status_code is StatusCode.OK + span_manager.end_span.assert_called_once_with(run_id) + + def test_noop_when_span_manager_is_none(self): + handler = _make_handler(span_manager=None) + finish = AgentFinish(return_values={"output": "x"}, log="") + # Should not raise + handler.on_agent_finish(finish, run_id=uuid4()) + + def test_noop_when_record_not_found(self): + run_id = uuid4() + span_manager = MagicMock() + span_manager.get_record.return_value = None + + handler = _make_handler(span_manager) + + finish = AgentFinish(return_values={"output": "x"}, log="") + handler.on_agent_finish(finish, run_id=run_id) + + span_manager.get_record.assert_called_once_with(str(run_id)) + span_manager.end_span.assert_not_called() diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_chain_callbacks.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_chain_callbacks.py new file mode 100644 index 0000000000..d1c03721dd --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_chain_callbacks.py @@ -0,0 +1,456 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for on_chain_start / on_chain_end / on_chain_error callbacks.""" + +from __future__ import annotations + +from unittest import mock +from uuid import uuid4 + +import pytest + +from opentelemetry.instrumentation.langchain.callback_handler import ( + OpenTelemetryLangChainCallbackHandler, +) +from opentelemetry.instrumentation.langchain.operation_mapping import ( + OperationName, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + SpanRecord, + _SpanManager, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.trace.status import StatusCode + +# --------------------------------------------------------------------------- +# Helpers & fixtures +# --------------------------------------------------------------------------- + + +def _make_mock_span(): + """Create a mock span with the interface used by the callback handler.""" + span = mock.MagicMock() + span.is_recording.return_value = True + span.set_attribute = mock.MagicMock() + span.set_status = mock.MagicMock() + return span + + +def _make_span_record(run_id, span=None, operation="invoke_agent"): + return SpanRecord( + run_id=str(run_id), + span=span or _make_mock_span(), + operation=operation, + ) + + +@pytest.fixture +def mock_span_manager(): + mgr = mock.MagicMock(spec=_SpanManager) + # By default runs are not ignored. + mgr.is_ignored.return_value = False + mgr.resolve_parent_id.side_effect = lambda parent_run_id: ( + str(parent_run_id) if parent_run_id is not None else None + ) + # start_span returns a SpanRecord by default. + mgr.start_span.return_value = _make_span_record(uuid4()) + return mgr + + +@pytest.fixture +def handler(mock_span_manager): + telemetry_handler = mock.MagicMock() + return OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=mock_span_manager, + ) + + +# --------------------------------------------------------------------------- +# on_chain_start +# --------------------------------------------------------------------------- + + +class TestOnChainStartInvokeAgent: + """on_chain_start creates invoke_agent spans for agent signals.""" + + def test_creates_span_when_otel_agent_span_true( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={"otel_agent_span": True}, + ) + mock_span_manager.start_span.assert_called_once() + call_kwargs = mock_span_manager.start_span.call_args + assert call_kwargs.kwargs["operation"] == OperationName.INVOKE_AGENT + assert call_kwargs.kwargs["name"].startswith("invoke_agent") + + def test_creates_span_when_agent_name_in_metadata( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={"agent_name": "planner"}, + ) + mock_span_manager.start_span.assert_called_once() + call_kwargs = mock_span_manager.start_span.call_args + assert call_kwargs.kwargs["operation"] == OperationName.INVOKE_AGENT + assert "planner" in call_kwargs.kwargs["name"] + + +class TestOnChainStartInvokeWorkflow: + """on_chain_start creates invoke_workflow spans for top-level graphs.""" + + def test_creates_workflow_span_for_top_level_langgraph( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_chain_start( + serialized={"name": "LangGraph"}, + inputs={}, + run_id=run_id, + parent_run_id=None, + metadata={}, + ) + mock_span_manager.start_span.assert_called_once() + call_kwargs = mock_span_manager.start_span.call_args + assert call_kwargs.kwargs["operation"] == OperationName.INVOKE_WORKFLOW + + +class TestOnChainStartSuppression: + """on_chain_start suppresses (ignores) known noise chains.""" + + def test_suppresses_start_node(self, handler, mock_span_manager): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={"langgraph_node": "__start__"}, + ) + mock_span_manager.ignore_run.assert_called_once() + mock_span_manager.start_span.assert_not_called() + + def test_suppresses_middleware_prefix(self, handler, mock_span_manager): + run_id = uuid4() + handler.on_chain_start( + serialized={"name": "Middleware.auth"}, + inputs={}, + run_id=run_id, + metadata={"langgraph_node": "Middleware.auth"}, + name="Middleware.auth", + ) + mock_span_manager.ignore_run.assert_called_once() + mock_span_manager.start_span.assert_not_called() + + def test_suppresses_unclassified_chain(self, handler, mock_span_manager): + """A generic chain with no agent/workflow signals is suppressed.""" + run_id = uuid4() + parent_id = uuid4() + handler.on_chain_start( + serialized={"name": "RunnableSequence"}, + inputs={}, + run_id=run_id, + parent_run_id=parent_id, + metadata={}, + ) + mock_span_manager.ignore_run.assert_called_once() + mock_span_manager.start_span.assert_not_called() + + +class TestOnChainStartAttributes: + """on_chain_start sets the correct span attributes.""" + + def test_sets_agent_name_attribute(self, handler, mock_span_manager): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={"otel_agent_span": True, "agent_name": "researcher"}, + ) + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert attrs[GenAI.GEN_AI_AGENT_NAME] == "researcher" + + def test_sets_agent_id_attribute(self, handler, mock_span_manager): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={ + "otel_agent_span": True, + "agent_id": "agent-42", + }, + ) + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert attrs[GenAI.GEN_AI_AGENT_ID] == "agent-42" + + def test_sets_conversation_id_from_thread_id( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={ + "otel_agent_span": True, + "thread_id": "thread-abc", + }, + ) + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert attrs[GenAI.GEN_AI_CONVERSATION_ID] == "thread-abc" + + def test_sets_conversation_id_from_session_id( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={ + "otel_agent_span": True, + "session_id": "sess-123", + }, + ) + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert attrs[GenAI.GEN_AI_CONVERSATION_ID] == "sess-123" + + def test_sets_conversation_id_from_conversation_id( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=run_id, + metadata={ + "otel_agent_span": True, + "conversation_id": "conv-789", + }, + ) + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert attrs[GenAI.GEN_AI_CONVERSATION_ID] == "conv-789" + + +class TestOnChainStartContentRecording: + """on_chain_start records input messages when content policy allows.""" + + def test_records_input_messages_when_policy_allows( + self, handler, mock_span_manager + ): + run_id = uuid4() + with mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.get_content_policy" + ) as mock_policy: + policy = mock.MagicMock() + policy.record_content = True + policy.should_record_content_on_spans = True + mock_policy.return_value = policy + + handler.on_chain_start( + serialized={}, + inputs={"input": "Hello, agent!"}, + run_id=run_id, + metadata={"otel_agent_span": True}, + ) + + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert GenAI.GEN_AI_INPUT_MESSAGES in attrs + + def test_does_not_record_input_when_policy_disallows( + self, handler, mock_span_manager + ): + run_id = uuid4() + with mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.get_content_policy" + ) as mock_policy: + policy = mock.MagicMock() + policy.record_content = False + policy.should_record_content_on_spans = False + mock_policy.return_value = policy + + handler.on_chain_start( + serialized={}, + inputs={"input": "Hello, agent!"}, + run_id=run_id, + metadata={"otel_agent_span": True}, + ) + + call_kwargs = mock_span_manager.start_span.call_args + attrs = call_kwargs.kwargs["attributes"] + assert GenAI.GEN_AI_INPUT_MESSAGES not in attrs + + +class TestOnChainStartNoSpanManager: + """on_chain_start is a no-op when span_manager is None.""" + + def test_returns_early_when_no_span_manager(self): + telemetry_handler = mock.MagicMock() + handler = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=None, + ) + # Should not raise. + handler.on_chain_start( + serialized={}, + inputs={}, + run_id=uuid4(), + metadata={"otel_agent_span": True}, + ) + + +# --------------------------------------------------------------------------- +# on_chain_end +# --------------------------------------------------------------------------- + + +class TestOnChainEnd: + """on_chain_end ends the span with OK status.""" + + def test_ends_span_with_ok_status(self, handler, mock_span_manager): + run_id = uuid4() + record = _make_span_record(run_id) + mock_span_manager.get_record.return_value = record + + handler.on_chain_end( + outputs={"output": "done"}, + run_id=run_id, + ) + + mock_span_manager.end_span.assert_called_once_with( + run_id, status=StatusCode.OK + ) + + def test_sets_output_messages_when_policy_allows( + self, handler, mock_span_manager + ): + run_id = uuid4() + span = _make_mock_span() + record = _make_span_record(run_id, span=span) + mock_span_manager.get_record.return_value = record + + with mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.get_content_policy" + ) as mock_policy: + policy = mock.MagicMock() + policy.record_content = True + policy.should_record_content_on_spans = True + mock_policy.return_value = policy + + handler.on_chain_end( + outputs={"output": "Agent result"}, + run_id=run_id, + ) + + span.set_attribute.assert_any_call( + GenAI.GEN_AI_OUTPUT_MESSAGES, mock.ANY + ) + + def test_skips_ignored_runs(self, handler, mock_span_manager): + run_id = uuid4() + mock_span_manager.is_ignored.return_value = True + + handler.on_chain_end( + outputs={"output": "done"}, + run_id=run_id, + ) + + mock_span_manager.end_span.assert_not_called() + mock_span_manager.clear_ignored_run.assert_called_once_with(run_id) + + def test_returns_early_when_no_record(self, handler, mock_span_manager): + run_id = uuid4() + mock_span_manager.get_record.return_value = None + + handler.on_chain_end( + outputs={"output": "done"}, + run_id=run_id, + ) + + mock_span_manager.end_span.assert_not_called() + + def test_returns_early_when_no_span_manager(self): + telemetry_handler = mock.MagicMock() + handler = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=None, + ) + # Should not raise. + handler.on_chain_end( + outputs={"output": "done"}, + run_id=uuid4(), + ) + + +# --------------------------------------------------------------------------- +# on_chain_error +# --------------------------------------------------------------------------- + + +class TestOnChainError: + """on_chain_error ends the span with error status.""" + + def test_ends_span_with_error(self, handler, mock_span_manager): + run_id = uuid4() + + error = RuntimeError("something went wrong") + handler.on_chain_error( + error=error, + run_id=run_id, + ) + + mock_span_manager.end_span.assert_called_once_with(run_id, error=error) + + def test_skips_ignored_runs(self, handler, mock_span_manager): + run_id = uuid4() + mock_span_manager.is_ignored.return_value = True + + handler.on_chain_error( + error=RuntimeError("boom"), + run_id=run_id, + ) + + mock_span_manager.end_span.assert_not_called() + mock_span_manager.clear_ignored_run.assert_called_once_with(run_id) + + def test_returns_early_when_no_span_manager(self): + telemetry_handler = mock.MagicMock() + handler = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=None, + ) + # Should not raise. + handler.on_chain_error( + error=RuntimeError("boom"), + run_id=uuid4(), + ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_content_recording.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_content_recording.py new file mode 100644 index 0000000000..df7864e355 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_content_recording.py @@ -0,0 +1,263 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from opentelemetry.instrumentation._semconv import ( + _OpenTelemetrySemanticConventionStability, +) +from opentelemetry.instrumentation.langchain.content_recording import ( + ContentPolicy, + should_record_messages, + should_record_retriever_content, + should_record_system_instructions, + should_record_tool_content, +) +from opentelemetry.util.genai.types import ContentCapturingMode + + +@pytest.fixture(autouse=True) +def _reset_semconv_stability(monkeypatch): + """Reset semconv stability cache so each test can set its own env vars.""" + orig_initialized = _OpenTelemetrySemanticConventionStability._initialized + orig_mapping = _OpenTelemetrySemanticConventionStability._OTEL_SEMCONV_STABILITY_SIGNAL_MAPPING.copy() + + _OpenTelemetrySemanticConventionStability._initialized = False + _OpenTelemetrySemanticConventionStability._OTEL_SEMCONV_STABILITY_SIGNAL_MAPPING = {} + + monkeypatch.delenv("OTEL_SEMCONV_STABILITY_OPT_IN", raising=False) + monkeypatch.delenv( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", raising=False + ) + monkeypatch.delenv("OTEL_INSTRUMENTATION_GENAI_EMIT_EVENT", raising=False) + + yield + + _OpenTelemetrySemanticConventionStability._initialized = orig_initialized + _OpenTelemetrySemanticConventionStability._OTEL_SEMCONV_STABILITY_SIGNAL_MAPPING = orig_mapping + + +def _enter_experimental(monkeypatch, capture_mode): + """Set env vars for experimental mode and re-initialize stability.""" + monkeypatch.setenv( + "OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental" + ) + monkeypatch.setenv( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", capture_mode + ) + _OpenTelemetrySemanticConventionStability._initialize() + + +# --------------------------------------------------------------------------- +# ContentPolicy – experimental mode with each ContentCapturingMode +# --------------------------------------------------------------------------- + + +class TestContentPolicySpanOnly: + """SPAN_ONLY: content on spans, no events.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().should_record_content_on_spans is True + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().should_emit_events is False + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().record_content is True + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + assert ContentPolicy().mode == ContentCapturingMode.SPAN_ONLY + + +class TestContentPolicyEventOnly: + """EVENT_ONLY: events enabled without duplicating content on spans.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().should_record_content_on_spans is False + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().should_emit_events is True + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().record_content is True + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "EVENT_ONLY") + assert ContentPolicy().mode == ContentCapturingMode.EVENT_ONLY + + +class TestContentPolicySpanAndEvent: + """SPAN_AND_EVENT: both spans and events active.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().should_record_content_on_spans is True + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().should_emit_events is True + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().record_content is True + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_AND_EVENT") + assert ContentPolicy().mode == ContentCapturingMode.SPAN_AND_EVENT + + +class TestContentPolicyNoContent: + """NO_CONTENT: nothing recorded.""" + + def test_should_record_content_on_spans(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().should_record_content_on_spans is False + + def test_should_emit_events(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().should_emit_events is False + + def test_record_content(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().record_content is False + + def test_mode(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + assert ContentPolicy().mode == ContentCapturingMode.NO_CONTENT + + +class TestContentPolicyRecordContentCombined: + """record_content is True when either spans or events are enabled.""" + + @pytest.mark.parametrize( + "capture_mode, expected", + [ + ("SPAN_ONLY", True), + ("EVENT_ONLY", True), + ("SPAN_AND_EVENT", True), + ("NO_CONTENT", False), + ], + ) + def test_record_content(self, monkeypatch, capture_mode, expected): + _enter_experimental(monkeypatch, capture_mode) + assert ContentPolicy().record_content is expected + + +# --------------------------------------------------------------------------- +# ContentPolicy – outside experimental mode +# --------------------------------------------------------------------------- + + +class TestContentPolicyNonExperimental: + """Without experimental opt-in everything is disabled.""" + + def test_should_record_content_on_spans(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().should_record_content_on_spans is False + + def test_should_emit_events(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().should_emit_events is False + + def test_record_content(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().record_content is False + + def test_mode_is_no_content(self): + _OpenTelemetrySemanticConventionStability._initialize() + assert ContentPolicy().mode == ContentCapturingMode.NO_CONTENT + + +# --------------------------------------------------------------------------- +# Helper functions – delegates to policy.should_record_content_on_spans +# --------------------------------------------------------------------------- + + +class _StubPolicy: + """Minimal stand-in for ContentPolicy with a fixed boolean.""" + + def __init__(self, value: bool): + self.should_record_content_on_spans = value + + +class TestShouldRecordMessages: + def test_true_when_policy_enabled(self): + assert should_record_messages(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_messages(_StubPolicy(False)) is False + + +class TestShouldRecordToolContent: + def test_true_when_policy_enabled(self): + assert should_record_tool_content(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_tool_content(_StubPolicy(False)) is False + + +class TestShouldRecordRetrieverContent: + def test_true_when_policy_enabled(self): + assert should_record_retriever_content(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_retriever_content(_StubPolicy(False)) is False + + +class TestShouldRecordSystemInstructions: + def test_true_when_policy_enabled(self): + assert should_record_system_instructions(_StubPolicy(True)) is True + + def test_false_when_policy_disabled(self): + assert should_record_system_instructions(_StubPolicy(False)) is False + + +# --------------------------------------------------------------------------- +# Helper functions – integration with real ContentPolicy via env vars +# --------------------------------------------------------------------------- + + +class TestHelperFunctionsIntegration: + """Verify helpers produce correct results with a real ContentPolicy.""" + + def test_all_helpers_true_with_span_only(self, monkeypatch): + _enter_experimental(monkeypatch, "SPAN_ONLY") + policy = ContentPolicy() + assert should_record_messages(policy) is True + assert should_record_tool_content(policy) is True + assert should_record_retriever_content(policy) is True + assert should_record_system_instructions(policy) is True + + def test_all_helpers_false_with_no_content(self, monkeypatch): + _enter_experimental(monkeypatch, "NO_CONTENT") + policy = ContentPolicy() + assert should_record_messages(policy) is False + assert should_record_tool_content(policy) is False + assert should_record_retriever_content(policy) is False + assert should_record_system_instructions(policy) is False + + def test_all_helpers_false_outside_experimental(self): + _OpenTelemetrySemanticConventionStability._initialize() + policy = ContentPolicy() + assert should_record_messages(policy) is False + assert should_record_tool_content(policy) is False + assert should_record_retriever_content(policy) is False + assert should_record_system_instructions(policy) is False diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_e2e_scenarios.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_e2e_scenarios.py new file mode 100644 index 0000000000..ecc4e1cfa3 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_e2e_scenarios.py @@ -0,0 +1,656 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end scenario tests that simulate realistic LangChain/LangGraph +callback sequences. + +These tests feed events directly into the callback handler and verify +the resulting span structure managed by a real ``_SpanManager`` backed +by a mock tracer. No actual LangChain runtime is invoked. +""" + +from __future__ import annotations + +from unittest import mock +from uuid import uuid4 + +from opentelemetry.instrumentation.langchain.callback_handler import ( + OpenTelemetryLangChainCallbackHandler, +) +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_EXECUTE_TOOL, + OP_INVOKE_AGENT, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + _SpanManager, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.semconv.attributes import error_attributes +from opentelemetry.trace import SpanKind +from opentelemetry.trace.status import StatusCode + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _mock_tracer(): + """Return a mock Tracer whose ``start_span`` produces unique mock spans.""" + tracer = mock.MagicMock() + + def _new_span(**kwargs): + span = mock.MagicMock() + span.is_recording.return_value = True + span._name = kwargs.get("name", "unnamed") + span._attributes = {} + + original_set_attr = span.set_attribute + + def _track_attr(key, value): + span._attributes[key] = value + return original_set_attr(key, value) + + span.set_attribute = mock.MagicMock(side_effect=_track_attr) + return span + + tracer.start_span = mock.MagicMock(side_effect=_new_span) + return tracer + + +def _make_handler_and_manager(): + """Create a handler backed by a real ``_SpanManager`` with a mock tracer.""" + tracer = _mock_tracer() + span_manager = _SpanManager(tracer) + telemetry_handler = mock.MagicMock() + handler = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=span_manager, + ) + return handler, span_manager, tracer + + +def _agent_metadata(name: str, **extra) -> dict: + """Metadata dict that causes ``classify_chain_run`` to emit an agent span.""" + meta = {"otel_agent_span": True, "agent_name": name} + meta.update(extra) + return meta + + +def _get_span(span_manager: _SpanManager, run_id) -> mock.MagicMock: + """Return the mock span for *run_id*, which must still be active.""" + record = span_manager.get_record(run_id) + assert record is not None, f"No active record for {run_id}" + return record.span + + +def _get_set_attributes(span: mock.MagicMock) -> dict: + """Collect all attributes set on a mock span via ``set_attribute``.""" + return dict(span._attributes) + + +# ------------------------------------------------------------------ +# Scenario 1: Simple agent with tool call +# ------------------------------------------------------------------ + + +class TestSimpleAgentWithToolCall: + """Agent starts → LLM called → tool called → LLM called again → agent ends. + + Verifies: + - Agent span parents LLM and tool spans. + - Token usage accumulates on agent span. + """ + + def test_span_hierarchy_and_token_accumulation(self, monkeypatch): + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + tool_id = uuid4() + + # 1) Agent starts + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={"messages": []}, + run_id=agent_id, + parent_run_id=None, + metadata=_agent_metadata("my_agent"), + ) + + agent_record = sm.get_record(agent_id) + assert agent_record is not None + assert agent_record.operation == OP_INVOKE_AGENT + + # 2) First LLM call (child of agent via an ignored intermediate chain) + # LLM spans are managed by TelemetryHandler, but token accumulation + # flows through SpanManager. Simulate by calling + # accumulate_llm_usage_to_agent directly. + sm.accumulate_llm_usage_to_agent( + agent_id, input_tokens=100, output_tokens=50 + ) + + # 3) Tool call + handler.on_tool_start( + serialized={"name": "web_search", "description": "Search the web"}, + input_str="latest news", + run_id=tool_id, + parent_run_id=agent_id, + ) + + tool_record = sm.get_record(tool_id) + assert tool_record is not None + assert tool_record.operation == OP_EXECUTE_TOOL + assert tool_record.parent_run_id == str(agent_id) + + handler.on_tool_end(output="some results", run_id=tool_id) + assert sm.get_record(tool_id) is None # span ended, record removed + + # 4) Second LLM call + sm.accumulate_llm_usage_to_agent( + agent_id, input_tokens=200, output_tokens=80 + ) + + # 5) Agent ends + handler.on_chain_end(outputs={"output": "done"}, run_id=agent_id) + assert sm.get_record(agent_id) is None # span ended + + # Verify token accumulation on the agent span. + # Token accumulation calls set_attribute on agent_record.span. + agent_span_obj = agent_record.span + attrs = _get_set_attributes(agent_span_obj) + + assert attrs[GenAI.GEN_AI_USAGE_INPUT_TOKENS] == 300 # 100 + 200 + assert attrs[GenAI.GEN_AI_USAGE_OUTPUT_TOKENS] == 130 # 50 + 80 + + def test_tool_span_created_with_correct_kind(self, monkeypatch): + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start( + serialized={"name": "agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("agent"), + ) + + handler.on_tool_start( + serialized={"name": "calc"}, + input_str="2+2", + run_id=tool_id, + parent_run_id=agent_id, + ) + + # Agent span → SpanKind.CLIENT, Tool span → SpanKind.INTERNAL + calls = tracer.start_span.call_args_list + agent_call_kwargs = calls[0][1] + tool_call_kwargs = calls[1][1] + assert agent_call_kwargs["kind"] == SpanKind.CLIENT + assert tool_call_kwargs["kind"] == SpanKind.INTERNAL + + handler.on_tool_end(output="4", run_id=tool_id) + handler.on_chain_end(outputs={}, run_id=agent_id) + + +# ------------------------------------------------------------------ +# Scenario 2: Nested agents +# ------------------------------------------------------------------ + + +class TestNestedAgents: + """Outer agent → inner agent → LLM → inner agent ends → outer agent ends. + + Verifies: + - Inner agent is a child of outer agent. + - Token usage propagates up through both levels. + """ + + def test_nested_agent_parenting(self): + handler, sm, tracer = _make_handler_and_manager() + + outer_id = uuid4() + inner_id = uuid4() + + # 1) Outer agent starts + handler.on_chain_start( + serialized={"name": "outer_agent"}, + inputs={}, + run_id=outer_id, + parent_run_id=None, + metadata=_agent_metadata("outer_agent"), + ) + + # 2) Inner agent starts (child of outer) + handler.on_chain_start( + serialized={"name": "inner_agent"}, + inputs={}, + run_id=inner_id, + parent_run_id=outer_id, + metadata=_agent_metadata("inner_agent"), + ) + + inner_record = sm.get_record(inner_id) + assert inner_record is not None + assert inner_record.parent_run_id == str(outer_id) + + # The tracer should have been called with the outer's span as context + # for the inner span. + assert tracer.start_span.call_count == 2 + + # 3) LLM call inside inner agent — accumulate tokens on inner agent + sm.accumulate_llm_usage_to_agent( + inner_id, input_tokens=50, output_tokens=25 + ) + + inner_attrs = _get_set_attributes(inner_record.span) + assert inner_attrs[GenAI.GEN_AI_USAGE_INPUT_TOKENS] == 50 + assert inner_attrs[GenAI.GEN_AI_USAGE_OUTPUT_TOKENS] == 25 + + # 4) Inner agent ends + handler.on_chain_end(outputs={}, run_id=inner_id) + assert sm.get_record(inner_id) is None + + # 5) Outer agent ends + handler.on_chain_end(outputs={}, run_id=outer_id) + assert sm.get_record(outer_id) is None + + def test_token_usage_propagates_to_nearest_agent(self): + """Token usage from an LLM call accumulates on the nearest agent ancestor.""" + handler, sm, tracer = _make_handler_and_manager() + + outer_id = uuid4() + inner_id = uuid4() + + handler.on_chain_start( + serialized={"name": "outer"}, + inputs={}, + run_id=outer_id, + metadata=_agent_metadata("outer"), + ) + handler.on_chain_start( + serialized={"name": "inner"}, + inputs={}, + run_id=inner_id, + parent_run_id=outer_id, + metadata=_agent_metadata("inner"), + ) + + outer_record = sm.get_record(outer_id) + inner_record = sm.get_record(inner_id) + + # Accumulate on the inner agent (the nearest) + sm.accumulate_llm_usage_to_agent( + inner_id, input_tokens=10, output_tokens=5 + ) + + inner_attrs = _get_set_attributes(inner_record.span) + assert inner_attrs.get(GenAI.GEN_AI_USAGE_INPUT_TOKENS) == 10 + + # Outer should NOT have received tokens (accumulation stops at nearest agent) + outer_attrs = _get_set_attributes(outer_record.span) + assert GenAI.GEN_AI_USAGE_INPUT_TOKENS not in outer_attrs + + handler.on_chain_end(outputs={}, run_id=inner_id) + handler.on_chain_end(outputs={}, run_id=outer_id) + + +# ------------------------------------------------------------------ +# Scenario 3: Ignored intermediate chain +# ------------------------------------------------------------------ + + +class TestIgnoredIntermediateChain: + """Agent → internal chain (ignored) → LLM → internal chain ends → agent ends. + + Verifies: + - Internal chain produces no span. + - LLM-related operations walk through the ignored run to the agent. + """ + + def test_ignored_chain_produces_no_span(self): + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + chain_id = uuid4() + + # 1) Agent starts + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("my_agent"), + ) + + # 2) Internal chain starts — no agent/workflow signals → ignored + handler.on_chain_start( + serialized={"name": "RunnableSequence"}, + inputs={}, + run_id=chain_id, + parent_run_id=agent_id, + metadata=None, + ) + + # The chain should be ignored: no span record created + assert sm.get_record(chain_id) is None + assert sm.is_ignored(chain_id) + + # Only one span created (the agent) + assert tracer.start_span.call_count == 1 + + # 3) Internal chain ends (no-op for ignored) + handler.on_chain_end(outputs={}, run_id=chain_id) + + # 4) Agent ends + handler.on_chain_end(outputs={}, run_id=agent_id) + assert sm.get_record(agent_id) is None + + def test_tool_parents_through_ignored_chain(self, monkeypatch): + """A tool whose parent_run_id points to an ignored chain should + resolve to the agent span as its effective parent.""" + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("my_agent"), + ) + + # Internal chain is ignored + handler.on_chain_start( + serialized={"name": "RunnableSequence"}, + inputs={}, + run_id=chain_id, + parent_run_id=agent_id, + metadata=None, + ) + assert sm.is_ignored(chain_id) + + # Tool claims chain_id as parent → should resolve to agent_id + handler.on_tool_start( + serialized={"name": "search"}, + input_str="query", + run_id=tool_id, + parent_run_id=chain_id, + ) + + tool_record = sm.get_record(tool_id) + assert tool_record is not None + # The tool's raw parent_run_id is the chain, but start_span resolves + # through the ignored run. Verify the tracer was called with the + # agent's context. + tool_start_call = tracer.start_span.call_args_list[1] # second call + ctx = tool_start_call[1].get("context") + assert ctx is not None # parent context was set (the agent's) + + handler.on_tool_end(output="ok", run_id=tool_id) + handler.on_chain_end(outputs={}, run_id=chain_id) + handler.on_chain_end(outputs={}, run_id=agent_id) + + def test_llm_token_accumulation_through_ignored_chain(self): + """Token usage from an LLM whose parent is an ignored chain should + still accumulate on the agent.""" + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + chain_id = uuid4() + + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("my_agent"), + ) + + # Ignored chain + handler.on_chain_start( + serialized={"name": "RunnableSequence"}, + inputs={}, + run_id=chain_id, + parent_run_id=agent_id, + metadata=None, + ) + + agent_record = sm.get_record(agent_id) + + # LLM call with chain_id as parent → accumulate_llm_usage_to_agent + # should walk through the ignored chain to the agent. + sm.accumulate_llm_usage_to_agent( + chain_id, input_tokens=75, output_tokens=30 + ) + + attrs = _get_set_attributes(agent_record.span) + assert attrs[GenAI.GEN_AI_USAGE_INPUT_TOKENS] == 75 + assert attrs[GenAI.GEN_AI_USAGE_OUTPUT_TOKENS] == 30 + + handler.on_chain_end(outputs={}, run_id=chain_id) + handler.on_chain_end(outputs={}, run_id=agent_id) + + +# ------------------------------------------------------------------ +# Scenario 4: Tool with error +# ------------------------------------------------------------------ + + +class TestToolWithError: + """Agent → tool starts → tool errors → agent ends with error. + + Verifies: + - Tool span has error status. + - Agent span has error status. + """ + + def test_tool_error_sets_error_status(self, monkeypatch): + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + tool_id = uuid4() + error = ValueError("tool exploded") + + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("my_agent"), + ) + + handler.on_tool_start( + serialized={"name": "risky_tool"}, + input_str="danger", + run_id=tool_id, + parent_run_id=agent_id, + ) + + tool_record = sm.get_record(tool_id) + tool_span = tool_record.span + + # Tool errors + handler.on_tool_error(error=error, run_id=tool_id) + + # Tool span should be ended with error attributes + tool_span.set_attribute.assert_any_call( + error_attributes.ERROR_TYPE, "ValueError" + ) + tool_span.set_status.assert_called_once() + status_call = tool_span.set_status.call_args[0][0] + assert status_call.status_code == StatusCode.ERROR + assert "tool exploded" in str(status_call.description) + tool_span.end.assert_called_once() + + # Tool record should be removed + assert sm.get_record(tool_id) is None + + def test_agent_error_after_tool_error(self, monkeypatch): + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + tool_id = uuid4() + tool_error = RuntimeError("timeout") + agent_error = RuntimeError("agent failed due to tool error") + + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("my_agent"), + ) + + handler.on_tool_start( + serialized={"name": "slow_tool"}, + input_str="data", + run_id=tool_id, + parent_run_id=agent_id, + ) + + agent_record = sm.get_record(agent_id) + agent_span = agent_record.span + + # Tool errors + handler.on_tool_error(error=tool_error, run_id=tool_id) + assert sm.get_record(tool_id) is None + + # Agent errors + handler.on_chain_error(error=agent_error, run_id=agent_id) + + # Agent span should have error status + agent_span.set_attribute.assert_any_call( + error_attributes.ERROR_TYPE, "RuntimeError" + ) + agent_span.set_status.assert_called_once() + agent_status = agent_span.set_status.call_args[0][0] + assert agent_status.status_code == StatusCode.ERROR + assert "agent failed" in str(agent_status.description) + agent_span.end.assert_called_once() + + assert sm.get_record(agent_id) is None + + def test_error_does_not_leak_across_runs(self, monkeypatch): + """A tool error in one run should not affect a sibling tool.""" + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + tool_ok_id = uuid4() + tool_err_id = uuid4() + + handler.on_chain_start( + serialized={"name": "my_agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("my_agent"), + ) + + # Two tools start + handler.on_tool_start( + serialized={"name": "good_tool"}, + input_str="", + run_id=tool_ok_id, + parent_run_id=agent_id, + ) + handler.on_tool_start( + serialized={"name": "bad_tool"}, + input_str="", + run_id=tool_err_id, + parent_run_id=agent_id, + ) + + ok_record = sm.get_record(tool_ok_id) + ok_span = ok_record.span + + # One tool errors + handler.on_tool_error(error=ValueError("boom"), run_id=tool_err_id) + assert sm.get_record(tool_err_id) is None + + # Other tool completes normally + handler.on_tool_end(output="all good", run_id=tool_ok_id) + # The OK tool's span should NOT have error status set + for call in ok_span.set_status.call_args_list: + assert call[0][0].status_code != StatusCode.ERROR + + handler.on_chain_end(outputs={}, run_id=agent_id) + + +# ------------------------------------------------------------------ +# Scenario 5: Span manager cleanup +# ------------------------------------------------------------------ + + +class TestSpanManagerCleanup: + """Verify that all internal state is cleaned up after a full lifecycle.""" + + def test_no_leftover_state(self, monkeypatch): + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + handler, sm, tracer = _make_handler_and_manager() + + agent_id = uuid4() + chain_id = uuid4() + tool_id = uuid4() + + handler.on_chain_start( + serialized={"name": "agent"}, + inputs={}, + run_id=agent_id, + metadata=_agent_metadata("agent"), + ) + handler.on_chain_start( + serialized={"name": "inner"}, + inputs={}, + run_id=chain_id, + parent_run_id=agent_id, + ) + handler.on_tool_start( + serialized={"name": "tool"}, + input_str="", + run_id=tool_id, + parent_run_id=chain_id, + ) + + handler.on_tool_end(output="ok", run_id=tool_id) + handler.on_chain_end(outputs={}, run_id=chain_id) + handler.on_chain_end(outputs={}, run_id=agent_id) + + # All span records should be cleared + assert len(sm._spans) == 0 diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_event_emitter.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_event_emitter.py new file mode 100644 index 0000000000..301cafe272 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_event_emitter.py @@ -0,0 +1,115 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from opentelemetry.instrumentation.langchain.event_emitter import EventEmitter + + +def _make_policy(*, should_emit_events: bool, record_content: bool): + policy = mock.MagicMock() + policy.should_emit_events = should_emit_events + policy.record_content = record_content + return policy + + +def _make_emitter(): + emitter = EventEmitter() + emitter._logger = mock.MagicMock() + return emitter + + +def test_emits_tool_call_event_with_content(monkeypatch): + policy = _make_policy(should_emit_events=True, record_content=True) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_tool_call_event( + mock.MagicMock(), + "calculator", + '{"x": 1}', + "call_123", + ) + + record = emitter._logger.emit.call_args.args[0] + assert record.event_name == "gen_ai.tool.call" + assert record.body == { + "name": "calculator", + "id": "call_123", + "arguments": '{"x": 1}', + } + + +def test_redacts_tool_result_when_content_recording_disabled(monkeypatch): + policy = _make_policy(should_emit_events=True, record_content=False) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_tool_result_event( + mock.MagicMock(), + "calculator", + '{"result": 2}', + ) + + record = emitter._logger.emit.call_args.args[0] + assert record.event_name == "gen_ai.tool.result" + assert record.body == { + "name": "calculator", + "result": "[redacted]", + } + + +def test_skips_agent_event_when_disabled(monkeypatch): + policy = _make_policy(should_emit_events=False, record_content=True) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_agent_start_event( + mock.MagicMock(), + "planner", + '[{"content": "hi"}]', + ) + + emitter._logger.emit.assert_not_called() + + +def test_emits_retriever_result_event(monkeypatch): + policy = _make_policy(should_emit_events=True, record_content=True) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.event_emitter.get_content_policy", + lambda: policy, + ) + emitter = _make_emitter() + + emitter.emit_retriever_result_event( + mock.MagicMock(), + "vector_store", + '[{"metadata": {"source": "a.txt"}}]', + ) + + record = emitter._logger.emit.call_args.args[0] + assert record.event_name == "gen_ai.retriever.result" + assert record.body == { + "name": "vector_store", + "documents": '[{"metadata": {"source": "a.txt"}}]', + } diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_llm_call.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_llm_call.py index 2fdd6a3acf..c7ab52bf6c 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_llm_call.py +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_llm_call.py @@ -1,3 +1,5 @@ +import os +from pathlib import Path from typing import Optional import pytest @@ -157,36 +159,37 @@ def test_chat_openai_gpt_3_5_turbo_model_llm_call_with_error( # span_exporter, start_instrumentation, us_amazon_nova_lite_v1_0 are coming from fixtures defined in conftest.py @pytest.mark.vcr() def test_us_amazon_nova_lite_v1_0_bedrock_llm_call( - span_exporter, start_instrumentation, us_amazon_nova_lite_v1_0 + span_exporter, start_instrumentation, us_amazon_nova_lite_v1_0, vcr ): messages = [ SystemMessage(content="You are a helpful assistant!"), HumanMessage(content="What is the capital of France?"), ] - result = us_amazon_nova_lite_v1_0.invoke(messages) + with vcr.use_cassette( + "test_us_amazon_nova_lite_v1_0_bedrock_llm_call.yaml", + match_on=["method", "scheme", "host", "port", "query"], + ): + result = us_amazon_nova_lite_v1_0.invoke(messages) assert result.content.find("The capital of France is Paris") != -1 # verify spans spans = span_exporter.get_finished_spans() - print(f"spans: {spans}") - for span in spans: - print(f"span: {span}") - print(f"span attributes: {span.attributes}") - # TODO: fix the code and ensure the assertions are correct assert_bedrock_completion_attributes(spans[0], result) # span_exporter, start_instrumentation, gemini are coming from fixtures defined in conftest.py @pytest.mark.vcr() -def test_gemini(span_exporter, start_instrumentation, gemini): +def test_gemini(span_exporter, start_instrumentation, gemini, request, vcr): + _skip_if_cassette_missing_and_no_real_key(request) messages = [ SystemMessage(content="You are a helpful assistant!"), HumanMessage(content="What is the capital of France?"), ] - result = gemini.invoke(messages) + with vcr.use_cassette(f"{request.node.name}.yaml"): + result = gemini.invoke(messages) assert result.content.find("The capital of France is **Paris**") != -1 @@ -195,6 +198,21 @@ def test_gemini(span_exporter, start_instrumentation, gemini): assert len(spans) == 0 # No spans should be created for gemini as of now +def _skip_if_cassette_missing_and_no_real_key(request): + cassette_path = ( + Path(__file__).parent / "cassettes" / f"{request.node.name}.yaml" + ) + if not cassette_path.exists() and gemini_api_key_is_placeholder(): + pytest.skip( + f"Cassette {cassette_path.name} is missing. " + "Set a real GOOGLE_API_KEY-compatible credential to record it." + ) + + +def gemini_api_key_is_placeholder(): + return os.getenv("GOOGLE_API_KEY", "test_key") == "test_key" + + def assert_openai_completion_attributes( span: ReadableSpan, response: Optional, verify_content: bool = True ): @@ -321,7 +339,7 @@ def assert_bedrock_completion_attributes( == "us.amazon.nova-lite-v1:0" ) - assert span.attributes["gen_ai.provider.name"] == "amazon_bedrock" + assert span.attributes["gen_ai.provider.name"] == "aws.bedrock" assert span.attributes[gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS] == 100 assert span.attributes[gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE] == 0.1 diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_operation_mapping.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_operation_mapping.py new file mode 100644 index 0000000000..63ee5b6fac --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_operation_mapping.py @@ -0,0 +1,336 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 + +from opentelemetry.instrumentation.langchain.operation_mapping import ( + OperationName, + classify_chain_run, + resolve_agent_name, + should_ignore_chain, +) + +# --------------------------------------------------------------------------- +# classify_chain_run +# --------------------------------------------------------------------------- + + +class TestClassifyChainRunAgentDetection: + """Agent signals → invoke_agent.""" + + def test_otel_agent_span_true(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_agent_span": True}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_metadata_agent_name(self): + result = classify_chain_run( + serialized={}, + metadata={"agent_name": "my-agent"}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_metadata_agent_type(self): + result = classify_chain_run( + serialized={}, + metadata={"agent_type": "react"}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_langgraph_agent_node(self): + result = classify_chain_run( + serialized={}, + metadata={"langgraph_node": "researcher"}, + kwargs={}, + ) + assert result == OperationName.INVOKE_AGENT + + def test_agent_signals_override_workflow(self): + """Agent signals take priority over workflow heuristics.""" + result = classify_chain_run( + serialized={"name": "LangGraph"}, + metadata={"otel_agent_span": True}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_AGENT + + +class TestClassifyChainRunWorkflowDetection: + """Workflow signals → invoke_workflow.""" + + def test_top_level_langgraph_by_name(self): + result = classify_chain_run( + serialized={"name": "LangGraph"}, + metadata={}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_WORKFLOW + + def test_top_level_langgraph_by_graph_id(self): + result = classify_chain_run( + serialized={"name": "other", "graph": {"id": "LangGraph-abc"}}, + metadata={}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_WORKFLOW + + def test_otel_workflow_span_true(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_workflow_span": True}, + kwargs={}, + parent_run_id=None, + ) + assert result == OperationName.INVOKE_WORKFLOW + + def test_not_workflow_when_has_parent(self): + """LangGraph name alone is not enough when there is a parent run.""" + result = classify_chain_run( + serialized={"name": "LangGraph"}, + metadata={}, + kwargs={}, + parent_run_id=uuid4(), + ) + assert result is None + + +class TestClassifyChainRunSuppression: + """Chains that should be suppressed (return None).""" + + def test_start_node_suppressed(self): + result = classify_chain_run( + serialized={}, + metadata={"langgraph_node": "__start__"}, + kwargs={}, + ) + assert result is None + + def test_middleware_prefix_suppressed(self): + result = classify_chain_run( + serialized={"name": "Middleware.auth"}, + metadata={"langgraph_node": "Middleware.auth"}, + kwargs={"name": "Middleware.auth"}, + ) + assert result is None + + def test_otel_trace_false(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_trace": False}, + kwargs={}, + ) + assert result is None + + def test_otel_agent_span_false_no_other_signals(self): + result = classify_chain_run( + serialized={}, + metadata={"otel_agent_span": False}, + kwargs={}, + ) + assert result is None + + def test_unclassified_generic_chain(self): + result = classify_chain_run( + serialized={"name": "RunnableSequence"}, + metadata={}, + kwargs={}, + parent_run_id=uuid4(), + ) + assert result is None + + +# --------------------------------------------------------------------------- +# should_ignore_chain +# --------------------------------------------------------------------------- + + +class TestShouldIgnoreChain: + """Suppression logic for known noise chains.""" + + def test_ignores_start_node(self): + assert should_ignore_chain( + metadata={"langgraph_node": "__start__"}, + agent_name=None, + parent_run_id=None, + kwargs={}, + ) + + def test_ignores_middleware_agent_name(self): + assert should_ignore_chain( + metadata={}, + agent_name="Middleware.something", + parent_run_id=None, + kwargs={}, + ) + + def test_ignores_middleware_in_kwargs_name(self): + assert should_ignore_chain( + metadata={}, + agent_name=None, + parent_run_id=None, + kwargs={"name": "Middleware.guard"}, + ) + + def test_ignores_otel_trace_false(self): + assert should_ignore_chain( + metadata={"otel_trace": False}, + agent_name=None, + parent_run_id=None, + kwargs={}, + ) + + def test_ignores_otel_agent_span_false_no_signals(self): + assert should_ignore_chain( + metadata={"otel_agent_span": False}, + agent_name=None, + parent_run_id=None, + kwargs={}, + ) + + def test_does_not_ignore_otel_agent_span_false_with_agent_name(self): + """otel_agent_span=False is overridden when agent_name is present.""" + assert not should_ignore_chain( + metadata={"otel_agent_span": False, "agent_name": "planner"}, + agent_name="planner", + parent_run_id=None, + kwargs={}, + ) + + def test_does_not_ignore_normal_agent_node(self): + assert not should_ignore_chain( + metadata={"langgraph_node": "researcher"}, + agent_name="researcher", + parent_run_id=uuid4(), + kwargs={}, + ) + + +# --------------------------------------------------------------------------- +# resolve_agent_name +# --------------------------------------------------------------------------- + + +class TestResolveAgentName: + """Best-effort agent name resolution from callback arguments.""" + + def test_from_metadata_agent_name(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={"agent_name": "planner"}, + kwargs={}, + ) + == "planner" + ) + + def test_from_kwargs_name(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={}, + kwargs={"name": "tool-caller"}, + ) + == "tool-caller" + ) + + def test_from_serialized_name(self): + assert ( + resolve_agent_name( + serialized={"name": "MyAgent"}, + metadata={}, + kwargs={}, + ) + == "MyAgent" + ) + + def test_from_langgraph_node(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={"langgraph_node": "researcher"}, + kwargs={}, + ) + == "researcher" + ) + + def test_langgraph_start_node_excluded(self): + assert ( + resolve_agent_name( + serialized={}, + metadata={"langgraph_node": "__start__"}, + kwargs={}, + ) + is None + ) + + def test_returns_none_when_nothing_available(self): + assert ( + resolve_agent_name(serialized={}, metadata={}, kwargs={}) is None + ) + + def test_returns_none_with_none_metadata(self): + assert ( + resolve_agent_name(serialized={}, metadata=None, kwargs={}) is None + ) + + def test_priority_metadata_over_kwargs(self): + """metadata agent_name has higher priority than kwargs name.""" + assert ( + resolve_agent_name( + serialized={"name": "serialized"}, + metadata={"agent_name": "meta"}, + kwargs={"name": "kw"}, + ) + == "meta" + ) + + def test_priority_kwargs_over_serialized(self): + assert ( + resolve_agent_name( + serialized={"name": "serialized"}, + metadata={}, + kwargs={"name": "kw"}, + ) + == "kw" + ) + + +# --------------------------------------------------------------------------- +# OperationName constants +# --------------------------------------------------------------------------- + + +class TestOperationNameConstants: + def test_chat(self): + assert OperationName.CHAT == "chat" + + def test_text_completion(self): + assert OperationName.TEXT_COMPLETION == "text_completion" + + def test_invoke_agent(self): + assert OperationName.INVOKE_AGENT == "invoke_agent" + + def test_execute_tool(self): + assert OperationName.EXECUTE_TOOL == "execute_tool" + + def test_invoke_workflow(self): + assert OperationName.INVOKE_WORKFLOW == "invoke_workflow" diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_retriever_callbacks.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_retriever_callbacks.py new file mode 100644 index 0000000000..5dc6d70344 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_retriever_callbacks.py @@ -0,0 +1,364 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for on_retriever_start / on_retriever_end / on_retriever_error callbacks.""" + +from __future__ import annotations + +import json +from unittest import mock +from uuid import uuid4 + +import pytest + +from opentelemetry.instrumentation.langchain.callback_handler import ( + OpenTelemetryLangChainCallbackHandler, +) +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_EXECUTE_TOOL, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + SpanRecord, + _SpanManager, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.trace import SpanKind +from opentelemetry.trace.status import StatusCode + +# --------------------------------------------------------------------------- +# Helpers & fixtures +# --------------------------------------------------------------------------- + + +def _make_mock_span(): + """Create a mock span with the interface used by the callback handler.""" + span = mock.MagicMock() + span.is_recording.return_value = True + span.set_attribute = mock.MagicMock() + span.set_status = mock.MagicMock() + return span + + +def _make_span_record( + run_id, span=None, operation=OP_EXECUTE_TOOL, attributes=None +): + return SpanRecord( + run_id=str(run_id), + span=span or _make_mock_span(), + operation=operation, + attributes=attributes or {}, + ) + + +@pytest.fixture +def mock_span_manager(): + mgr = mock.MagicMock(spec=_SpanManager) + mgr.is_ignored.return_value = False + mgr.resolve_parent_id.side_effect = lambda parent_run_id: ( + str(parent_run_id) if parent_run_id is not None else None + ) + mgr.start_span.return_value = _make_span_record(uuid4()) + return mgr + + +@pytest.fixture +def handler(mock_span_manager): + telemetry_handler = mock.MagicMock() + return OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=mock_span_manager, + ) + + +# --------------------------------------------------------------------------- +# on_retriever_start +# --------------------------------------------------------------------------- + + +class TestOnRetrieverStart: + """on_retriever_start creates execute_tool spans for retrievers.""" + + def test_creates_execute_tool_span_with_retriever_type( + self, handler, mock_span_manager + ): + run_id = uuid4() + handler.on_retriever_start( + serialized={"name": "vector_store"}, + query="What is OpenTelemetry?", + run_id=run_id, + ) + + mock_span_manager.start_span.assert_called_once() + call_kwargs = mock_span_manager.start_span.call_args.kwargs + assert call_kwargs["operation"] == OP_EXECUTE_TOOL + assert call_kwargs["kind"] == SpanKind.INTERNAL + assert call_kwargs["name"] == f"{OP_EXECUTE_TOOL} vector_store" + attrs = call_kwargs["attributes"] + assert attrs[GenAI.GEN_AI_OPERATION_NAME] == OP_EXECUTE_TOOL + assert attrs[GenAI.GEN_AI_TOOL_NAME] == "vector_store" + assert attrs[GenAI.GEN_AI_TOOL_TYPE] == "retriever" + + def test_sets_query_when_content_recording_enabled( + self, handler, mock_span_manager + ): + run_id = uuid4() + with mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_retriever_content", + return_value=True, + ): + handler.on_retriever_start( + serialized={"name": "retriever"}, + query="semantic search query", + run_id=run_id, + ) + + call_kwargs = mock_span_manager.start_span.call_args.kwargs + assert ( + call_kwargs["attributes"]["gen_ai.retrieval.query.text"] + == "semantic search query" + ) + + def test_redacts_query_when_content_recording_disabled( + self, handler, mock_span_manager + ): + run_id = uuid4() + with mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_retriever_content", + return_value=False, + ): + handler.on_retriever_start( + serialized={"name": "retriever"}, + query="secret query", + run_id=run_id, + ) + + call_kwargs = mock_span_manager.start_span.call_args.kwargs + assert "gen_ai.retrieval.query.text" not in call_kwargs["attributes"] + + def test_inherits_provider_from_parent_span( + self, handler, mock_span_manager + ): + parent_run_id = uuid4() + run_id = uuid4() + parent_record = _make_span_record( + parent_run_id, + attributes={GenAI.GEN_AI_PROVIDER_NAME: "openai"}, + ) + mock_span_manager.get_record.return_value = parent_record + + handler.on_retriever_start( + serialized={"name": "retriever"}, + query="test query", + run_id=run_id, + parent_run_id=parent_run_id, + ) + + call_kwargs = mock_span_manager.start_span.call_args.kwargs + assert ( + call_kwargs["attributes"][GenAI.GEN_AI_PROVIDER_NAME] == "openai" + ) + + def test_defaults_tool_name_to_retriever(self, handler, mock_span_manager): + run_id = uuid4() + handler.on_retriever_start( + serialized={}, + query="test", + run_id=run_id, + ) + + call_kwargs = mock_span_manager.start_span.call_args.kwargs + assert call_kwargs["attributes"][GenAI.GEN_AI_TOOL_NAME] == "retriever" + assert call_kwargs["name"] == f"{OP_EXECUTE_TOOL} retriever" + + def test_returns_early_when_no_span_manager(self): + telemetry_handler = mock.MagicMock() + h = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=None, + ) + # Should not raise. + h.on_retriever_start( + serialized={"name": "retriever"}, + query="test", + run_id=uuid4(), + ) + + +# --------------------------------------------------------------------------- +# on_retriever_end +# --------------------------------------------------------------------------- + + +class TestOnRetrieverEnd: + """on_retriever_end sets retrieval documents and ends span.""" + + def test_sets_documents_with_content_when_enabled( + self, handler, mock_span_manager + ): + run_id = uuid4() + span = _make_mock_span() + record = _make_span_record(run_id, span=span) + mock_span_manager.get_record.return_value = record + + docs = [ + mock.MagicMock( + page_content="Document 1", metadata={"source": "a.txt"} + ), + mock.MagicMock( + page_content="Document 2", metadata={"source": "b.txt"} + ), + ] + + formatted_json = json.dumps( + [ + { + "page_content": "Document 1", + "metadata": {"source": "a.txt"}, + }, + { + "page_content": "Document 2", + "metadata": {"source": "b.txt"}, + }, + ] + ) + + with ( + mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_retriever_content", + return_value=True, + ), + mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.format_documents", + return_value=formatted_json, + ) as mock_format, + ): + handler.on_retriever_end( + documents=docs, + run_id=run_id, + ) + mock_format.assert_called_once_with(docs, record_content=True) + + span.set_attribute.assert_called_once_with( + "gen_ai.retrieval.documents", formatted_json + ) + + def test_sets_documents_metadata_only_when_content_disabled( + self, handler, mock_span_manager + ): + run_id = uuid4() + span = _make_mock_span() + record = _make_span_record(run_id, span=span) + mock_span_manager.get_record.return_value = record + + docs = [ + mock.MagicMock( + page_content="Secret content", + metadata={"source": "a.txt"}, + ), + ] + + metadata_only_json = json.dumps([{"metadata": {"source": "a.txt"}}]) + + with ( + mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_retriever_content", + return_value=False, + ), + mock.patch( + "opentelemetry.instrumentation.langchain.callback_handler.format_documents", + return_value=metadata_only_json, + ) as mock_format, + ): + handler.on_retriever_end( + documents=docs, + run_id=run_id, + ) + mock_format.assert_called_once_with(docs, record_content=False) + + span.set_attribute.assert_called_once_with( + "gen_ai.retrieval.documents", metadata_only_json + ) + + def test_ends_span_with_ok_status(self, handler, mock_span_manager): + run_id = uuid4() + record = _make_span_record(run_id) + mock_span_manager.get_record.return_value = record + + handler.on_retriever_end( + documents=[], + run_id=run_id, + ) + + mock_span_manager.end_span.assert_called_once_with( + run_id=str(run_id), status=StatusCode.OK + ) + + def test_returns_early_when_no_record(self, handler, mock_span_manager): + run_id = uuid4() + mock_span_manager.get_record.return_value = None + + handler.on_retriever_end( + documents=[], + run_id=run_id, + ) + + mock_span_manager.end_span.assert_not_called() + + def test_returns_early_when_no_span_manager(self): + telemetry_handler = mock.MagicMock() + h = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=None, + ) + # Should not raise. + h.on_retriever_end( + documents=[], + run_id=uuid4(), + ) + + +# --------------------------------------------------------------------------- +# on_retriever_error +# --------------------------------------------------------------------------- + + +class TestOnRetrieverError: + """on_retriever_error ends the span with error status.""" + + def test_ends_span_with_error(self, handler, mock_span_manager): + run_id = uuid4() + error = RuntimeError("retriever failed") + + handler.on_retriever_error( + error=error, + run_id=run_id, + ) + + mock_span_manager.end_span.assert_called_once_with( + run_id=str(run_id), error=error + ) + + def test_returns_early_when_no_span_manager(self): + telemetry_handler = mock.MagicMock() + h = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=None, + ) + # Should not raise. + h.on_retriever_error( + error=RuntimeError("boom"), + run_id=uuid4(), + ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_span_hierarchy.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_span_hierarchy.py new file mode 100644 index 0000000000..484e047617 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_span_hierarchy.py @@ -0,0 +1,299 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for span hierarchy and parent-child resolution in _SpanManager.""" + +from unittest import mock + +import pytest + +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_CHAT, + OP_INVOKE_AGENT, +) +from opentelemetry.instrumentation.langchain.span_manager import _SpanManager +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.trace.status import StatusCode + + +def _make_mock_span(): + span = mock.MagicMock() + span.is_recording.return_value = True + return span + + +def _make_tracer(): + tracer = mock.MagicMock() + tracer.start_span.side_effect = lambda **kwargs: _make_mock_span() + return tracer + + +@pytest.fixture() +def tracer(): + return _make_tracer() + + +@pytest.fixture() +def mgr(tracer): + return _SpanManager(tracer) + + +# ------------------------------------------------------------------ +# resolve_parent_id +# ------------------------------------------------------------------ + + +class TestResolveParentId: + def test_returns_parent_when_parent_exists_in_spans(self, mgr): + """Parent run_id is returned directly when it is not ignored.""" + mgr.start_span( + run_id="parent-1", + name="agent", + operation=OP_INVOKE_AGENT, + ) + assert mgr.resolve_parent_id("parent-1") == "parent-1" + + def test_walks_through_ignored_runs(self, mgr): + """Children of ignored runs are re-parented to the visible ancestor.""" + mgr.start_span( + run_id="grandparent", + name="agent", + operation=OP_INVOKE_AGENT, + ) + # Middle run is ignored; its parent is the grandparent. + mgr.ignore_run("ignored-middle", parent_run_id="grandparent") + + resolved = mgr.resolve_parent_id("ignored-middle") + assert resolved == "grandparent" + + def test_cycle_in_ignored_chain_returns_none(self, mgr): + """A cycle among ignored runs must not loop forever.""" + mgr.ignore_run("a", parent_run_id="b") + mgr.ignore_run("b", parent_run_id="a") + + assert mgr.resolve_parent_id("a") is None + + def test_returns_none_when_parent_not_found(self, mgr): + assert mgr.resolve_parent_id(None) is None + # Unknown non-ignored id is returned as-is (it is "visible"). + assert mgr.resolve_parent_id("nonexistent") == "nonexistent" + + +# ------------------------------------------------------------------ +# Agent stacks +# ------------------------------------------------------------------ + + +class TestAgentStacks: + def test_agent_stacks_track_per_thread(self, mgr): + """Each thread_key gets its own independent agent stack.""" + mgr.start_span( + run_id="agent-t1", + name="agent", + operation=OP_INVOKE_AGENT, + thread_key="thread-1", + ) + mgr.start_span( + run_id="agent-t2", + name="agent", + operation=OP_INVOKE_AGENT, + thread_key="thread-2", + ) + + assert mgr._agent_stack_by_thread["thread-1"] == ["agent-t1"] + assert mgr._agent_stack_by_thread["thread-2"] == ["agent-t2"] + + def test_start_span_adds_to_agent_stack(self, mgr): + """invoke_agent spans are pushed onto the per-thread agent stack.""" + mgr.start_span( + run_id="agent-a", + name="outer", + operation=OP_INVOKE_AGENT, + thread_key="t", + ) + mgr.start_span( + run_id="agent-b", + name="inner", + operation=OP_INVOKE_AGENT, + thread_key="t", + ) + + assert mgr._agent_stack_by_thread["t"] == ["agent-a", "agent-b"] + + def test_end_span_removes_from_agent_stack(self, mgr): + """Ending an invoke_agent span pops it from the agent stack.""" + mgr.start_span( + run_id="agent-x", + name="agent", + operation=OP_INVOKE_AGENT, + thread_key="tk", + ) + assert "tk" in mgr._agent_stack_by_thread + + mgr.end_span("agent-x") + + # Stack should be cleaned up entirely when empty. + assert "tk" not in mgr._agent_stack_by_thread + + +# ------------------------------------------------------------------ +# Goto routing +# ------------------------------------------------------------------ + + +class TestGotoRouting: + def test_push_pop_lifo(self, mgr): + """push/pop follows LIFO order.""" + mgr.push_goto_parent("t1", "parent-a") + mgr.push_goto_parent("t1", "parent-b") + + assert mgr.pop_goto_parent("t1") == "parent-b" + assert mgr.pop_goto_parent("t1") == "parent-a" + + def test_pop_empty_returns_none(self, mgr): + assert mgr.pop_goto_parent("nonexistent-thread") is None + + def test_cleanup_of_empty_stacks(self, mgr): + """Once the last goto parent is popped the thread key is removed.""" + mgr.push_goto_parent("t1", "p1") + mgr.pop_goto_parent("t1") + + assert "t1" not in mgr._goto_parent_stack + + +# ------------------------------------------------------------------ +# accumulate_usage_to_parent +# ------------------------------------------------------------------ + + +class TestAccumulateUsageToParent: + def test_accumulates_on_nearest_agent_parent(self, mgr): + """Token counts propagate to the nearest invoke_agent ancestor.""" + agent_rec = mgr.start_span( + run_id="agent", + name="agent", + operation=OP_INVOKE_AGENT, + ) + chat_rec = mgr.start_span( + run_id="chat", + name="chat gpt-4o", + operation=OP_CHAT, + parent_run_id="agent", + ) + + mgr.accumulate_usage_to_parent( + chat_rec, input_tokens=10, output_tokens=5 + ) + + agent_span = agent_rec.span + agent_span.set_attribute.assert_any_call( + GenAI.GEN_AI_USAGE_INPUT_TOKENS, 10 + ) + agent_span.set_attribute.assert_any_call( + GenAI.GEN_AI_USAGE_OUTPUT_TOKENS, 5 + ) + + def test_noop_when_no_agent_parent(self, mgr): + """No error when the parent chain has no invoke_agent span.""" + chat_rec = mgr.start_span( + run_id="chat-orphan", + name="chat", + operation=OP_CHAT, + ) + # Should not raise. + mgr.accumulate_usage_to_parent( + chat_rec, input_tokens=1, output_tokens=2 + ) + + def test_handles_none_token_values(self, mgr): + """Both tokens None → early return, no side-effects.""" + agent_rec = mgr.start_span( + run_id="agent", + name="agent", + operation=OP_INVOKE_AGENT, + ) + chat_rec = mgr.start_span( + run_id="chat", + name="chat", + operation=OP_CHAT, + parent_run_id="agent", + ) + + # Reset call tracking after start_span's own set_attribute calls. + agent_rec.span.set_attribute.reset_mock() + + mgr.accumulate_usage_to_parent( + chat_rec, input_tokens=None, output_tokens=None + ) + + # No token attributes should have been set on the agent span. + agent_rec.span.set_attribute.assert_not_called() + + +# ------------------------------------------------------------------ +# start_span / end_span +# ------------------------------------------------------------------ + + +class TestStartEndSpan: + def test_creates_span_with_correct_parent_context(self, mgr, tracer): + """start_span passes the parent span's context to the tracer.""" + mgr.start_span( + run_id="parent", + name="parent-agent", + operation=OP_INVOKE_AGENT, + ) + + mgr.start_span( + run_id="child", + name="child-chat", + operation=OP_CHAT, + parent_run_id="parent", + ) + + # The second start_span call should pass a non-None context. + calls = tracer.start_span.call_args_list + assert len(calls) == 2 + child_call_kwargs = calls[1][1] + assert child_call_kwargs["context"] is not None + + def test_end_span_with_error_sets_attributes(self, mgr): + """end_span records error type and ERROR status on the span.""" + rec = mgr.start_span( + run_id="err-run", + name="failing", + operation=OP_CHAT, + ) + span = rec.span + + mgr.end_span("err-run", error=ValueError("boom")) + + span.set_attribute.assert_any_call("error.type", "ValueError") + span.set_status.assert_called_once() + status_arg = span.set_status.call_args[0][0] + assert status_arg.status_code == StatusCode.ERROR + + def test_end_span_removes_record(self, mgr): + """After end_span the run_id is no longer tracked.""" + mgr.start_span( + run_id="temp", + name="temp", + operation=OP_CHAT, + ) + assert mgr.get_record("temp") is not None + + mgr.end_span("temp") + assert mgr.get_record("temp") is None diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_streaming_metrics.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_streaming_metrics.py new file mode 100644 index 0000000000..068afe8a35 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_streaming_metrics.py @@ -0,0 +1,331 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for on_llm_new_token streaming timing metrics.""" + +from __future__ import annotations + +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.outputs import LLMResult + +from opentelemetry.instrumentation.langchain.callback_handler import ( + OpenTelemetryLangChainCallbackHandler, +) +from opentelemetry.util.genai.types import LLMInvocation + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_handler(): + """Create a handler with mock telemetry and histogram instruments.""" + telemetry_handler = MagicMock() + # stop_llm / fail_llm must return an LLMInvocation so on_llm_end works + telemetry_handler.stop_llm.return_value = LLMInvocation() + telemetry_handler.fail_llm.return_value = LLMInvocation() + + handler = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + ) + + # Replace the real histograms with mocks so we can inspect calls. + ttfc = MagicMock(name="ttfc_histogram") + tpoc = MagicMock(name="tpoc_histogram") + handler._ttfc_histogram = ttfc + handler._tpoc_histogram = tpoc + + return handler, ttfc, tpoc + + +def _register_llm_invocation( + handler, run_id, *, monotonic_start_s=100.0, **kwargs +): + """Register an LLMInvocation in the handler's invocation manager.""" + invocation = LLMInvocation( + monotonic_start_s=monotonic_start_s, + operation_name=kwargs.get("operation_name", "chat"), + request_model=kwargs.get("request_model", "gpt-4"), + provider=kwargs.get("provider", "openai"), + response_model_name=kwargs.get("response_model_name"), + server_address=kwargs.get("server_address", "api.openai.com"), + server_port=kwargs.get("server_port"), + ) + # Wire a mock span so on_llm_end doesn't crash + invocation.span = MagicMock() + invocation.span.is_recording.return_value = False + + handler._invocation_manager.add_invocation_state( + run_id=run_id, + parent_run_id=None, + invocation=invocation, + ) + return invocation + + +def _empty_llm_result(): + return LLMResult(generations=[]) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestFirstTokenRecordsTTFC: + """First token records time_to_first_chunk metric.""" + + def test_first_token_records_ttfc(self): + handler, ttfc, tpoc = _make_handler() + run_id = uuid.uuid4() + _register_llm_invocation(handler, run_id, monotonic_start_s=100.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("Hello", run_id=run_id) + + ttfc.record.assert_called_once() + recorded_value = ttfc.record.call_args[0][0] + assert recorded_value == pytest.approx(0.5) + + # time_per_output_chunk must NOT be recorded for the first token + tpoc.record.assert_not_called() + + def test_ttfc_includes_metric_attributes(self): + handler, ttfc, _ = _make_handler() + run_id = uuid.uuid4() + _register_llm_invocation( + handler, + run_id, + monotonic_start_s=100.0, + operation_name="chat", + request_model="gpt-4", + provider="openai", + server_address="api.openai.com", + ) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + mock_timeit.default_timer.return_value = 100.3 + handler.on_llm_new_token("Hi", run_id=run_id) + + attrs = ttfc.record.call_args[1]["attributes"] + assert attrs["gen_ai.operation.name"] == "chat" + assert attrs["gen_ai.request.model"] == "gpt-4" + assert attrs["gen_ai.provider.name"] == "openai" + assert attrs["server.address"] == "api.openai.com" + + +class TestSubsequentTokenRecordsTPOC: + """Subsequent tokens record time_per_output_chunk metric.""" + + def test_second_token_records_tpoc(self): + handler, ttfc, tpoc = _make_handler() + run_id = uuid.uuid4() + _register_llm_invocation(handler, run_id, monotonic_start_s=100.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + # First token + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("Hello", run_id=run_id) + # Second token + mock_timeit.default_timer.return_value = 100.7 + handler.on_llm_new_token(" world", run_id=run_id) + + tpoc.record.assert_called_once() + recorded_value = tpoc.record.call_args[0][0] + assert recorded_value == pytest.approx(0.2) + + def test_third_token_also_records_tpoc(self): + handler, ttfc, tpoc = _make_handler() + run_id = uuid.uuid4() + _register_llm_invocation(handler, run_id, monotonic_start_s=100.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("a", run_id=run_id) + + mock_timeit.default_timer.return_value = 100.7 + handler.on_llm_new_token("b", run_id=run_id) + + mock_timeit.default_timer.return_value = 101.0 + handler.on_llm_new_token("c", run_id=run_id) + + assert tpoc.record.call_count == 2 + # Second call (c-b): 101.0 - 100.7 = 0.3 + assert tpoc.record.call_args_list[1][0][0] == pytest.approx(0.3) + + +class TestNoOpWhenInvocationNotFound: + """No-op when run_id is not in the invocation manager.""" + + def test_unknown_run_id_does_nothing(self): + handler, ttfc, tpoc = _make_handler() + unknown_run_id = uuid.uuid4() + + # Should not raise + handler.on_llm_new_token("token", run_id=unknown_run_id) + + ttfc.record.assert_not_called() + tpoc.record.assert_not_called() + + +class TestNoOpWhenMonotonicStartIsNone: + """No-op when invocation.monotonic_start_s is None.""" + + def test_none_monotonic_start_does_nothing(self): + handler, ttfc, tpoc = _make_handler() + run_id = uuid.uuid4() + + # Register invocation with monotonic_start_s=None + invocation = LLMInvocation(monotonic_start_s=None) + invocation.span = MagicMock() + handler._invocation_manager.add_invocation_state( + run_id=run_id, + parent_run_id=None, + invocation=invocation, + ) + + handler.on_llm_new_token("token", run_id=run_id) + + ttfc.record.assert_not_called() + tpoc.record.assert_not_called() + # Streaming state should not be updated either + assert str(run_id) not in handler._streaming_state + + +class TestStreamingStateCleanupOnLLMEnd: + """Streaming state is cleaned up in on_llm_end.""" + + def test_on_llm_end_removes_streaming_state(self): + handler, ttfc, tpoc = _make_handler() + run_id = uuid.uuid4() + _register_llm_invocation(handler, run_id, monotonic_start_s=100.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("Hello", run_id=run_id) + + # Streaming state should exist now + assert str(run_id) in handler._streaming_state + + handler.on_llm_end(_empty_llm_result(), run_id=run_id) + + assert str(run_id) not in handler._streaming_state + + +class TestStreamingStateCleanupOnLLMError: + """Streaming state is cleaned up in on_llm_error.""" + + def test_on_llm_error_removes_streaming_state(self): + handler, ttfc, tpoc = _make_handler() + run_id = uuid.uuid4() + _register_llm_invocation(handler, run_id, monotonic_start_s=100.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("Hello", run_id=run_id) + + assert str(run_id) in handler._streaming_state + + handler.on_llm_error(RuntimeError("boom"), run_id=run_id) + + assert str(run_id) not in handler._streaming_state + + +class TestMultipleStreamingSequences: + """Multiple streaming sequences (different run_ids) don't interfere.""" + + def test_independent_run_ids(self): + handler, ttfc, tpoc = _make_handler() + run_a = uuid.uuid4() + run_b = uuid.uuid4() + _register_llm_invocation(handler, run_a, monotonic_start_s=100.0) + _register_llm_invocation(handler, run_b, monotonic_start_s=200.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + # First token for run_a at t=100.5 + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("A1", run_id=run_a) + + # First token for run_b at t=200.3 + mock_timeit.default_timer.return_value = 200.3 + handler.on_llm_new_token("B1", run_id=run_b) + + # Second token for run_a at t=100.8 + mock_timeit.default_timer.return_value = 100.8 + handler.on_llm_new_token("A2", run_id=run_a) + + # Second token for run_b at t=200.6 + mock_timeit.default_timer.return_value = 200.6 + handler.on_llm_new_token("B2", run_id=run_b) + + # TTFC calls: run_a (0.5), run_b (0.3) + assert ttfc.record.call_count == 2 + ttfc_values = [call[0][0] for call in ttfc.record.call_args_list] + assert pytest.approx(0.5) in ttfc_values + assert pytest.approx(0.3) in ttfc_values + + # TPOC calls: run_a (0.3), run_b (0.3) + assert tpoc.record.call_count == 2 + tpoc_values = [call[0][0] for call in tpoc.record.call_args_list] + assert pytest.approx(0.3) in tpoc_values + + def test_ending_one_stream_does_not_affect_another(self): + handler, ttfc, tpoc = _make_handler() + run_a = uuid.uuid4() + run_b = uuid.uuid4() + _register_llm_invocation(handler, run_a, monotonic_start_s=100.0) + _register_llm_invocation(handler, run_b, monotonic_start_s=200.0) + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + mock_timeit.default_timer.return_value = 100.5 + handler.on_llm_new_token("A1", run_id=run_a) + + mock_timeit.default_timer.return_value = 200.3 + handler.on_llm_new_token("B1", run_id=run_b) + + # End run_a + handler.on_llm_end(_empty_llm_result(), run_id=run_a) + assert str(run_a) not in handler._streaming_state + # run_b should still be present + assert str(run_b) in handler._streaming_state + + with patch( + "opentelemetry.instrumentation.langchain.callback_handler.timeit" + ) as mock_timeit: + # run_b second token should still work + mock_timeit.default_timer.return_value = 200.6 + handler.on_llm_new_token("B2", run_id=run_b) + + tpoc.record.assert_called_once() + assert tpoc.record.call_args[0][0] == pytest.approx(0.3) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_tool_callbacks.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_tool_callbacks.py new file mode 100644 index 0000000000..bff6f51ab2 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_tool_callbacks.py @@ -0,0 +1,480 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock +from uuid import uuid4 + +from opentelemetry.instrumentation.langchain.callback_handler import ( + OpenTelemetryLangChainCallbackHandler, +) +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_EXECUTE_TOOL, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + SpanRecord, + _SpanManager, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.trace.status import StatusCode +from opentelemetry.util.genai.handler import TelemetryHandler + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_handler(span_manager=None): + """Create a callback handler with mocked dependencies.""" + telemetry_handler = mock.MagicMock(spec=TelemetryHandler) + handler = OpenTelemetryLangChainCallbackHandler( + telemetry_handler=telemetry_handler, + span_manager=span_manager, + ) + return handler + + +def _make_span_manager(): + """Create a mock _SpanManager.""" + sm = mock.MagicMock(spec=_SpanManager) + sm.resolve_parent_id.side_effect = lambda parent_run_id: ( + str(parent_run_id) if parent_run_id is not None else None + ) + return sm + + +def _make_span_record(run_id, attributes=None): + """Create a SpanRecord with a mock span.""" + span = mock.MagicMock() + return SpanRecord( + run_id=str(run_id), + span=span, + operation=OP_EXECUTE_TOOL, + attributes=attributes or {}, + ) + + +def _enable_content_recording(monkeypatch): + """Patch content_recording to enable tool content capture.""" + policy = mock.MagicMock() + policy.record_content = True + policy.should_emit_events = False + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.get_content_policy", + lambda: policy, + ) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: True, + ) + + +def _disable_content_recording(monkeypatch): + """Patch content_recording to disable tool content capture.""" + policy = mock.MagicMock() + policy.record_content = False + policy.should_emit_events = False + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.get_content_policy", + lambda: policy, + ) + monkeypatch.setattr( + "opentelemetry.instrumentation.langchain.callback_handler.should_record_tool_content", + lambda policy: False, + ) + + +# --------------------------------------------------------------------------- +# on_tool_start +# --------------------------------------------------------------------------- + + +class TestOnToolStart: + """Tests for the on_tool_start callback method.""" + + def test_creates_span_with_correct_attributes(self, monkeypatch): + """Span is created with operation_name, tool_name, and tool_description.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + run_id = uuid4() + + handler.on_tool_start( + serialized={"name": "web_search", "description": "Search the web"}, + input_str="query text", + run_id=run_id, + ) + + sm.start_span.assert_called_once() + call_kwargs = sm.start_span.call_args[1] + + assert call_kwargs["name"] == f"{OP_EXECUTE_TOOL} web_search" + assert call_kwargs["operation"] == OP_EXECUTE_TOOL + assert call_kwargs["run_id"] == str(run_id) + + attrs = call_kwargs["attributes"] + assert attrs[GenAI.GEN_AI_OPERATION_NAME] == OP_EXECUTE_TOOL + assert attrs[GenAI.GEN_AI_TOOL_NAME] == "web_search" + assert attrs[GenAI.GEN_AI_TOOL_DESCRIPTION] == "Search the web" + + def test_sets_tool_call_id_from_inputs(self, monkeypatch): + """tool_call_id is set when present in the inputs dict.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "calc"}, + input_str="", + run_id=uuid4(), + inputs={"tool_call_id": "call_abc123", "x": 42}, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_CALL_ID] == "call_abc123" + + def test_sets_tool_call_id_from_metadata(self, monkeypatch): + """tool_call_id falls back to metadata when not in inputs.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "calc"}, + input_str="", + run_id=uuid4(), + metadata={"tool_call_id": "call_meta_456"}, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_CALL_ID] == "call_meta_456" + + def test_includes_tool_arguments_when_content_recording_enabled( + self, monkeypatch + ): + """Tool call arguments are recorded when content recording is on.""" + _enable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "calc"}, + input_str="fallback", + run_id=uuid4(), + inputs={"tool_call_id": "call_1", "x": 42, "op": "add"}, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + arguments = json.loads(attrs[GenAI.GEN_AI_TOOL_CALL_ARGUMENTS]) + assert arguments == {"x": 42, "op": "add"} + # tool_call_id should be excluded from arguments + assert "tool_call_id" not in arguments + + def test_uses_input_str_as_arguments_fallback(self, monkeypatch): + """input_str is used as arguments when inputs has no useful data.""" + _enable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "calc"}, + input_str="raw query text", + run_id=uuid4(), + inputs={}, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_CALL_ARGUMENTS] == "raw query text" + + def test_redacts_tool_arguments_when_content_recording_disabled( + self, monkeypatch + ): + """Tool arguments are not recorded when content recording is off.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "calc"}, + input_str="secret input", + run_id=uuid4(), + inputs={"x": 42}, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert GenAI.GEN_AI_TOOL_CALL_ARGUMENTS not in attrs + + def test_inherits_provider_from_parent_span(self, monkeypatch): + """Provider name is inherited from the parent span record.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + parent_run_id = uuid4() + + parent_record = _make_span_record( + parent_run_id, + attributes={GenAI.GEN_AI_PROVIDER_NAME: "openai"}, + ) + sm.get_record.return_value = parent_record + + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "tool"}, + input_str="", + run_id=uuid4(), + parent_run_id=parent_run_id, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_PROVIDER_NAME] == "openai" + + def test_no_provider_when_parent_has_none(self, monkeypatch): + """No provider attribute when parent record has no provider.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + parent_run_id = uuid4() + + parent_record = _make_span_record(parent_run_id, attributes={}) + sm.get_record.return_value = parent_record + + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "tool"}, + input_str="", + run_id=uuid4(), + parent_run_id=parent_run_id, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert GenAI.GEN_AI_PROVIDER_NAME not in attrs + + def test_noop_when_span_manager_is_none(self): + """No exception or span creation when span_manager is None.""" + handler = _make_handler(span_manager=None) + + # Should not raise + handler.on_tool_start( + serialized={"name": "tool"}, + input_str="query", + run_id=uuid4(), + ) + + def test_sets_tool_type_from_serialized(self, monkeypatch): + """Tool type is set from serialized dict.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={"name": "search", "type": "function"}, + input_str="", + run_id=uuid4(), + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_TYPE] == "function" + + def test_resolves_tool_name_from_metadata(self, monkeypatch): + """Tool name falls back to metadata when not in serialized.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={}, + input_str="", + run_id=uuid4(), + metadata={"tool_name": "my_custom_tool"}, + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_NAME] == "my_custom_tool" + + def test_resolves_tool_name_from_kwargs(self, monkeypatch): + """Tool name falls back to kwargs when not in serialized or metadata.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={}, + input_str="", + run_id=uuid4(), + name="kwargs_tool", + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_NAME] == "kwargs_tool" + + def test_defaults_tool_name_to_unknown(self, monkeypatch): + """Tool name defaults to 'unknown_tool' when no source provides it.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + + handler.on_tool_start( + serialized={}, + input_str="", + run_id=uuid4(), + ) + + attrs = sm.start_span.call_args[1]["attributes"] + assert attrs[GenAI.GEN_AI_TOOL_NAME] == "unknown_tool" + + +# --------------------------------------------------------------------------- +# on_tool_end +# --------------------------------------------------------------------------- + + +class TestOnToolEnd: + """Tests for the on_tool_end callback method.""" + + def test_sets_tool_result_when_content_recording_enabled( + self, monkeypatch + ): + """Tool result is set as span attribute when content recording is on.""" + _enable_content_recording(monkeypatch) + sm = _make_span_manager() + run_id = uuid4() + record = _make_span_record(run_id) + sm.get_record.return_value = record + + handler = _make_handler(span_manager=sm) + handler.on_tool_end(output={"answer": 42}, run_id=run_id) + + record.span.set_attribute.assert_called_once() + key, value = record.span.set_attribute.call_args.args + assert key == GenAI.GEN_AI_TOOL_CALL_RESULT + assert json.loads(value) == {"answer": 42} + + def test_redacts_tool_result_when_content_recording_disabled( + self, monkeypatch + ): + """Tool result is not recorded when content recording is off.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + run_id = uuid4() + record = _make_span_record(run_id) + sm.get_record.return_value = record + + handler = _make_handler(span_manager=sm) + handler.on_tool_end(output={"secret": "data"}, run_id=run_id) + + record.span.set_attribute.assert_not_called() + + def test_ends_span_with_ok_status(self, monkeypatch): + """Span is ended with OK status.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + run_id = uuid4() + record = _make_span_record(run_id) + sm.get_record.return_value = record + + handler = _make_handler(span_manager=sm) + handler.on_tool_end(output="result", run_id=run_id) + + sm.end_span.assert_called_once_with( + run_id=str(run_id), status=StatusCode.OK + ) + + def test_noop_when_record_not_found(self, monkeypatch): + """No-op when span record is not found for the run_id.""" + _disable_content_recording(monkeypatch) + sm = _make_span_manager() + sm.get_record.return_value = None + + handler = _make_handler(span_manager=sm) + handler.on_tool_end(output="result", run_id=uuid4()) + + sm.end_span.assert_not_called() + + def test_noop_when_span_manager_is_none(self): + """No exception when span_manager is None.""" + handler = _make_handler(span_manager=None) + handler.on_tool_end(output="result", run_id=uuid4()) + + def test_handles_non_json_serializable_output(self, monkeypatch): + """Non-JSON-serializable output falls back to str().""" + _enable_content_recording(monkeypatch) + sm = _make_span_manager() + run_id = uuid4() + record = _make_span_record(run_id) + sm.get_record.return_value = record + + class Custom: + def __str__(self): + return "custom_output" + + handler = _make_handler(span_manager=sm) + handler.on_tool_end(output=Custom(), run_id=run_id) + + record.span.set_attribute.assert_called_once() + call_args = record.span.set_attribute.call_args + assert call_args[0][0] == GenAI.GEN_AI_TOOL_CALL_RESULT + + def test_skips_result_when_output_is_none(self, monkeypatch): + """No result attribute when output is None, even with content recording on.""" + _enable_content_recording(monkeypatch) + sm = _make_span_manager() + run_id = uuid4() + record = _make_span_record(run_id) + sm.get_record.return_value = record + + handler = _make_handler(span_manager=sm) + handler.on_tool_end(output=None, run_id=run_id) + + record.span.set_attribute.assert_not_called() + sm.end_span.assert_called_once_with( + run_id=str(run_id), status=StatusCode.OK + ) + + +# --------------------------------------------------------------------------- +# on_tool_error +# --------------------------------------------------------------------------- + + +class TestOnToolError: + """Tests for the on_tool_error callback method.""" + + def test_ends_span_with_error(self): + """Span is ended with the error passed through.""" + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + run_id = uuid4() + error = ValueError("tool failed") + + handler.on_tool_error(error=error, run_id=run_id) + + sm.end_span.assert_called_once_with(run_id=str(run_id), error=error) + + def test_noop_when_span_manager_is_none(self): + """No exception when span_manager is None.""" + handler = _make_handler(span_manager=None) + handler.on_tool_error(error=RuntimeError("boom"), run_id=uuid4()) + + def test_noop_when_record_not_found(self): + """end_span is still called (it handles missing records internally).""" + sm = _make_span_manager() + handler = _make_handler(span_manager=sm) + run_id = uuid4() + error = RuntimeError("oops") + + handler.on_tool_error(error=error, run_id=run_id) + + sm.end_span.assert_called_once_with(run_id=str(run_id), error=error) diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_usage_propagation.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_usage_propagation.py new file mode 100644 index 0000000000..99b7a5ade3 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_usage_propagation.py @@ -0,0 +1,237 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for token usage accumulation from LLM spans to parent agent spans.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +from opentelemetry.instrumentation.langchain.semconv_attributes import ( + OP_CHAT, + OP_INVOKE_AGENT, +) +from opentelemetry.instrumentation.langchain.span_manager import ( + SpanRecord, + _SpanManager, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) + +INPUT_TOKENS = GenAI.GEN_AI_USAGE_INPUT_TOKENS +OUTPUT_TOKENS = GenAI.GEN_AI_USAGE_OUTPUT_TOKENS + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_mock_span(): + span = MagicMock() + span.set_attribute = MagicMock() + return span + + +def _make_tracer(): + """Return a mock Tracer whose start_span returns fresh mock spans.""" + tracer = MagicMock() + tracer.start_span = MagicMock(side_effect=lambda **kw: _make_mock_span()) + return tracer + + +def _make_manager(): + tracer = _make_tracer() + return _SpanManager(tracer), tracer + + +def _register_record(mgr, run_id, operation, parent_run_id=None): + """Register a SpanRecord directly in the manager for test isolation.""" + rid = str(run_id) + prid = str(parent_run_id) if parent_run_id is not None else None + span = _make_mock_span() + record = SpanRecord( + run_id=rid, + span=span, + operation=operation, + parent_run_id=prid, + ) + mgr._spans[rid] = record + return record + + +# ------------------------------------------------------------------ +# accumulate_usage_to_parent +# ------------------------------------------------------------------ + + +class TestAccumulateUsageToParent: + """Tests for _SpanManager.accumulate_usage_to_parent.""" + + def test_accumulates_tokens_on_nearest_agent_parent(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=10, output_tokens=20 + ) + + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 10) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 20) + + def test_accumulates_across_multiple_llm_calls(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm1_id = str(uuid4()) + llm2_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm1_rec = _register_record( + mgr, llm1_id, OP_CHAT, parent_run_id=agent_id + ) + llm2_rec = _register_record( + mgr, llm2_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm1_rec, input_tokens=10, output_tokens=5 + ) + mgr.accumulate_usage_to_parent( + llm2_rec, input_tokens=20, output_tokens=15 + ) + + # After two calls the values should be additive. + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 30) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 20) + + def test_skips_non_agent_parents(self): + """Walk up through a non-agent (chat) intermediate to the agent.""" + mgr, _ = _make_manager() + agent_id = str(uuid4()) + chain_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + _register_record(mgr, chain_id, OP_CHAT, parent_run_id=agent_id) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=chain_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=7, output_tokens=3 + ) + + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 7) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 3) + + def test_noop_when_both_tokens_none(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=None, output_tokens=None + ) + + agent_rec.span.set_attribute.assert_not_called() + + def test_handles_only_input_tokens(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=42, output_tokens=None + ) + + agent_rec.span.set_attribute.assert_called_once_with(INPUT_TOKENS, 42) + + def test_handles_only_output_tokens(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + llm_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + llm_rec = _register_record( + mgr, llm_id, OP_CHAT, parent_run_id=agent_id + ) + + mgr.accumulate_usage_to_parent( + llm_rec, input_tokens=None, output_tokens=99 + ) + + agent_rec.span.set_attribute.assert_called_once_with(OUTPUT_TOKENS, 99) + + +# ------------------------------------------------------------------ +# accumulate_llm_usage_to_agent +# ------------------------------------------------------------------ + + +class TestAccumulateLlmUsageToAgent: + """Tests for _SpanManager.accumulate_llm_usage_to_agent.""" + + def test_resolves_through_ignored_runs(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + ignored_id = str(uuid4()) + + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + mgr.ignore_run(ignored_id, parent_run_id=agent_id) + + mgr.accumulate_llm_usage_to_agent( + parent_run_id=ignored_id, input_tokens=15, output_tokens=25 + ) + + agent_rec.span.set_attribute.assert_any_call(INPUT_TOKENS, 15) + agent_rec.span.set_attribute.assert_any_call(OUTPUT_TOKENS, 25) + + def test_noop_when_parent_run_id_is_none(self): + mgr, _ = _make_manager() + agent_id = str(uuid4()) + agent_rec = _register_record(mgr, agent_id, OP_INVOKE_AGENT) + + mgr.accumulate_llm_usage_to_agent( + parent_run_id=None, input_tokens=10, output_tokens=20 + ) + + agent_rec.span.set_attribute.assert_not_called() + + def test_noop_when_no_agent_in_chain(self): + mgr, _ = _make_manager() + chat_id = str(uuid4()) + + chat_rec = _register_record(mgr, chat_id, OP_CHAT) + + mgr.accumulate_llm_usage_to_agent( + parent_run_id=chat_id, input_tokens=10, output_tokens=20 + ) + + chat_rec.span.set_attribute.assert_not_called() diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_utils.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_utils.py new file mode 100644 index 0000000000..5a166306dc --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_utils.py @@ -0,0 +1,323 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from opentelemetry.instrumentation.langchain.utils import ( + infer_provider_name, + infer_server_address, + infer_server_port, +) + +# --------------------------------------------------------------------------- +# infer_provider_name +# --------------------------------------------------------------------------- + + +class TestInferProviderNameFromMetadata: + """Provider resolution via metadata ls_provider field.""" + + @pytest.mark.parametrize( + "ls_provider, expected", + [ + ("openai", "openai"), + ("anthropic", "anthropic"), + ("cohere", "cohere"), + ("ollama", "ollama"), + ], + ) + def test_direct_mapping(self, ls_provider, expected): + metadata = {"ls_provider": ls_provider} + assert infer_provider_name({}, metadata, {}) == expected + + @pytest.mark.parametrize( + "ls_provider", + ["azure", "azure_openai"], + ) + def test_azure_variants(self, ls_provider): + metadata = {"ls_provider": ls_provider} + assert infer_provider_name({}, metadata, {}) == "azure.ai.openai" + + def test_github_maps_to_azure(self): + metadata = {"ls_provider": "github"} + assert infer_provider_name({}, metadata, {}) == "azure.ai.openai" + + @pytest.mark.parametrize( + "ls_provider", + ["amazon_bedrock", "bedrock", "aws_bedrock"], + ) + def test_bedrock_variants(self, ls_provider): + metadata = {"ls_provider": ls_provider} + assert infer_provider_name({}, metadata, {}) == "aws.bedrock" + + def test_google(self): + metadata = {"ls_provider": "google"} + assert infer_provider_name({}, metadata, {}) == "gcp.gen_ai" + + +class TestInferProviderNameFromBaseUrl: + """Provider resolution via base_url in invocation_params.""" + + @pytest.mark.parametrize( + "url, expected", + [ + ("https://my-resource.openai.azure.com/v1", "azure.ai.openai"), + ("https://api.openai.com/v1", "openai"), + ( + "https://bedrock-runtime.us-east-1.amazonaws.com", + "aws.bedrock", + ), + ("https://api.anthropic.com/v1", "anthropic"), + ( + "https://us-central1-aiplatform.googleapis.com", + "gcp.gen_ai", + ), + ], + ) + def test_url_patterns(self, url, expected): + invocation_params = {"base_url": url} + assert infer_provider_name({}, {}, invocation_params) == expected + + def test_azure_keyword_in_url(self): + invocation_params = { + "base_url": "https://custom-azure-endpoint.example.com/v1" + } + assert ( + infer_provider_name({}, {}, invocation_params) == "azure.ai.openai" + ) + + def test_ollama_keyword_in_url(self): + invocation_params = { + "base_url": "https://my-ollama-server.local:11434/api" + } + assert infer_provider_name({}, {}, invocation_params) == "ollama" + + def test_amazonaws_in_url(self): + invocation_params = { + "base_url": "https://runtime.sagemaker.us-west-2.amazonaws.com" + } + assert infer_provider_name({}, {}, invocation_params) == "aws.bedrock" + + def test_openai_com_in_url(self): + invocation_params = {"base_url": "https://api.openai.com/v2/chat"} + assert infer_provider_name({}, {}, invocation_params) == "openai" + + +class TestInferProviderNameFromSerializedClassName: + """Provider resolution via serialized name/id fields.""" + + @pytest.mark.parametrize( + "class_name, expected", + [ + ("ChatOpenAI", "openai"), + ("ChatBedrock", "aws.bedrock"), + ("ChatAnthropic", "anthropic"), + ("ChatGoogleGenerativeAI", "gcp.gen_ai"), + ], + ) + def test_class_name(self, class_name, expected): + serialized = {"name": class_name} + assert infer_provider_name(serialized, {}, {}) == expected + + @pytest.mark.parametrize( + "class_name, expected", + [ + ("ChatOpenAI", "openai"), + ("ChatBedrock", "aws.bedrock"), + ("ChatAnthropic", "anthropic"), + ("ChatGoogleGenerativeAI", "gcp.gen_ai"), + ], + ) + def test_class_name_via_id(self, class_name, expected): + serialized = {"id": ["langchain_openai", "chat_models", class_name]} + assert infer_provider_name(serialized, {}, {}) == expected + + +class TestInferProviderNameFromSerializedKwargs: + """Provider resolution via kwargs in serialized dict.""" + + def test_azure_endpoint_kwarg(self): + serialized = { + "kwargs": {"azure_endpoint": "https://my-model.openai.azure.com/"} + } + assert infer_provider_name(serialized, {}, {}) == "azure.ai.openai" + + +class TestInferProviderNameReturnsNone: + """Returns None when no provider signals are available.""" + + def test_empty_inputs(self): + assert infer_provider_name({}, {}, {}) is None + + def test_none_inputs(self): + assert infer_provider_name({}, None, None) is None + + def test_unrecognized_metadata(self): + metadata = {"ls_provider": "some_unknown_provider"} + assert infer_provider_name({}, metadata, {}) is None + + def test_unrecognized_url(self): + invocation_params = {"base_url": "https://custom-llm.example.com/v1"} + assert infer_provider_name({}, {}, invocation_params) is None + + def test_unrecognized_class_name(self): + serialized = {"name": "ChatCustomLLM"} + assert infer_provider_name(serialized, {}, {}) is None + + +class TestInferProviderNamePriority: + """Metadata takes priority over invocation_params over serialized.""" + + def test_metadata_over_invocation_params(self): + metadata = {"ls_provider": "anthropic"} + invocation_params = {"base_url": "https://api.openai.com/v1"} + assert ( + infer_provider_name({}, metadata, invocation_params) == "anthropic" + ) + + def test_metadata_over_serialized(self): + metadata = {"ls_provider": "anthropic"} + serialized = {"name": "ChatOpenAI"} + assert infer_provider_name(serialized, metadata, {}) == "anthropic" + + def test_invocation_params_over_serialized(self): + invocation_params = {"base_url": "https://api.anthropic.com/v1"} + serialized = {"name": "ChatOpenAI"} + assert ( + infer_provider_name(serialized, {}, invocation_params) + == "anthropic" + ) + + +# --------------------------------------------------------------------------- +# infer_server_address +# --------------------------------------------------------------------------- + + +class TestInferServerAddress: + """Extract hostname from various URL sources.""" + + def test_from_invocation_params_base_url(self): + invocation_params = {"base_url": "https://api.openai.com/v1"} + assert infer_server_address({}, invocation_params) == "api.openai.com" + + def test_from_serialized_openai_api_base(self): + serialized = { + "kwargs": {"openai_api_base": "https://my-model.openai.azure.com/"} + } + assert ( + infer_server_address(serialized, {}) == "my-model.openai.azure.com" + ) + + def test_from_serialized_azure_endpoint(self): + serialized = { + "kwargs": { + "azure_endpoint": "https://my-resource.openai.azure.com/" + } + } + assert ( + infer_server_address(serialized, {}) + == "my-resource.openai.azure.com" + ) + + def test_returns_none_when_no_url(self): + assert infer_server_address({}, {}) is None + + def test_returns_none_for_empty_inputs(self): + assert infer_server_address({}, None) is None + + def test_returns_none_for_none_serialized_kwargs(self): + serialized = {"kwargs": {}} + assert infer_server_address(serialized, {}) is None + + def test_strips_port_from_hostname(self): + invocation_params = {"base_url": "http://localhost:11434/v1"} + assert infer_server_address({}, invocation_params) == "localhost" + + def test_handles_url_with_path(self): + invocation_params = { + "base_url": "https://api.openai.com/v1/chat/completions" + } + assert infer_server_address({}, invocation_params) == "api.openai.com" + + def test_handles_malformed_url(self): + invocation_params = {"base_url": "not-a-valid-url"} + result = infer_server_address({}, invocation_params) + # Should not raise; either returns None or a best-effort parse + assert result is None or isinstance(result, str) + + def test_handles_empty_string_url(self): + invocation_params = {"base_url": ""} + result = infer_server_address({}, invocation_params) + assert result is None or isinstance(result, str) + + def test_invocation_params_base_url_takes_priority(self): + serialized = { + "kwargs": {"openai_api_base": "https://fallback.example.com/v1"} + } + invocation_params = {"base_url": "https://primary.example.com/v1"} + assert ( + infer_server_address(serialized, invocation_params) + == "primary.example.com" + ) + + +# --------------------------------------------------------------------------- +# infer_server_port +# --------------------------------------------------------------------------- + + +class TestInferServerPort: + """Extract port from URL sources.""" + + def test_explicit_port(self): + invocation_params = {"base_url": "http://localhost:11434/v1"} + assert infer_server_port({}, invocation_params) == 11434 + + def test_no_explicit_port_returns_none(self): + invocation_params = {"base_url": "https://api.openai.com/v1"} + assert infer_server_port({}, invocation_params) is None + + def test_standard_http_port_returned_when_explicit(self): + # urlparse returns port when explicitly specified, even if standard + invocation_params = {"base_url": "http://api.example.com:80/v1"} + assert infer_server_port({}, invocation_params) == 80 + + def test_standard_https_port_returned_when_explicit(self): + invocation_params = {"base_url": "https://api.example.com:443/v1"} + assert infer_server_port({}, invocation_params) == 443 + + def test_custom_port(self): + invocation_params = {"base_url": "https://api.example.com:8443/v1"} + assert infer_server_port({}, invocation_params) == 8443 + + def test_returns_none_when_no_url(self): + assert infer_server_port({}, {}) is None + + def test_returns_none_for_none_inputs(self): + assert infer_server_port({}, None) is None + + def test_port_from_serialized_openai_api_base(self): + serialized = { + "kwargs": {"openai_api_base": "http://localhost:8080/v1"} + } + assert infer_server_port(serialized, {}) == 8080 + + def test_port_from_serialized_azure_endpoint(self): + serialized = { + "kwargs": { + "azure_endpoint": "https://my-resource.openai.azure.com:9090/" + } + } + assert infer_server_port(serialized, {}) == 9090 diff --git a/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_w3c_propagation.py b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_w3c_propagation.py new file mode 100644 index 0000000000..d303642f33 --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-langchain/tests/test_w3c_propagation.py @@ -0,0 +1,200 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from opentelemetry.instrumentation.langchain.utils import ( + extract_propagation_context, + extract_trace_headers, + propagated_context, +) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace import get_current_span + +# Valid W3C traceparent components for test fixtures. +_TRACE_ID = "0af7651916cd43dd8448eb211c80319c" +_SPAN_ID = "b7ad6b7169203331" +_TRACEPARENT = f"00-{_TRACE_ID}-{_SPAN_ID}-01" +_TRACESTATE = "congo=t61rcWkgMzE" + + +# --------------------------------------------------------------------------- +# extract_trace_headers +# --------------------------------------------------------------------------- + + +class TestExtractTraceHeaders: + """Tests for extract_trace_headers().""" + + def test_top_level_traceparent(self): + container = {"traceparent": _TRACEPARENT} + result = extract_trace_headers(container) + assert result == {"traceparent": _TRACEPARENT} + + def test_top_level_traceparent_and_tracestate(self): + container = {"traceparent": _TRACEPARENT, "tracestate": _TRACESTATE} + result = extract_trace_headers(container) + assert result == { + "traceparent": _TRACEPARENT, + "tracestate": _TRACESTATE, + } + + def test_nested_headers_key(self): + container = {"headers": {"traceparent": _TRACEPARENT}} + result = extract_trace_headers(container) + assert result == {"traceparent": _TRACEPARENT} + + def test_nested_metadata_key(self): + container = { + "metadata": { + "traceparent": _TRACEPARENT, + "tracestate": _TRACESTATE, + } + } + result = extract_trace_headers(container) + assert result == { + "traceparent": _TRACEPARENT, + "tracestate": _TRACESTATE, + } + + def test_nested_request_headers_key(self): + container = {"request_headers": {"traceparent": _TRACEPARENT}} + result = extract_trace_headers(container) + assert result == {"traceparent": _TRACEPARENT} + + def test_empty_container_returns_none(self): + assert extract_trace_headers({}) is None + + def test_no_trace_headers_returns_none(self): + container = {"foo": "bar", "headers": {"content-type": "text/plain"}} + assert extract_trace_headers(container) is None + + def test_non_dict_container_returns_none(self): + assert extract_trace_headers(None) is None + assert extract_trace_headers("string") is None + assert extract_trace_headers(42) is None + assert extract_trace_headers(["traceparent", _TRACEPARENT]) is None + + def test_empty_string_traceparent_ignored(self): + container = {"traceparent": ""} + assert extract_trace_headers(container) is None + + def test_top_level_takes_precedence_over_nested(self): + other_traceparent = ( + "00-11111111111111111111111111111111-2222222222222222-01" + ) + container = { + "traceparent": _TRACEPARENT, + "headers": {"traceparent": other_traceparent}, + } + result = extract_trace_headers(container) + assert result["traceparent"] == _TRACEPARENT + + +# --------------------------------------------------------------------------- +# propagated_context +# --------------------------------------------------------------------------- + + +class TestPropagatedContext: + """Tests for the propagated_context() context manager.""" + + def test_noop_when_headers_is_none(self): + with propagated_context(None): + span = get_current_span() + assert not span.get_span_context().is_valid + + def test_noop_when_headers_is_empty(self): + with propagated_context({}): + span = get_current_span() + assert not span.get_span_context().is_valid + + def test_attaches_and_detaches_valid_traceparent(self): + provider = TracerProvider() + tracer = provider.get_tracer("test") + + with tracer.start_as_current_span("outer"): + outer_ctx = get_current_span().get_span_context() + + headers = {"traceparent": _TRACEPARENT} + with propagated_context(headers): + inner_ctx = get_current_span().get_span_context() + # The propagated context should carry the injected trace id. + assert format(inner_ctx.trace_id, "032x") == _TRACE_ID + + # After exiting, we should be back to the outer span. + restored_ctx = get_current_span().get_span_context() + assert restored_ctx.trace_id == outer_ctx.trace_id + + provider.shutdown() + + def test_invalid_traceparent_does_not_crash(self): + headers = {"traceparent": "not-a-valid-traceparent"} + with propagated_context(headers): + # Should execute without raising; span context may be invalid. + span = get_current_span() + assert span is not None + + def test_malformed_traceparent_does_not_crash(self): + headers = {"traceparent": "00-short-bad-01"} + with propagated_context(headers): + span = get_current_span() + assert span is not None + + +# --------------------------------------------------------------------------- +# extract_propagation_context +# --------------------------------------------------------------------------- + + +class TestExtractPropagationContext: + """Tests for extract_propagation_context().""" + + def test_finds_headers_in_metadata(self): + metadata = {"traceparent": _TRACEPARENT} + result = extract_propagation_context(metadata, {}, {}) + assert result == {"traceparent": _TRACEPARENT} + + def test_falls_back_to_inputs(self): + inputs = {"traceparent": _TRACEPARENT} + result = extract_propagation_context({}, inputs, {}) + assert result == {"traceparent": _TRACEPARENT} + + def test_falls_back_to_kwargs(self): + kwargs = {"traceparent": _TRACEPARENT} + result = extract_propagation_context({}, {}, kwargs) + assert result == {"traceparent": _TRACEPARENT} + + def test_returns_none_when_no_source_has_headers(self): + result = extract_propagation_context({}, {}, {}) + assert result is None + + def test_metadata_takes_precedence_over_inputs(self): + other_traceparent = ( + "00-11111111111111111111111111111111-2222222222222222-01" + ) + metadata = {"traceparent": _TRACEPARENT} + inputs = {"traceparent": other_traceparent} + result = extract_propagation_context(metadata, inputs, {}) + assert result["traceparent"] == _TRACEPARENT + + def test_none_sources_are_skipped(self): + kwargs = {"traceparent": _TRACEPARENT} + result = extract_propagation_context(None, None, kwargs) + assert result == {"traceparent": _TRACEPARENT} + + def test_non_dict_inputs_skipped(self): + result = extract_propagation_context( + None, "not a dict", {"traceparent": _TRACEPARENT} + ) + assert result == {"traceparent": _TRACEPARENT}