Skip to content

Commit 1f877ef

Browse files
committed
Test triage agent
Signed-off-by: Nikola Forró <nforro@redhat.com>
1 parent 1a2352c commit 1f877ef

9 files changed

Lines changed: 346 additions & 14 deletions

File tree

beeai/Containerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ RUN pip3 install --no-cache-dir \
2828
beeai-framework[mcp,duckduckgo]==0.1.31 \
2929
openinference-instrumentation-beeai \
3030
arize-phoenix-otel \
31+
deepeval \
3132
&& cd /usr/local/lib/python3.13/site-packages \
3233
&& patch -p2 -i /tmp/beeai-gemini.patch \
3334
&& patch -p2 -i /tmp/beeai-gemini-malformed-function-call.patch \

beeai/agents/backport_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def prompt(self) -> str:
144144
6. {{ backport_git_steps }}
145145
"""
146146

147-
async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
147+
async def run_with_schema(self, input: TInputSchema, capture_raw_response: bool = False) -> TOutputSchema:
148148
async with mcp_tools(
149149
os.getenv("MCP_GATEWAY_URL"),
150150
filter=lambda t: t
@@ -153,7 +153,7 @@ async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
153153
tools = self._tools.copy()
154154
try:
155155
self._tools.extend(gateway_tools)
156-
return await self._run_with_schema(input)
156+
return await self._run_with_schema(input, capture_raw_response=capture_raw_response)
157157
finally:
158158
self._tools = tools
159159
# disassociate removed tools from requirements

beeai/agents/base_agent.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515

1616
class BaseAgent(RequirementAgent, ABC):
17+
last_raw_response: RequirementAgentRunOutput | None = None
18+
1719
@property
1820
@abstractmethod
1921
def input_schema(self) -> type[TInputSchema]: ...
@@ -32,7 +34,9 @@ def _render_prompt(self, input: TInputSchema) -> str:
3234
)
3335
return template.render(input)
3436

35-
async def _run_with_schema(self, input: TInputSchema) -> TOutputSchema:
37+
async def _run_with_schema(
38+
self, input: TInputSchema, capture_raw_response: bool = False
39+
) -> TOutputSchema:
3640
max_retries_per_step = int(os.getenv("BEEAI_MAX_RETRIES_PER_STEP", 5))
3741
total_max_retries = int(os.getenv("BEEAI_TOTAL_MAX_RETRIES", 10))
3842
max_iterations = int(os.getenv("BEEAI_MAX_ITERATIONS", 100))
@@ -46,10 +50,14 @@ async def _run_with_schema(self, input: TInputSchema) -> TOutputSchema:
4650
max_iterations=max_iterations,
4751
),
4852
)
53+
if capture_raw_response:
54+
self.last_raw_response = response
4955
return self.output_schema.model_validate_json(response.result.text)
5056

51-
async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
52-
return await self._run_with_schema(input)
57+
async def run_with_schema(
58+
self, input: TInputSchema, capture_raw_response: bool = False
59+
) -> TOutputSchema:
60+
return await self._run_with_schema(input, capture_raw_response)
5361

5462

5563
if os.getenv("LITELLM_DEBUG"):
@@ -58,4 +66,5 @@ async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
5866
import beeai_framework.adapters.litellm.chat
5967
import beeai_framework.adapters.litellm.embedding
6068
from beeai_framework.adapters.litellm.utils import litellm_debug
69+
6170
litellm_debug(True)

beeai/agents/rebase_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def prompt(self) -> str:
173173
- Any validation issues found with rpmlint
174174
"""
175175

