Skip to content

Commit 3d5a6a9

Browse files
authored
feat: Add judge evaluation support to agent graphs (#142)
1 parent 2189b81 commit 3d5a6a9

21 files changed

Lines changed: 338 additions & 146 deletions

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def create_agent(self, config: Any, tools: Optional[ToolRegistry] = None) -> Lan
3939
)
4040
return LangChainAgentRunner(agent)
4141

42-
def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any:
42+
def create_agent_graph(
43+
self,
44+
graph_def: Any,
45+
tools: ToolRegistry,
46+
) -> Any:
4347
"""
4448
CAUTION:
4549
This feature is experimental and should NOT be considered ready for production use.

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
22

3+
import asyncio
34
import time
4-
from typing import Annotated, Any, Dict, List, Optional, Set, Tuple
5+
from contextvars import ContextVar
6+
from typing import Annotated, Any, Dict, List, Set, Tuple
57

68
from ldai import log
79
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
@@ -16,6 +18,9 @@
1618
)
1719
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
1820

21+
# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar.
22+
_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks')
23+
1924

2025
def _make_handoff_tool(child_key: str, description: str) -> Any:
2126
"""
@@ -67,7 +72,11 @@ class LangGraphAgentGraphRunner(AgentGraphRunner):
6772
Requires ``langgraph`` to be installed.
6873
"""
6974

70-
def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry):
75+
def __init__(
76+
self,
77+
graph: AgentGraphDefinition,
78+
tools: ToolRegistry,
79+
):
7180
"""
7281
Initialize the runner.
7382
@@ -172,6 +181,26 @@ async def invoke(state: WorkflowState) -> dict:
172181
if node_instructions:
173182
msgs = [SystemMessage(content=node_instructions)] + msgs
174183
response = await bound_model.ainvoke(msgs)
184+
185+
node_obj = self._graph.get_node(nk)
186+
if node_obj is not None:
187+
input_text = '\r\n'.join(
188+
m.content if isinstance(m.content, str) else str(m.content)
189+
for m in msgs
190+
) if msgs else ''
191+
output_text = (
192+
response.content if hasattr(response, 'content') else str(response)
193+
)
194+
task = node_obj.get_config().evaluator.evaluate(input_text, output_text)
195+
run_tasks = _run_eval_tasks.get(None)
196+
if run_tasks is not None:
197+
run_tasks.setdefault(nk, []).append(task)
198+
else:
199+
log.warning(
200+
f"LangGraphAgentGraphRunner: eval task for node '{nk}' "
201+
"has no run context; judge results will not be tracked"
202+
)
203+
175204
return {'messages': [response]}
176205

