Skip to content

Commit a769522

Browse files
authored
Remove run_id from GenAI Utils (#228)
* lint updates, test fixes * remove auto generation of agent_id * resolve rebase conflicts * restore context clarification and tests * clean up dead code in llamaindex, align agent_id references
1 parent 4630fc0 commit a769522

28 files changed

Lines changed: 308 additions & 414 deletions

File tree

instrumentation-genai/opentelemetry-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/callback_handler.py

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import json
10-
from typing import Any, Optional, List
10+
from typing import Any, Optional, List, Dict
1111
from uuid import UUID
1212

1313
from langchain_core.callbacks import BaseCallbackHandler
@@ -296,13 +296,48 @@ def _extract_tool_details(
296296
}
297297

298298

299+
def _agent_span_id(agent: AgentInvocation) -> Optional[str]:
300+
"""Return the agent's span ID as a hex string, if available."""
301+
if agent.span_id is not None:
302+
return f"{agent.span_id:016x}"
303+
return None
304+
305+
306+
class _InvocationManager:
307+
"""Local store mapping LangChain run_ids to GenAI invocation objects.
308+
309+
Keyed by LangChain's own run_id UUID (from callback API), not the
310+
GenAI run_id field. Tracks parent-child relationships so the parent
311+
chain can be walked without a central registry.
312+
"""
313+
314+
def __init__(self) -> None:
315+
self._invocations: Dict[UUID, Any] = {}
316+
self._parents: Dict[UUID, Optional[UUID]] = {}
317+
318+
def add(self, run_id: UUID, parent_run_id: Optional[UUID], invocation: Any) -> None:
319+
self._invocations[run_id] = invocation
320+
self._parents[run_id] = parent_run_id
321+
322+
def get(self, run_id: UUID) -> Any:
323+
return self._invocations.get(run_id)
324+
325+
def get_parent_id(self, run_id: UUID) -> Optional[UUID]:
326+
return self._parents.get(run_id)
327+
328+
def remove(self, run_id: UUID) -> None:
329+
self._invocations.pop(run_id, None)
330+
self._parents.pop(run_id, None)
331+
332+
299333
class LangchainCallbackHandler(BaseCallbackHandler):
300334
def __init__(
301335
self,
302336
telemetry_handler: Optional[TelemetryHandler] = None,
303337
) -> None:
304338
super().__init__()
305339
self._handler = telemetry_handler
340+
self._invocation_manager = _InvocationManager()
306341
# Tracks ContextVar state before we push an inferred conversation_id
307342
# so it can be restored when the root entity finishes.
308343
self._inferred_context_prev: dict[UUID, GenAIContext | None] = {}
@@ -321,12 +356,12 @@ def _find_nearest_agent(self, run_id: Optional[UUID]) -> Optional[AgentInvocatio
321356
visited = set()
322357
while current is not None and current not in visited:
323358
visited.add(current)
324-
entity = self._handler.get_entity(current)
359+
entity = self._invocation_manager.get(current)
325360
if isinstance(entity, AgentInvocation):
326361
return entity
327362
if entity is None:
328363
break
329-
current = getattr(entity, "parent_run_id", None)
364+
current = self._invocation_manager.get_parent_id(current)
330365
return None
331366

332367
def _start_agent_invocation(
@@ -344,12 +379,10 @@ def _start_agent_invocation(
344379
) -> AgentInvocation:
345380
agent = AgentInvocation(
346381
name=name,
347-
run_id=run_id,
348382
attributes=attrs,
349383
)
350384
agent.input_messages = command_input_messages or _make_input_message(inputs)
351385
agent.agent_name = _safe_str(agent_name) if agent_name else name
352-
agent.parent_run_id = parent_run_id
353386
agent.framework = "langchain"
354387
if conversation_id:
355388
agent.conversation_id = conversation_id
@@ -361,6 +394,7 @@ def _start_agent_invocation(
361394
if metadata.get("system"):
362395
agent.system = _safe_str(metadata["system"])
363396
self._handler.start_agent(agent)
397+
self._invocation_manager.add(run_id, parent_run_id, agent)
364398
return agent
365399

366400
def on_chain_start(
@@ -436,7 +470,7 @@ def on_chain_start(
436470
if is_resume:
437471
agent.attributes[GEN_AI_WORKFLOW_COMMAND] = "resume"
438472
else:
439-
wf = Workflow(name=name, run_id=run_id, attributes=attrs)
473+
wf = Workflow(name=name, attributes=attrs)
440474
wf.input_messages = command_input_messages or _make_input_message(
441475
inputs
442476
)
@@ -445,12 +479,16 @@ def on_chain_start(
445479
if is_resume:
446480
wf.attributes[GEN_AI_WORKFLOW_COMMAND] = "resume"
447481
self._handler.start_workflow(wf)
482+
self._invocation_manager.add(run_id, None, wf)
448483
return
449484
else:
450485
# Skip if parent entity no longer exists (e.g., LangGraph
451486
# replays the interrupted node during resume — the parent
452487
# workflow from the previous trace was already ended).
453-
if self._handler.get_entity(parent_run_id) is None:
488+
# TODO: _invocation_manager is per-handler instance; consider
489+
# making LangchainCallbackHandler a singleton so all
490+
# invocations share the same manager.
491+
if self._invocation_manager.get(parent_run_id) is None:
454492
return
455493

456494
context_agent = self._find_nearest_agent(parent_run_id)
@@ -477,7 +515,7 @@ def on_chain_start(
477515
return
478516
tool_info = _extract_tool_details(metadata)
479517
if tool_info is not None:
480-
existing = self._handler.get_entity(run_id)
518+
existing = self._invocation_manager.get(run_id)
481519
if isinstance(existing, ToolCall):
482520
tool = existing
483521
if context_agent is not None:
@@ -487,7 +525,7 @@ def on_chain_start(
487525
if not getattr(tool, "agent_name", None):
488526
tool.agent_name = _safe_str(agent_name_value)
489527
if not getattr(tool, "agent_id", None):
490-
tool.agent_id = str(context_agent.run_id)
528+
tool.agent_id = _agent_span_id(context_agent)
491529
else:
492530
# Filter out tool-specific metadata from attributes
493531
# since they're stored in dedicated fields
@@ -505,15 +543,14 @@ def on_chain_start(
505543
name=tool_info.get("name", name),
506544
id=tool_info.get("id"),
507545
arguments=arguments,
508-
run_id=run_id,
509-
parent_run_id=parent_run_id,
510546
attributes=tool_attrs,
511547
)
512548
tool.framework = "langchain"
513549
if context_agent is not None and context_agent_name is not None:
514550
tool.agent_name = context_agent_name
515-
tool.agent_id = str(context_agent.run_id)
551+
tool.agent_id = _agent_span_id(context_agent)
516552
self._handler.start_tool_call(tool)
553+
self._invocation_manager.add(run_id, parent_run_id, tool)
517554
if inputs is not None and getattr(tool, "arguments", None) is None:
518555
tool.arguments = inputs
519556
if getattr(tool, "arguments", None) is not None:
@@ -523,16 +560,15 @@ def on_chain_start(
523560
else:
524561
step = Step(
525562
name=name,
526-
run_id=run_id,
527-
parent_run_id=parent_run_id,
528563
step_type="chain",
529564
attributes=attrs,
530565
)
531566
if context_agent is not None:
532567
if context_agent_name is not None:
533568
step.agent_name = context_agent_name
534-
step.agent_id = str(context_agent.run_id)
569+
step.agent_id = _agent_span_id(context_agent)
535570
self._handler.start_step(step)
571+
self._invocation_manager.add(run_id, parent_run_id, step)
536572

537573
def on_chain_end(
538574
self,
@@ -542,7 +578,7 @@ def on_chain_end(
542578
parent_run_id: Optional[UUID] = None,
543579
**_kwargs: Any,
544580
) -> None:
545-
entity = self._handler.get_entity(run_id)
581+
entity = self._invocation_manager.get(run_id)
546582
if entity is None:
547583
return
548584

@@ -571,6 +607,7 @@ def on_chain_end(
571607
if serialized is not None:
572608
entity.attributes.setdefault("tool.response", serialized)
573609
self._handler.stop_tool_call(entity)
610+
self._invocation_manager.remove(run_id)
574611

575612
def on_chat_model_start(
576613
self,
@@ -656,8 +693,6 @@ def on_chat_model_start(
656693
request_model=request_model,
657694
input_messages=input_messages,
658695
attributes=attrs,
659-
run_id=run_id,
660-
parent_run_id=parent_run_id,
661696
)
662697
if provider:
663698
inv.provider = provider
@@ -666,8 +701,9 @@ def on_chat_model_start(
666701
if context_agent is not None:
667702
agent_name_value = context_agent.agent_name or context_agent.name
668703
inv.agent_name = _safe_str(agent_name_value)
669-
inv.agent_id = str(context_agent.run_id)
704+
inv.agent_id = _agent_span_id(context_agent)
670705
self._handler.start_llm(inv)
706+
self._invocation_manager.add(run_id, parent_run_id, inv)
671707

672708
def on_llm_start(
673709
self,
@@ -690,7 +726,7 @@ def on_llm_start(
690726
metadata=metadata,
691727
**extra,
692728
)
693-
inv = self._handler.get_entity(run_id)
729+
inv = self._invocation_manager.get(run_id)
694730
if isinstance(inv, LLMInvocation):
695731
inv.operation = "generate_text"
696732

@@ -702,7 +738,7 @@ def on_llm_end(
702738
parent_run_id: Optional[UUID] = None,
703739
**_kwargs: Any,
704740
) -> None:
705-
inv = self._handler.get_entity(run_id)
741+
inv = self._invocation_manager.get(run_id)
706742
if not isinstance(inv, LLMInvocation):
707743
return
708744
generations = getattr(response, "generations", [])
@@ -755,6 +791,7 @@ def on_llm_end(
755791
break
756792

757793
self._handler.stop_llm(inv)
794+
self._invocation_manager.remove(run_id)
758795

759796
def on_tool_start(
760797
self,
@@ -794,7 +831,7 @@ def on_tool_start(
794831
else:
795832
id_value = None
796833
arguments: Any = inputs if inputs is not None else input_str
797-
existing = self._handler.get_entity(run_id)
834+
existing = self._invocation_manager.get(run_id)
798835
if isinstance(existing, ToolCall):
799836
if arguments is not None:
800837
existing.arguments = arguments
@@ -807,29 +844,28 @@ def on_tool_start(
807844
):
808845
existing.agent_name = context_agent_name
809846
if not getattr(existing, "agent_id", None):
810-
existing.agent_id = str(context_agent.run_id)
847+
existing.agent_id = _agent_span_id(context_agent)
811848
if existing.framework is None:
812849
existing.framework = "langchain"
813850
return
814851
tool = ToolCall(
815852
name=name,
816853
id=id_value,
817854
arguments=arguments,
818-
run_id=run_id,
819-
parent_run_id=parent_run_id,
820855
attributes=attrs,
821856
)
822857
tool.framework = "langchain"
823858
if context_agent is not None and context_agent_name is not None:
824859
tool.agent_name = context_agent_name
825-
tool.agent_id = str(context_agent.run_id)
860+
tool.agent_id = _agent_span_id(context_agent)
826861
if arguments is not None:
827862
serialized_args = _serialize(arguments)
828863
if serialized_args is not None:
829864
tool.attributes.setdefault("tool.arguments", serialized_args)
830865
if inputs is None and input_str:
831866
tool.attributes.setdefault("tool.input_str", _safe_str(input_str))
832867
self._handler.start_tool_call(tool)
868+
self._invocation_manager.add(run_id, parent_run_id, tool)
833869

834870
def on_tool_end(
835871
self,
@@ -839,24 +875,36 @@ def on_tool_end(
839875
parent_run_id: Optional[UUID] = None,
840876
**_kwargs: Any,
841877
) -> None:
842-
tool = self._handler.get_entity(run_id)
878+
tool = self._invocation_manager.get(run_id)
843879
if not isinstance(tool, ToolCall):
844880
return
845881
serialized = _serialize(output)
846882
if serialized is not None:
847883
tool.attributes.setdefault("tool.response", serialized)
848884
self._handler.stop_tool_call(tool)
885+
self._invocation_manager.remove(run_id)
849886

850887
def _fail(self, run_id: UUID, error: BaseException) -> None:
851888
classification = _classify_error(error)
852-
self._handler.fail_by_run_id(
853-
run_id,
854-
GenAIError(
855-
message=str(error),
856-
type=type(error),
857-
classification=classification,
858-
),
889+
entity = self._invocation_manager.get(run_id)
890+
if entity is None:
891+
return
892+
genai_error = GenAIError(
893+
message=str(error),
894+
type=type(error),
895+
classification=classification,
859896
)
897+
if isinstance(entity, Workflow):
898+
self._handler.fail_workflow(entity, genai_error)
899+
elif isinstance(entity, AgentInvocation):
900+
self._handler.fail_agent(entity, genai_error)
901+
elif isinstance(entity, Step):
902+
self._handler.fail_step(entity, genai_error)
903+
elif isinstance(entity, LLMInvocation):
904+
self._handler.fail_llm(entity, genai_error)
905+
elif isinstance(entity, ToolCall):
906+
self._handler.fail_tool_call(entity, genai_error)
907+
self._invocation_manager.remove(run_id)
860908

861909
def on_llm_error(
862910
self,

0 commit comments

Comments
 (0)