diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 33c1693208..2d7145c5c9 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -25,6 +25,7 @@ from google.genai import types as genai_types from ..agents.llm_agent import Agent +from ..apps.app import App from ..evaluation.base_eval_service import BaseEvalService from ..evaluation.base_eval_service import EvaluateConfig from ..evaluation.base_eval_service import EvaluateRequest @@ -93,6 +94,17 @@ def get_root_agent(agent_module_file_path: str) -> Agent: return root_agent +def get_root_agent_and_app( + agent_module_file_path: str, +) -> tuple[Agent, Optional[App]]: + """Returns root agent and App (if defined) given the agent module.""" + agent_module = _get_agent_module(agent_module_file_path) + agent_obj = agent_module.agent + app = agent_obj if isinstance(agent_obj, App) else None + root_agent = agent_obj.root_agent + return root_agent, app + + def try_get_reset_func(agent_module_file_path: str) -> Any: """Returns reset function for the agent, if present, given the agent module.""" agent_module = _get_agent_module(agent_module_file_path) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 07ccc15892..6b1868b849 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -825,6 +825,7 @@ def cli_eval( from .cli_eval import _collect_inferences from .cli_eval import get_default_metric_info from .cli_eval import get_root_agent + from .cli_eval import get_root_agent_and_app from .cli_eval import parse_and_get_evals_to_run from .cli_eval import pretty_print_eval_result except ModuleNotFoundError as mnf: @@ -834,7 +835,7 @@ def cli_eval( print(f"Using evaluation criteria: {eval_config}") eval_metrics = get_eval_metrics_from_config(eval_config) - root_agent = get_root_agent(agent_module_file_path) + root_agent, app = get_root_agent_and_app(agent_module_file_path) app_name = os.path.basename(agent_module_file_path) agents_dir = os.path.dirname(agent_module_file_path) eval_sets_manager = None @@ -936,6 +937,7 @@ def cli_eval( eval_service = LocalEvalService( root_agent=root_agent, + app=app, eval_sets_manager=eval_sets_manager, eval_set_results_manager=eval_set_results_manager, user_simulator_provider=user_simulator_provider, @@ -1859,7 +1861,7 @@ def cli_api_server( default=False, help=( "Optional. Deploy ADK Web UI if set. (default: deploy ADK API server" - " only). WARNING: The web UI is for development and testing only — do" + " only). WARNING: The web UI is for development and testing only — do" " not use in production." ), ) @@ -1910,7 +1912,7 @@ def cli_api_server( default=False, help="Optional. Whether to enable A2A endpoint.", ) -# Kept as raw str (not parsed to list) — interpolated directly into Dockerfile CMD. +# Kept as raw str (not parsed to list) — interpolated directly into Dockerfile CMD. @click.option( "--trigger_sources", type=str, @@ -2389,7 +2391,7 @@ def cli_deploy_agent_engine( default=False, help=( "Optional. Deploy ADK Web UI if set. (default: deploy ADK API server" - " only). WARNING: The web UI is for development and testing only — do" + " only). WARNING: The web UI is for development and testing only — do" " not use in production." ), ) @@ -2433,7 +2435,7 @@ def cli_deploy_agent_engine( " version in the dev environment)" ), ) -# Kept as raw str (not parsed to list) — interpolated directly into Dockerfile CMD. +# Kept as raw str (not parsed to list) — interpolated directly into Dockerfile CMD. @click.option( "--trigger_sources", type=str, diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index f8fb6795aa..351d760862 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -26,6 +26,7 @@ from pydantic import BaseModel from ..agents.llm_agent import Agent +from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..events.event import Event @@ -197,6 +198,7 @@ async def _generate_inferences_from_root_agent( session_service: Optional[BaseSessionService] = None, artifact_service: Optional[BaseArtifactService] = None, memory_service: Optional[BaseMemoryService] = None, + app: Optional[App] = None, ) -> list[Invocation]: """Scrapes the root agent in coordination with the user simulator.""" @@ -235,13 +237,15 @@ async def _generate_inferences_from_root_agent( ensure_retry_options_plugin = EnsureRetryOptionsPlugin( name="ensure_retry_options" ) + app_plugins = list(app.plugins) if app else [] async with Runner( app_name=app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, - plugins=[request_intercepter_plugin, ensure_retry_options_plugin], + plugins=app_plugins + + [request_intercepter_plugin, ensure_retry_options_plugin], ) as runner: events = [] while True: diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 2426204ca0..382bed67a2 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -25,6 +25,7 @@ from typing_extensions import override from ..agents.base_agent import BaseAgent +from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..errors.not_found_error import NotFoundError @@ -116,6 +117,7 @@ def __init__( self, root_agent: BaseAgent, eval_sets_manager: EvalSetsManager, + app: Optional[App] = None, metric_evaluator_registry: Optional[MetricEvaluatorRegistry] = None, session_service: Optional[BaseSessionService] = None, artifact_service: Optional[BaseArtifactService] = None, @@ -125,6 +127,7 @@ def __init__( memory_service: Optional[BaseMemoryService] = None, ): self._root_agent = root_agent + self._app = app self._eval_sets_manager = eval_sets_manager metric_evaluator_registry = ( metric_evaluator_registry or DEFAULT_METRIC_EVALUATOR_REGISTRY @@ -182,6 +185,7 @@ async def run_inference(eval_case): eval_set_id=inference_request.eval_set_id, eval_case=eval_case, root_agent=self._root_agent, + app=self._app, ) inference_results = [run_inference(eval_case) for eval_case in eval_cases] @@ -470,6 +474,7 @@ async def _perform_inference_single_eval_item( eval_set_id: str, eval_case: EvalCase, root_agent: BaseAgent, + app: Optional[App] = None, ) -> InferenceResult: initial_session = eval_case.session_input session_id = self._session_id_supplier() @@ -491,6 +496,7 @@ async def _perform_inference_single_eval_item( session_service=self._session_service, artifact_service=self._artifact_service, memory_service=self._memory_service, + app=app, ) )