177206
invoke.__name__ = nk
@@ -280,7 +309,9 @@ async def run(self, input: Any) -> AgentGraphResult:
280309
:param input: The string prompt to send to the agent graph
281310
:return: AgentGraphResult with the final output and metrics
282311
"""
283-
tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None
312+
pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
313+
token = _run_eval_tasks.set(pending_eval_tasks)
314+
tracker = self._graph.create_tracker()
284315
start_ns = time.perf_counter_ns()
285316

286317
try:
@@ -299,19 +330,18 @@ async def run(self, input: Any) -> AgentGraphResult:
299330
output = extract_last_message_content(messages)
300331

301332
# Flush per-node metrics to LD trackers
302-
handler.flush(self._graph)
333+
all_eval_results = await handler.flush(self._graph, pending_eval_tasks)
303334

304-
# Graph-level metrics
305-
if tracker:
306-
tracker.track_path(handler.path)
307-
tracker.track_duration(duration)
308-
tracker.track_invocation_success()
309-
tracker.track_total_tokens(sum_token_usage_from_messages(messages))
335+
tracker.track_path(handler.path)
336+
tracker.track_duration(duration)
337+
tracker.track_invocation_success()
338+
tracker.track_total_tokens(sum_token_usage_from_messages(messages))
310339

311340
return AgentGraphResult(
312341
output=output,
313342
raw=result,
314343
metrics=LDAIMetrics(success=True),
344+
evaluations=all_eval_results,
315345
)
316346

317347
except Exception as exc:
@@ -323,11 +353,12 @@ async def run(self, input: Any) -> AgentGraphResult:
323353
else:
324354
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
325355
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
326-
if tracker:
327-
tracker.track_duration(duration)
328-
tracker.track_invocation_failure()
356+
tracker.track_duration(duration)
357+
tracker.track_invocation_failure()
329358
return AgentGraphResult(
330359
output='',
331360
raw=None,
332361
metrics=LDAIMetrics(success=False),
333362
)
363+
finally:
364+
_run_eval_tasks.reset(token)

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from langchain_core.callbacks import BaseCallbackHandler
66
from langchain_core.outputs import ChatGeneration, LLMResult
77
from ldai.agent_graph import AgentGraphDefinition
8+
from ldai.providers.types import JudgeResult
89
from ldai.tracker import TokenUsage
910

1011
from ldai_langchain.langchain_helper import get_ai_usage_from_response
@@ -188,15 +189,22 @@ def on_tool_end(
188189
# Flush
189190
# ------------------------------------------------------------------
190191

191-
def flush(self, graph: AgentGraphDefinition) -> None:
192+
async def flush(
193+
self, graph: AgentGraphDefinition, eval_tasks=None
194+
) -> List[JudgeResult]:
192195
"""
193196
Emit all collected per-node metrics to the LaunchDarkly trackers.
194197
195198
Call this once after the graph run completes.
196199
197200
:param graph: The AgentGraphDefinition whose nodes hold the LD config trackers.
201+
:param eval_tasks: Optional dict mapping node key to a list of awaitables that
202+
return judge evaluation results. Multiple tasks arise when a node is visited
203+
more than once (e.g. in a graph with cycles).
204+
:return: All judge results collected across all nodes.
198205
"""
199206
node_trackers: Dict[str, Any] = {}
207+
all_eval_results: List[JudgeResult] = []
200208
for node_key in self._path:
201209
if node_key in node_trackers:
202210
continue
@@ -220,3 +228,15 @@ def flush(self, graph: AgentGraphDefinition) -> None:
220228

221229
for tool_key in self._node_tool_calls.get(node_key, []):
222230
config_tracker.track_tool_call(tool_key)
231+
232+
if not eval_tasks:
233+
continue
234+
235+
for eval_task in eval_tasks.get(node_key, []):
236+
results = await eval_task
237+
all_eval_results.extend(results)
238+
for r in results:
239+
if r.success:
240+
config_tracker.track_judge_result(r)
241+
242+
return all_eval_results

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
77

88
from ldai import LDMessage
9+
from ldai.evaluator import Evaluator
910

1011
from ldai_langchain import (
1112
LangChainModelRunner,
@@ -530,6 +531,7 @@ def sync_tool(x: str = '') -> str:
530531
cfg = AIAgentConfig(
531532
key='n',
532533
enabled=True,
534+
evaluator=Evaluator.noop(),
533535
create_tracker=MagicMock(),
534536
model=ModelConfig(
535537
name='gpt-4',
@@ -553,6 +555,7 @@ async def async_tool(x: str = '') -> str:
553555
cfg = AIAgentConfig(
554556
key='n',
555557
enabled=True,
558+
evaluator=Evaluator.noop(),
556559
create_tracker=MagicMock(),
557560
model=ModelConfig(
558561
name='gpt-4',

packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import AsyncMock, MagicMock, patch
55

66
from ldai.agent_graph import AgentGraphDefinition
7+
from ldai.evaluator import Evaluator
78
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
89
from ldai.providers import AgentGraphResult, ToolRegistry
910
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
@@ -20,6 +21,7 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
2021
model=ModelConfig(name='gpt-4'),
2122
provider=ProviderConfig(name='openai'),
2223
instructions='You are a helpful assistant.',
24+
evaluator=Evaluator.noop(),
2325
)
2426
graph_config = AIAgentGraphConfig(
2527
key='test-graph',

packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ldai.agent_graph import AgentGraphDefinition
1818
from ldai.models import AIAgentConfig, AIAgentGraphConfig, ModelConfig, ProviderConfig
1919
from ldai.tracker import AIGraphTracker, LDAIConfigTracker, TokenUsage
20+
from ldai.evaluator import Evaluator
2021
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
2122

2223

@@ -48,6 +49,7 @@ def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_k
4849
node_config = AIAgentConfig(
4950
key=node_key,
5051
enabled=True,
52+
evaluator=Evaluator.noop(),
5153
model=ModelConfig(name='gpt-4', parameters={}),
5254
provider=ProviderConfig(name='openai'),
5355
instructions='Be helpful.',
@@ -317,7 +319,8 @@ def test_on_tool_end_none_name_ignored():
317319
# flush() tests
318320
# ---------------------------------------------------------------------------
319321

320-
def test_flush_emits_token_events_to_ld_tracker():
322+
@pytest.mark.asyncio
323+
async def test_flush_emits_token_events_to_ld_tracker():
321324
"""flush() calls track_tokens on the node's config tracker."""
322325
mock_ld_client = MagicMock()
323326
graph = _make_graph(mock_ld_client, node_key='root-agent', graph_key='g1')
@@ -327,7 +330,7 @@ def test_flush_emits_token_events_to_ld_tracker():
327330
node_run_id = uuid4()
328331
handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent')
329332
handler.on_llm_end(_llm_result(15, 10, 5), run_id=uuid4(), parent_run_id=node_run_id)
330-
handler.flush(graph)
333+
await handler.flush(graph)
331334

