Skip to content

Commit 9760ae9

Browse files
authored
feat: agent runner improvements (#95)
* feat: use single directory per agent run * feat: add summary stats * fix: lint and format * feat: streaming
1 parent 0246f2c commit 9760ae9

2 files changed

Lines changed: 337 additions & 54 deletions

File tree

src/deepset_mcp/benchmark/runner/agent_benchmark_runner.py

Lines changed: 199 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import asyncio
22
import json
33
import logging
4+
import sys
5+
from collections.abc import Callable
46
from datetime import datetime
57
from pathlib import Path
68
from typing import Any
79

810
from haystack.dataclasses.chat_message import ChatMessage
11+
from haystack.dataclasses.streaming_chunk import StreamingChunk
912

1013
from deepset_mcp.api.client import AsyncDeepsetClient
1114
from deepset_mcp.benchmark.runner.agent_loader import load_agent
@@ -28,16 +31,22 @@ def __init__(
2831
self,
2932
agent_config: AgentConfig,
3033
benchmark_config: BenchmarkConfig,
34+
streaming: bool = False,
3135
):
3236
"""
3337
Initialize the benchmark runner.
3438
3539
Args:
3640
agent_config: Configuration for the agent to test.
3741
benchmark_config: Benchmark configuration.
42+
streaming: Whether to enable streaming output during agent execution.
3843
"""
3944
self.agent_config = agent_config
4045
self.benchmark_config = benchmark_config
46+
self.streaming = streaming
47+
48+
# Create a single timestamp for this benchmark run
49+
self.run_timestamp = datetime.now()
4150

4251
try:
4352
secret_key = self.benchmark_config.get_env_var("LANGFUSE_SECRET_KEY")
@@ -52,6 +61,52 @@ def __init__(
5261
self.agent = agent
5362
self.commit_hash = commit_hash
5463

64+
# Create the run ID once for all test cases
65+
self.run_id = (
66+
f"{self.agent_config.display_name}-{self.commit_hash}_{self.run_timestamp.strftime('%Y%m%d_%H%M%S')}"
67+
)
68+
69+
# TODO: streaming is WIP; wait until https://github.com/deepset-ai/haystack-core-integrations/issues/1947 is fixed
70+
def _create_streaming_callback(self, test_case_name: str) -> Callable[[StreamingChunk], Any]:
71+
"""
72+
Create a streaming callback function for a specific test case.
73+
74+
Args:
75+
test_case_name: Name of the test case for logging context
76+
77+
Returns:
78+
Callback function for streaming
79+
"""
80+
81+
async def streaming_callback(chunk: StreamingChunk) -> None:
82+
"""Handle streaming chunks from the agent."""
83+
if hasattr(chunk, "content") and chunk.content:
84+
# meta content_block type=tool_use
85+
# meta type (content_block_start)
86+
# meta delta type=input_json_delta
87+
# meta delta message_delta
88+
# meta delta stop_reason=tool_use
89+
# Print with test case context, using a subtle prefix
90+
content = chunk.content
91+
# Handle newlines by adding the prefix to each new line
92+
lines = content.split("\n")
93+
for i, line in enumerate(lines):
94+
if i == 0:
95+
print(f"{line}", end="")
96+
elif line.strip(): # Only print non-empty lines with prefix
97+
print(f"\n[{test_case_name}] {line}", end="")
98+
else:
99+
print() # Just print the newline for empty lines
100+
101+
# If the content ends with a newline, print it
102+
if content.endswith("\n"):
103+
print()
104+
105+
# Ensure output is flushed immediately
106+
sys.stdout.flush()
107+
108+
return streaming_callback
109+
55110
async def run_single_test(self, test_case_name: str) -> dict[str, Any]:
56111
"""
57112
Run the agent against a single test case.
@@ -93,7 +148,18 @@ async def run_single_test(self, test_case_name: str) -> dict[str, Any]:
93148
workspace=self.benchmark_config.deepset_workspace
94149
).validate(yaml_config=query_yaml_config)
95150

96-
agent_output = await self.agent.run_async(messages=[ChatMessage.from_user(test_config.prompt)])
151+
# Prepare streaming callback if streaming is enabled
152+
streaming_callback = None
153+
if self.streaming:
154+
streaming_callback = self._create_streaming_callback(test_case_name)
155+
print(f"\n🤖 [{test_case_name}] Agent starting...\n")
156+
157+
agent_output = await self.agent.run_async(
158+
messages=[ChatMessage.from_user(test_config.prompt)], streaming_callback=streaming_callback
159+
)
160+
161+
if self.streaming:
162+
print(f"\n\n✅ [{test_case_name}] Agent completed.\n")
97163

98164
post_agent_validation = None
99165
if query_name:
@@ -169,22 +235,31 @@ async def run_single_test_with_cleanup(self, test_case_name: str) -> dict[str, A
169235

170236
return result
171237

172-
def run_all_tests(self, test_case_path: Path) -> list[dict[str, Any]]:
238+
def run_all_tests(self, test_case_path: Path) -> tuple[list[dict[str, Any]], dict[str, Any]]:
173239
"""
174240
Run the agent against all available test cases.
175241
176242
Args:
177243
test_case_path: Directory containing test case files
178244
179245
Returns:
180-
List of results for each test case
246+
Tuple of (test results list, summary statistics dict)
181247
"""
182248
# Find all test case files
183249
test_paths = find_all_test_case_paths(test_case_path)
184250

185251
if not test_paths:
186252
logger.warning(f"No test cases found in {test_case_path}")
187-
return []
253+
empty_summary = {
254+
"total_prompt_tokens": 0,
255+
"total_completion_tokens": 0,
256+
"tests_completed": 0,
257+
"tests_failed": 0,
258+
"avg_tool_calls": 0.0,
259+
"pass_rate_percent": 0.0,
260+
"fail_rate_percent": 0.0,
261+
}
262+
return [], empty_summary
188263

189264
logger.info(f"Found {len(test_paths)} test cases to run")
190265

@@ -195,13 +270,16 @@ def run_all_tests(self, test_case_path: Path) -> list[dict[str, Any]]:
195270
result = asyncio.run(self.run_single_test_with_cleanup(test_name))
196271
results.append(result)
197272

198-
return results
273+
# Create run summary CSV and get summary data
274+
summary_data = self._create_run_summary_csv(results)
275+
276+
return results, summary_data
199277

200278
async def run_all_tests_async(
201279
self,
202280
test_case_path: Path,
203281
concurrency: int = 1, # Keep concurrency low to avoid resource conflicts
204-
) -> list[dict[str, Any]]:
282+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
205283
"""
206284
Run all test cases asynchronously with controlled concurrency.
207285
@@ -210,14 +288,23 @@ async def run_all_tests_async(
210288
concurrency: Number of concurrent test runs (default: 1 for safety)
211289
212290
Returns:
213-
List of results for each test case
291+
Tuple of (test results list, summary statistics dict)
214292
"""
215293
# Find all test case files
216294
test_paths = find_all_test_case_paths(test_case_path)
217295

218296
if not test_paths:
219297
logger.warning(f"No test cases found in {test_case_path}")
220-
return []
298+
empty_summary = {
299+
"total_prompt_tokens": 0,
300+
"total_completion_tokens": 0,
301+
"tests_completed": 0,
302+
"tests_failed": 0,
303+
"avg_tool_calls": 0.0,
304+
"pass_rate_percent": 0.0,
305+
"fail_rate_percent": 0.0,
306+
}
307+
return [], empty_summary
221308

222309
logger.info(f"Found {len(test_paths)} test cases to run with concurrency={concurrency}")
223310

@@ -244,7 +331,10 @@ async def run_with_semaphore(test_name: str) -> dict[str, Any]:
244331
else:
245332
processed_results.append(result) # type: ignore
246333

247-
return processed_results
334+
# Create run summary CSV and get summary data
335+
summary_data = self._create_run_summary_csv(processed_results)
336+
337+
return processed_results, summary_data
248338

249339
def _format_results(
250340
self,
@@ -255,19 +345,21 @@ def _format_results(
255345
post_yaml: str | None = None,
256346
) -> dict[str, Any]:
257347
"""Format the agent output and metadata for saving to file."""
258-
timestamp = datetime.now()
259-
260348
return {
261349
"metadata": {
262350
"commit_hash": self.commit_hash,
263351
"agent_display_name": self.agent_config.display_name,
264352
"test_case_name": test_config.name,
265-
"timestamp": timestamp.isoformat(),
266-
"run_id": f"{self.agent_config.display_name}-{self.commit_hash}_{timestamp.strftime('%Y%m%d_%H%M%S')}",
353+
"timestamp": self.run_timestamp.isoformat(),
354+
"run_id": self.run_id,
267355
},
268356
"validation": {
269-
"pre_validation": "PASS" if is_pre_agent_valid else "FAIL",
270-
"post_validation": "PASS" if is_post_agent_valid else "FAIL",
357+
"pre_validation": "PASS"
358+
if is_pre_agent_valid is True
359+
else ("FAIL" if is_pre_agent_valid is False else None),
360+
"post_validation": "PASS"
361+
if is_post_agent_valid is True
362+
else ("FAIL" if is_post_agent_valid is False else None),
271363
},
272364
"messages": {
273365
"serialized": [message.to_dict() for message in agent_output["messages"]],
@@ -276,6 +368,86 @@ def _format_results(
276368
"pipeline_yaml": post_yaml,
277369
}
278370

371+
def _create_run_summary_csv(self, results: list[dict[str, Any]]) -> dict[str, Any]:
372+
"""
373+
Create a summary CSV file for the entire benchmark run.
374+
375+
Args:
376+
results: List of test results from the benchmark run
377+
378+
Returns:
379+
Dictionary containing the summary statistics
380+
"""
381+
# Initialize counters
382+
total_prompt_tokens = 0
383+
total_completion_tokens = 0
384+
tests_completed = 0
385+
tests_failed = 0
386+
total_tool_calls = 0
387+
tests_with_validation = 0
388+
validation_passes = 0
389+
390+
for result in results:
391+
if result["status"] == "success":
392+
tests_completed += 1
393+
processed_data = result["processed_data"]
394+
395+
# Sum token counts
396+
stats = processed_data["messages"]["stats"]
397+
total_prompt_tokens += stats["total_prompt_tokens"]
398+
total_completion_tokens += stats["total_completion_tokens"]
399+
total_tool_calls += stats["total_tool_calls"]
400+
401+
# Check validation results (exclude cases where pre or post validation is None)
402+
validation = processed_data["validation"]
403+
pre_val = validation["pre_validation"]
404+
post_val = validation["post_validation"]
405+
406+
# Only count validation if both pre and post validation exist
407+
if pre_val is not None and post_val is not None:
408+
tests_with_validation += 1
409+
410+
# Expected pattern: pre_validation should FAIL, post_validation should PASS
411+
# This indicates the agent successfully fixed the broken pipeline
412+
if pre_val == "FAIL" and post_val == "PASS":
413+
validation_passes += 1
414+
else:
415+
tests_failed += 1
416+
417+
# Calculate averages and rates
418+
avg_tool_calls = total_tool_calls / tests_completed if tests_completed > 0 else 0
419+
pass_rate = (validation_passes / tests_with_validation * 100) if tests_with_validation > 0 else 0
420+
fail_rate = 100 - pass_rate if tests_with_validation > 0 else 0
421+
422+
# Create summary dict
423+
summary_data = {
424+
"total_prompt_tokens": total_prompt_tokens,
425+
"total_completion_tokens": total_completion_tokens,
426+
"tests_completed": tests_completed,
427+
"tests_failed": tests_failed,
428+
"avg_tool_calls": round(avg_tool_calls, 2),
429+
"pass_rate_percent": round(pass_rate, 2),
430+
"fail_rate_percent": round(fail_rate, 2),
431+
}
432+
433+
# Create CSV content
434+
csv_data = [
435+
"total_prompt_tokens,total_completion_tokens,tests_completed,tests_failed,avg_tool_calls,pass_rate_percent,fail_rate_percent",
436+
f"{total_prompt_tokens},{total_completion_tokens},{tests_completed},{tests_failed},{avg_tool_calls:.2f},{pass_rate:.2f},{fail_rate:.2f}",
437+
]
438+
439+
# Save to main run directory
440+
run_dir = self.benchmark_config.output_dir / "agent_runs" / self.run_id
441+
run_dir.mkdir(exist_ok=True, parents=True)
442+
summary_file = run_dir / "run_summary.csv"
443+
444+
with open(summary_file, "w", encoding="utf-8") as f:
445+
f.write("\n".join(csv_data))
446+
447+
logger.info(f"Run summary saved to: {summary_file}")
448+
449+
return summary_data
450+
279451
@staticmethod
280452
def _extract_assistant_message_stats(messages: list[ChatMessage]) -> dict[str, str | int]:
281453
"""
@@ -344,15 +516,17 @@ def _save_run_results(processed_data: dict[str, Any], test_case_name: str, outpu
344516

345517
# Save test_results.csv
346518
csv_file = test_case_dir / "test_results.csv"
519+
pre_validation = processed_data["validation"]["pre_validation"] or "N/A"
520+
post_validation = processed_data["validation"]["post_validation"] or "N/A"
347521
csv_data = [
348522
"commit,test_case,agent,prompt_tokens,completion_tokens,tool_calls,model,pre_validation,post_validation",
349523
f"{metadata['commit_hash']},{test_case_name},{metadata['agent_display_name']},"
350524
f"{processed_data['messages']['stats']['total_prompt_tokens']},"
351525
f"{processed_data['messages']['stats']['total_completion_tokens']},"
352526
f"{processed_data['messages']['stats']['total_tool_calls']},"
353527
f"{processed_data['messages']['stats']['model']},"
354-
f"{processed_data['validation']['pre_validation']},"
355-
f"{processed_data['validation']['post_validation']}",
528+
f"{pre_validation},"
529+
f"{post_validation}",
356530
]
357531

358532
with open(csv_file, "w", encoding="utf-8") as f:
@@ -372,7 +546,8 @@ def run_agent_benchmark(
372546
benchmark_config: BenchmarkConfig,
373547
test_case_name: str | None = None,
374548
concurrency: int = 1,
375-
) -> list[dict[str, Any]]:
549+
streaming: bool = False,
550+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
376551
"""
377552
Convenience function to run agent benchmarks.
378553
@@ -381,6 +556,7 @@ def run_agent_benchmark(
381556
benchmark_config: Benchmark configuration.
382557
test_case_name: Specific test case to run (if None, runs all)
383558
concurrency: Number of concurrent test runs
559+
streaming: If True, run in streaming mode
384560
385561
Returns:
386562
List of test results
@@ -389,12 +565,16 @@ def run_agent_benchmark(
389565
runner = AgentBenchmarkRunner(
390566
agent_config=agent_config,
391567
benchmark_config=benchmark_config,
568+
streaming=streaming,
392569
)
393570

394571
if test_case_name:
395572
# Run single test case
396573
result = asyncio.run(runner.run_single_test_with_cleanup(test_case_name))
397-
return [result]
574+
results = [result]
575+
# Create run summary CSV for single test case
576+
summary_data = runner._create_run_summary_csv(results)
577+
return results, summary_data
398578
else:
399579
# Run all test cases
400580
if concurrency == 1:

0 commit comments

Comments
 (0)