Skip to content
Open
84 changes: 84 additions & 0 deletions src/google/adk/evaluation/agent_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .eval_metrics import PrebuiltMetrics
from .eval_result import EvalCaseResult
from .eval_set import EvalSet
from .eval_set_results_manager import EvalSetResultsManager
from .eval_sets_manager import EvalSetsManager
from .evaluator import EvalStatus
from .in_memory_eval_sets_manager import InMemoryEvalSetsManager
Expand Down Expand Up @@ -113,6 +114,8 @@ async def evaluate_eval_set(
num_runs: int = NUM_RUNS,
agent_name: Optional[str] = None,
print_detailed_results: bool = True,
app_name: Optional[str] = None,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
):
"""Evaluates an agent using the given EvalSet.

Expand All @@ -128,6 +131,10 @@ async def evaluate_eval_set(
assessed.
agent_name: The name of the agent, if trying to evaluate something other
than root agent. If left empty or none, then root agent is evaluated.
app_name: The application name used by eval set results manager while
persisting eval set results.
eval_set_results_manager: Optional manager used to persist the eval set
evaluation result as `*.evalset_result.json`.
print_detailed_results: Whether to print detailed results for each metric
evaluation.
"""
Expand Down Expand Up @@ -162,6 +169,13 @@ async def evaluate_eval_set(
num_runs=num_runs,
user_simulator_provider=user_simulator_provider,
)
AgentEvaluator._maybe_save_eval_set_result(
agent_module=agent_module,
app_name=app_name,
eval_set=eval_set,
eval_results_by_eval_id=eval_results_by_eval_id,
eval_set_results_manager=eval_set_results_manager,
)

# Step 2: Post-process the results!

Expand Down Expand Up @@ -200,6 +214,8 @@ async def evaluate(
agent_name: Optional[str] = None,
initial_session_file: Optional[str] = None,
print_detailed_results: bool = True,
app_name: Optional[str] = None,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
):
"""Evaluates an Agent given eval data.

Expand All @@ -214,6 +230,10 @@ async def evaluate(
num_runs: Number of times all entries in the eval dataset should be
assessed.
agent_name: The name of the agent.
app_name: The application name used by eval set results manager while
persisting eval set results.
eval_set_results_manager: Optional manager used to persist the eval set
evaluation result as `*.evalset_result.json`.
initial_session_file: File that contains initial session state that is
needed by all the evals in the eval dataset.
print_detailed_results: Whether to print detailed results for each metric
Expand Down Expand Up @@ -244,6 +264,8 @@ async def evaluate(
eval_config=eval_config,
num_runs=num_runs,
agent_name=agent_name,
app_name=app_name,
eval_set_results_manager=eval_set_results_manager,
print_detailed_results=print_detailed_results,
)

Expand Down Expand Up @@ -644,6 +666,68 @@ def _get_eval_metric_results_with_invocation(
)
return eval_metric_results

@staticmethod
def _resolve_app_name(
agent_module: str, app_name: Optional[str] = None
) -> str:
"""Returns app_name for storing eval set results."""
if app_name:
return app_name

parts = [part for part in agent_module.split(".") if part]
if not parts:
return agent_module
if len(parts) > 1 and parts[-1] == "agent":
return parts[-2]
return parts[-1]

@staticmethod
def _flatten_eval_results_by_eval_case_order(
eval_set: EvalSet,
eval_results_by_eval_id: dict[str, list[EvalCaseResult]],
) -> list[EvalCaseResult]:
"""Returns eval results flattened in eval case order."""
flattened_results: list[EvalCaseResult] = []
seen_eval_ids = set()
for eval_case in eval_set.eval_cases:
eval_results = eval_results_by_eval_id.get(eval_case.eval_id, [])
if eval_results:
flattened_results.extend(eval_results)
seen_eval_ids.add(eval_case.eval_id)

for eval_id, eval_results in eval_results_by_eval_id.items():
if eval_id in seen_eval_ids:
continue
flattened_results.extend(eval_results)
Comment thread
ftnext marked this conversation as resolved.
Outdated

return flattened_results

@staticmethod
def _maybe_save_eval_set_result(
agent_module: str,
app_name: Optional[str],
eval_set: EvalSet,
eval_results_by_eval_id: dict[str, list[EvalCaseResult]],
eval_set_results_manager: Optional[EvalSetResultsManager],
) -> None:
"""Saves eval set result if manager is provided."""
if eval_set_results_manager is None:
return

resolved_app_name = AgentEvaluator._resolve_app_name(
agent_module=agent_module, app_name=app_name
)
all_eval_case_results = (
AgentEvaluator._flatten_eval_results_by_eval_case_order(
eval_set=eval_set, eval_results_by_eval_id=eval_results_by_eval_id
)
)
eval_set_results_manager.save_eval_set_result(
app_name=resolved_app_name,
eval_set_id=eval_set.eval_set_id,
eval_case_results=all_eval_case_results,
)

@staticmethod
def _process_metrics_and_get_failures(
eval_metric_results: dict[str, list[_EvalMetricResultWithInvocation]],
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/test_with_test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from google.adk.evaluation.agent_evaluator import AgentEvaluator
from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
import pytest


Expand All @@ -35,3 +36,27 @@ async def test_with_folder_of_test_files_long_running():
),
num_runs=4,
)


@pytest.mark.asyncio
async def test_with_single_test_file_saves_eval_set_result(
tmp_path,
):
"""Persists eval set results with derived app_name when app_name is omitted."""
eval_set_results_manager = LocalEvalSetResultsManager(
agents_dir=str(tmp_path)
)
await AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir=(
"tests/integration/fixture/home_automation_agent/simple_test.test.json"
),
eval_set_results_manager=eval_set_results_manager,
)

saved_result_files = list(
(tmp_path / "home_automation_agent" / ".adk" / "eval_history").glob(
"*.evalset_result.json"
)
)
assert saved_result_files
Loading