332335
ev = _events(mock_ld_client)
333336
assert ev['$ld:ai:tokens:total'][0][1] == 15
@@ -336,7 +339,8 @@ def test_flush_emits_token_events_to_ld_tracker():
336339
assert ev['$ld:ai:generation:success'][0][1] == 1
337340

338341

339-
def test_flush_emits_duration():
342+
@pytest.mark.asyncio
343+
async def test_flush_emits_duration():
340344
"""flush() calls track_duration when duration was recorded."""
341345
mock_ld_client = MagicMock()
342346
graph = _make_graph(mock_ld_client)
@@ -346,13 +350,14 @@ def test_flush_emits_duration():
346350
run_id = uuid4()
347351
handler.on_chain_start({}, {}, run_id=run_id, name='root-agent')
348352
handler.on_chain_end({}, run_id=run_id)
349-
handler.flush(graph)
353+
await handler.flush(graph)
350354

351355
ev = _events(mock_ld_client)
352356
assert '$ld:ai:duration:total' in ev
353357

354358

355-
def test_flush_emits_tool_calls():
359+
@pytest.mark.asyncio
360+
async def test_flush_emits_tool_calls():
356361
"""flush() calls track_tool_call for each recorded tool invocation."""
357362
mock_ld_client = MagicMock()
358363
graph = _make_graph(mock_ld_client)
@@ -366,15 +371,16 @@ def test_flush_emits_tool_calls():
366371
tools_run_id = uuid4()
367372
handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools')
368373
handler.on_tool_end('r', run_id=uuid4(), parent_run_id=tools_run_id, name='fn_search')
369-
handler.flush(graph)
374+
await handler.flush(graph)
370375

371376
ev = _events(mock_ld_client)
372377
tool_events = ev.get('$ld:ai:tool_call', [])
373378
assert len(tool_events) == 1
374379
assert tool_events[0][0]['toolKey'] == 'search'
375380

376381

