Skip to content

Commit 64d8060

Browse files
jsonbaileyclaude
andcommitted
fix: collect and await judge eval tasks instead of fire-and-forget
Replace asyncio.create_task fire-and-forget with proper task collection and awaiting in both OpenAI and LangGraph runners, ensuring judge results are tracked reliably. Use ContextVar in LangGraph runner to isolate pending eval task state across concurrent run() calls. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ab4a77c commit 64d8060

2 files changed

Lines changed: 21 additions & 27 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import time
5+
from contextvars import ContextVar
56
from typing import Annotated, Any, Dict, List, Set, Tuple
67

78
from ldai import log
@@ -17,6 +18,9 @@
1718
)
1819
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
1920

21+
# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar.
22+
_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks')
23+
2024

2125
def _make_handoff_tool(child_key: str, description: str) -> Any:
2226
"""
@@ -84,7 +88,6 @@ def __init__(
8488
self._compiled: Any = None
8589
self._fn_name_to_config_key: Dict[str, str] = {}
8690
self._node_keys: Set[str] = set()
87-
self._pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
8891

8992
def _ensure_compiled(self) -> None:
9093
"""Build and cache the compiled graph if not already done."""
@@ -189,7 +192,7 @@ async def invoke(state: WorkflowState) -> dict:
189192
response.content if hasattr(response, 'content') else str(response)
190193
)
191194
task = node_obj.get_config().evaluator.evaluate(input_text, output_text)
192-
self._pending_eval_tasks.setdefault(nk, []).append(task)
195+
_run_eval_tasks.get({}).setdefault(nk, []).append(task)
193196

194197
return {'messages': [response]}
195198

@@ -299,7 +302,8 @@ async def run(self, input: Any) -> AgentGraphResult:
299302
:param input: The string prompt to send to the agent graph
300303
:return: AgentGraphResult with the final output and metrics
301304
"""
302-
self._pending_eval_tasks = {}
305+
pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
306+
token = _run_eval_tasks.set(pending_eval_tasks)
303307
tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None
304308
start_ns = time.perf_counter_ns()
305309

@@ -319,7 +323,7 @@ async def run(self, input: Any) -> AgentGraphResult:
319323
output = extract_last_message_content(messages)
320324

321325
# Flush per-node metrics to LD trackers
322-
await handler.flush(self._graph, self._pending_eval_tasks)
326+
await handler.flush(self._graph, pending_eval_tasks)
323327

324328
# Graph-level metrics
325329
if tracker:
@@ -351,3 +355,5 @@ async def run(self, input: Any) -> AgentGraphResult:
351355
raw=None,
352356
metrics=LDAIMetrics(success=False),
353357
)
358+
finally:
359+
_run_eval_tasks.reset(token)

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import re
32
import time
43
from typing import Any, Dict, List, Optional
@@ -30,6 +29,7 @@ def __init__(self, last_handoff_ns: int, last_node_key: str, input_str: str = ''
3029
self.last_handoff_ns = last_handoff_ns
3130
self.last_node_key = last_node_key
3231
self.input_str = input_str
32+
self.pending_eval_tasks: List[tuple] = []
3333

3434

3535
class OpenAIAgentGraphRunner(AgentGraphRunner):
@@ -91,6 +91,11 @@ async def run(self, input: Any) -> AgentGraphResult:
9191
root_agent = self._build_agents(path, state, tracker)
9292
result = await Runner.run(root_agent, input_str)
9393
self._flush_final_segment(state, result, input_str)
94+
for node_tracker, eval_task in state.pending_eval_tasks:
95+
eval_results = await eval_task
96+
for r in eval_results:
97+
if r.success:
98+
node_tracker.track_judge_result(r)
9499
self._track_tool_calls(result)
95100

96101
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
@@ -106,7 +111,7 @@ async def run(self, input: Any) -> AgentGraphResult:
106111
return AgentGraphResult(
107112
output=str(result.final_output),
108113
raw=result,
109-
metrics=LDAIMetrics(success=True),
114+
metrics=LDAIMetrics(success=True, usage=token_usage),
110115
)
111116
except Exception as exc:
112117
if isinstance(exc, ImportError):
@@ -270,19 +275,10 @@ def _handle_handoff(
270275
config_tracker.track_duration(int(duration_ms))
271276
config_tracker.track_success()
272277

273-
# Fire judge evaluation for the src node (fire-and-forget)
274278
src_node = self._graph.get_node(src)
275279
if src_node is not None:
276-
evaluator = src_node.get_config().evaluator
277-
# Use empty string as output since we don't have the node's final output here
278-
eval_task = evaluator.evaluate(input_str, '')
279-
280-
async def _track(trk, et):
281-
results = await et
282-
for r in results:
283-
if r.success:
284-
trk.track_judge_result(r)
285-
asyncio.create_task(_track(config_tracker, eval_task))
280+
eval_task = src_node.get_config().evaluator.evaluate(input_str, '')
281+
state.pending_eval_tasks.append((config_tracker, eval_task))
286282

287283
def _flush_final_segment(
288284
self,
@@ -313,19 +309,11 @@ def _flush_final_segment(
313309
config_tracker.track_duration(int(duration_ms))
314310
config_tracker.track_success()
315311

316-
# Fire judge evaluation for the final node (fire-and-forget)
317312
final_node = self._graph.get_node(state.last_node_key)
318313
if final_node is not None:
319-
evaluator = final_node.get_config().evaluator
320314
output_str = str(result.final_output) if result is not None else ''
321-
eval_task = evaluator.evaluate(input_str, output_str)
322-
323-
async def _track(trk, et):
324-
results = await et
325-
for r in results:
326-
if r.success:
327-
trk.track_judge_result(r)
328-
asyncio.create_task(_track(config_tracker, eval_task))
315+
eval_task = final_node.get_config().evaluator.evaluate(input_str, output_str)
316+
state.pending_eval_tasks.append((config_tracker, eval_task))
329317

330318
def _track_tool_calls(self, result: Any) -> None:
331319
"""Track all tool calls from the run result, attributed to the node that called them."""

0 commit comments

Comments
 (0)