176-
async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
176+
async def run_with_schema(self, input: TInputSchema, capture_raw_response: bool = False) -> TOutputSchema:
177177
async with mcp_tools(
178178
os.getenv("MCP_GATEWAY_URL"),
179179
filter=lambda t: t
@@ -182,7 +182,7 @@ async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
182182
tools = self._tools.copy()
183183
try:
184184
self._tools.extend(gateway_tools)
185-
return await self._run_with_schema(input)
185+
return await self._run_with_schema(input, capture_raw_response=capture_raw_response)
186186
finally:
187187
self._tools = tools
188188
# disassociate removed tools from requirements

beeai/agents/tests/_utils.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import asyncio
5+
import os
6+
from collections.abc import Awaitable, Callable
7+
from pathlib import Path
8+
from typing import TypeVar
9+
10+
import pytest
11+
from deepeval import evaluate
12+
from deepeval.dataset import EvaluationDataset, Golden
13+
from deepeval.evaluate import DisplayConfig
14+
from deepeval.metrics import BaseMetric
15+
from deepeval.test_case import LLMTestCase
16+
from deepeval.test_run.test_run import TestRunResultDisplay
17+
from rich.console import Console, Group
18+
from rich.panel import Panel
19+
from rich.table import Table
20+
21+
from beeai_framework.agents import AnyAgent
22+
23+
ROOT_CACHE_DIR = f"/tmp/.cache"
24+
Path(ROOT_CACHE_DIR).mkdir(parents=True, exist_ok=True)
25+
26+
27+
T = TypeVar("T", bound=AnyAgent)
28+
29+
30+
async def create_dataset(
31+
*,
32+
name: str,
33+
agent_factory: Callable[[], T],
34+
agent_run: Callable[[T, LLMTestCase], Awaitable[None]],
35+
goldens: list[Golden],
36+
cache: bool | None = None,
37+
) -> EvaluationDataset:
38+
dataset = EvaluationDataset()
39+
40+
cache_dir = Path(f"{ROOT_CACHE_DIR}/{name}")
41+
if cache is None:
42+
cache = os.getenv("EVAL_CACHE_DATASET", "").lower() == "true"
43+
44+
if cache and cache_dir.exists():
45+
for file_path in cache_dir.glob("*.json"):
46+
dataset.add_test_cases_from_json_file(
47+
file_path=str(file_path.absolute().resolve()),
48+
input_key_name="input",
49+
actual_output_key_name="actual_output",
50+
expected_output_key_name="expected_output",
51+
context_key_name="context",
52+
tools_called_key_name="tools_called",
53+
expected_tools_key_name="expected_tools",
54+
retrieval_context_key_name="retrieval_context",
55+
)
56+
else:
57+
58+
async def process_golden(golden: Golden) -> LLMTestCase:
59+
agent = agent_factory()
60+
case = LLMTestCase(
61+
input=golden.input,
62+
expected_tools=golden.expected_tools,
63+
actual_output="",
64+
expected_output=golden.expected_output,
65+
comments=golden.comments,
66+
context=golden.context,
67+
tools_called=golden.tools_called,
68+
retrieval_context=golden.retrieval_context,
69+
additional_metadata=golden.additional_metadata,
70+
)
71+
await agent_run(agent, case)
72+
return case
73+
74+
for test_case in await asyncio.gather(*[process_golden(golden) for golden in goldens], return_exceptions=False):
75+
dataset.add_test_case(test_case)
76+
77+
if cache:
78+
dataset.save_as(file_type="json", directory=str(cache_dir.absolute()), include_test_cases=True)
79+
80+
for case in dataset.test_cases:
81+
case.name = f"{name} - {case.input[0:128].strip()}" # type: ignore
82+
83+
return dataset
84+
85+
86+
def evaluate_dataset(
87+
dataset: EvaluationDataset, metrics: list[BaseMetric], display_mode: TestRunResultDisplay | None = None
88+
) -> None:
89+
console = Console()
90+
console.print("[bold green]Evaluating dataset[/bold green]")
91+
92+
if display_mode is None:
93+
display_mode = TestRunResultDisplay(os.environ.get("EVAL_DISPLAY_MODE", "all"))
94+
95+
output = evaluate(
96+
test_cases=dataset.test_cases, # type: ignore
97+
metrics=metrics,
98+
display_config=DisplayConfig(
99+
show_indicator=False, print_results=False, verbose_mode=False, display_option=None
100+
),
101+
)
102+
103+
# Calculate pass/fail counts
104+
total = len(output.test_results)
105+
passed = sum(
106+
bool(test_result.metrics_data) and all(md.success for md in (test_result.metrics_data or []))
107+
for test_result in output.test_results
108+
)
109+
failed = total - passed
110+
111+
# Print summary table
112+
summary_table = Table(title="Test Results Summary", show_header=True, header_style="bold cyan")
113+
summary_table.add_column("Total", justify="right")
114+
summary_table.add_column("Passed", justify="right", style="green")
115+
summary_table.add_column("Failed", justify="right", style="red")
116+
summary_table.add_row(str(total), str(passed), str(failed))
117+
console.print(summary_table)
118+
119+
for test_result in output.test_results:
120+
if display_mode != TestRunResultDisplay.ALL and (
121+
(display_mode == TestRunResultDisplay.FAILING and test_result.success)
122+
or (display_mode == TestRunResultDisplay.PASSING and not test_result.success)
123+
):
124+
continue
125+
126+
# Info Table
127+
info_table = Table(show_header=False, box=None, pad_edge=False)
128+
info_table.add_row("Input", str(test_result.input))
129+
info_table.add_row("Expected Output", str(test_result.expected_output))
130+
info_table.add_row("Actual Output", str(test_result.actual_output))
131+
132+
# Metrics Table
133+
metrics_table = Table(title="Metrics", show_header=True, header_style="bold magenta")
134+
metrics_table.add_column("Metric")
135+
metrics_table.add_column("Success")
136+
metrics_table.add_column("Score")
137+
metrics_table.add_column("Threshold")
138+
metrics_table.add_column("Reason")
139+
metrics_table.add_column("Error")
140+
# metrics_table.add_column("Verbose Log")
141+
142+
for metric_data in test_result.metrics_data or []:
143+
metrics_table.add_row(
144+
str(metric_data.name),
145+
str(metric_data.success),
146+
str(metric_data.score),
147+
str(metric_data.threshold),
148+
str(metric_data.reason),
149+
str(metric_data.error) if metric_data.error else "",
150+
# str(metric_data.verbose_logs),
151+
)
152+
153+
# Print the panel with info and metrics table
154+
console.print(
155+
Panel(
156+
Group(info_table, metrics_table),
157+
title=f"[bold blue]{test_result.name}[/bold blue]",
158+
border_style="blue",
159+
)
160+
)
161+
162+
# Gather failed tests
163+
if failed:
164+
pytest.fail(f"{failed}/{total} tests failed. See the summary table above for more details.", pytrace=False)
165+
else:
166+
assert 1 == 1

beeai/agents/tests/model.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
from typing import Any, TypeVar
6+
7+
from deepeval.key_handler import KEY_FILE_HANDLER, KeyValues
8+
from deepeval.models import DeepEvalBaseLLM
9+
from dotenv import load_dotenv
10+
from pydantic import BaseModel
11+
12+
from beeai_framework.backend import ChatModel, ChatModelParameters
13+
from beeai_framework.backend.constants import ProviderName
14+
from beeai_framework.backend.message import UserMessage
15+
from beeai_framework.middleware.trajectory import GlobalTrajectoryMiddleware
16+
from beeai_framework.utils import ModelLike
17+
18+
TSchema = TypeVar("TSchema", bound=BaseModel)
19+
20+
21+
load_dotenv()
22+
23+
24+
class DeepEvalLLM(DeepEvalBaseLLM):
25+
def __init__(self, model: ChatModel, *args: Any, **kwargs: Any) -> None:
26+
self._model = model
27+
super().__init__(model.model_id, *args, **kwargs)
28+
29+
def load_model(self, *args: Any, **kwargs: Any) -> None:
30+
return None
31+
32+
def generate(self, prompt: str, schema: BaseModel | None = None) -> str:
33+
raise NotImplementedError()
34+
35+
async def a_generate(self, prompt: str, schema: TSchema | None = None) -> str:
36+
input_msg = UserMessage(prompt)
37+
response = await self._model.create(
38+
messages=[input_msg],
39+
response_format=schema.model_json_schema(mode="serialization") if schema is not None else None,
40+
stream=False,
41+
temperature=0,
42+
).middleware(
43+
GlobalTrajectoryMiddleware(
44+
pretty=True, exclude_none=True, enabled=os.environ.get("EVAL_LOG_LLM_CALLS", "").lower() == "true"
45+
)
46+
)
47+
text = response.get_text_content()
48+
return schema.model_validate_json(text) if schema else text # type: ignore
49+
50+
def get_model_name(self) -> str:
51+
return f"{self._model.model_id} ({self._model.provider_id})"
52+
53+
@staticmethod
54+
def from_name(
55+
name: str | ProviderName | None = None, options: ModelLike[ChatModelParameters] | None = None, **kwargs: Any
56+
) -> "DeepEvalLLM":
57+
name = name or KEY_FILE_HANDLER.fetch_data(KeyValues.LOCAL_MODEL_NAME)
58+
model = ChatModel.from_name(name, options, **kwargs)
59+
return DeepEvalLLM(model)

0 commit comments

Comments
 (0)