377-
def test_flush_includes_graph_key_in_node_events():
382+
@pytest.mark.asyncio
383+
async def test_flush_includes_graph_key_in_node_events():
378384
"""flush() passes graph_key to the node tracker so graphKey appears in events."""
379385
mock_ld_client = MagicMock()
380386
graph = _make_graph(mock_ld_client, graph_key='my-graph')
@@ -384,14 +390,15 @@ def test_flush_includes_graph_key_in_node_events():
384390
node_run_id = uuid4()
385391
handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent')
386392
handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id)
387-
handler.flush(graph)
393+
await handler.flush(graph)
388394

389395
ev = _events(mock_ld_client)
390396
token_data = ev['$ld:ai:tokens:total'][0][0]
391397
assert token_data.get('graphKey') == 'my-graph'
392398

393399

394-
def test_flush_with_no_graph_key_on_node_tracker():
400+
@pytest.mark.asyncio
401+
async def test_flush_with_no_graph_key_on_node_tracker():
395402
"""When node tracker has no graph_key, events omit graphKey."""
396403
mock_ld_client = MagicMock()
397404
context = MagicMock()
@@ -408,6 +415,7 @@ def test_flush_with_no_graph_key_on_node_tracker():
408415
node_config = AIAgentConfig(
409416
key='root-agent',
410417
enabled=True,
418+
evaluator=Evaluator.noop(),
411419
model=ModelConfig(name='gpt-4', parameters={}),
412420
provider=ProviderConfig(name='openai'),
413421
instructions='Be helpful.',
@@ -425,36 +433,38 @@ def test_flush_with_no_graph_key_on_node_tracker():
425433
nodes=nodes,
426434
context=context,
427435
enabled=True,
428-
create_tracker=lambda: None,
436+
create_tracker=lambda: AIGraphTracker(mock_ld_client, 'v1', 'test-graph', 1, context),
429437
)
430438

431439
handler = LDMetricsCallbackHandler({'root-agent'}, {})
432440
node_run_id = uuid4()
433441
handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent')
434442
handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id)
435-
handler.flush(graph)
443+
await handler.flush(graph)
436444

437445
ev = _events(mock_ld_client)
438446
token_data = ev['$ld:ai:tokens:total'][0][0]
439447
assert 'graphKey' not in token_data
440448

441449

442-
def test_flush_skips_nodes_not_in_path():
450+
@pytest.mark.asyncio
451+
async def test_flush_skips_nodes_not_in_path():
443452
"""flush() only emits events for nodes that were actually executed."""
444453
mock_ld_client = MagicMock()
445454
graph = _make_graph(mock_ld_client)
446455
tracker = graph.create_tracker()
447456

448457
# Handler with 'root-agent' in node_keys but never started
449458
handler = LDMetricsCallbackHandler({'root-agent'}, {})
450-
handler.flush(graph)
459+
await handler.flush(graph)
451460

452461
ev = _events(mock_ld_client)
453462
assert '$ld:ai:tokens:total' not in ev
454463
assert '$ld:ai:generation:success' not in ev
455464

456465

457-
def test_flush_skips_node_without_tracker():
466+
@pytest.mark.asyncio
467+
async def test_flush_skips_node_without_tracker():
458468
"""flush() silently skips nodes whose config has no tracker."""
459469
mock_ld_client = MagicMock()
460470
context = MagicMock()
@@ -463,6 +473,7 @@ def test_flush_skips_node_without_tracker():
463473
key='no-track',
464474
enabled=True,
465475
create_tracker=lambda: None,
476+
evaluator=Evaluator.noop(),
466477
model=ModelConfig(name='gpt-4', parameters={}),
467478
provider=ProviderConfig(name='openai'),
468479
instructions='',
@@ -483,7 +494,7 @@ def test_flush_skips_node_without_tracker():
483494
node_run_id = uuid4()
484495
handler.on_chain_start({}, {}, run_id=node_run_id, name='no-track')
485496
handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id)
486-
handler.flush(graph) # should not raise
497+
await handler.flush(graph) # should not raise
487498

488499
mock_ld_client.track.assert_not_called()
489500

0 commit comments

Comments
 (0)