Skip to content

Commit 7b7ddda

Browse files
KPJoshicopybara-github
authored andcommitted
feat: Add interface between optimization infra and LocalEvalService
details: * Enables the use of ADK evaluations via LocalEvalService for optimizing agents. * Provides flexibility in choosing eval sets and eval cases for training and validation. * Converts ADK eval results into a compact format useful for whitebox agent optimization. Co-authored-by: Keyur Joshi <keyurj@google.com> PiperOrigin-RevId: 875818012
1 parent 65d9a72 commit 7b7ddda

3 files changed

Lines changed: 754 additions & 1 deletion

File tree

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import Any
19+
from typing import Literal
20+
from typing import Optional
21+
22+
from pydantic import BaseModel
23+
from pydantic import Field
24+
25+
from ..agents.llm_agent import Agent
26+
from ..evaluation.base_eval_service import EvaluateConfig
27+
from ..evaluation.base_eval_service import EvaluateRequest
28+
from ..evaluation.base_eval_service import InferenceConfig
29+
from ..evaluation.base_eval_service import InferenceRequest
30+
from ..evaluation.base_eval_service import InferenceResult
31+
from ..evaluation.eval_case import get_all_tool_calls_with_responses
32+
from ..evaluation.eval_case import IntermediateData
33+
from ..evaluation.eval_case import Invocation
34+
from ..evaluation.eval_case import InvocationEvents
35+
from ..evaluation.eval_config import EvalConfig
36+
from ..evaluation.eval_config import get_eval_metrics_from_config
37+
from ..evaluation.eval_metrics import EvalStatus
38+
from ..evaluation.eval_result import EvalCaseResult
39+
from ..evaluation.eval_sets_manager import EvalSetsManager
40+
from ..evaluation.local_eval_service import LocalEvalService
41+
from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider
42+
from ..utils.context_utils import Aclosing
43+
from .data_types import UnstructuredSamplingResult
44+
from .sampler import Sampler
45+
46+
logger = logging.getLogger("google_adk." + __name__)
47+
48+
49+
def _log_eval_summary(eval_results: list[EvalCaseResult]):
50+
"""Logs a summary of eval results."""
51+
num_pass, num_fail, num_other = 0, 0, 0
52+
for eval_result in eval_results:
53+
eval_result: EvalCaseResult
54+
if eval_result.final_eval_status == EvalStatus.PASSED:
55+
num_pass += 1
56+
elif eval_result.final_eval_status == EvalStatus.FAILED:
57+
num_fail += 1
58+
else:
59+
num_other += 1
60+
log_str = f"Evaluation summary: {num_pass} PASSED, {num_fail} FAILED"
61+
if num_other:
62+
log_str += f", {num_other} OTHER"
63+
logger.info(log_str)
64+
65+
66+
def extract_tool_call_data(
67+
intermediate_data: IntermediateData | InvocationEvents,
68+
) -> list[dict[str, Any]]:
69+
"""Extracts tool calls and their responses from intermediate data."""
70+
call_response_pairs = get_all_tool_calls_with_responses(intermediate_data)
71+
result = []
72+
for tool_call, tool_response in call_response_pairs:
73+
result.append({
74+
"name": tool_call.name,
75+
"args": tool_call.args,
76+
"response": tool_response.response if tool_response else None,
77+
})
78+
return result
79+
80+
81+
def extract_single_invocation_info(
82+
invocation: Invocation,
83+
) -> dict[str, Any]:
84+
"""Extracts useful information from a single invocation."""
85+
user_prompt = ""
86+
for part in invocation.user_content.parts:
87+
if part.text and not part.thought:
88+
user_prompt += part.text
89+
agent_response = ""
90+
if invocation.final_response:
91+
for part in invocation.final_response.parts:
92+
if part.text and not part.thought:
93+
agent_response += part.text
94+
result = {"user_prompt": user_prompt, "agent_response": agent_response}
95+
if invocation.intermediate_data:
96+
tool_call_data = extract_tool_call_data(invocation.intermediate_data)
97+
result["tool_calls"] = tool_call_data
98+
return result
99+
100+
101+
class LocalEvalSamplerConfig(BaseModel):
102+
"""Contains configuration options required by the LocalEvalServiceInterface."""
103+
104+
eval_config: EvalConfig = Field(
105+
required=True,
106+
description="The configuration for the evaluation.",
107+
)
108+
109+
app_name: str = Field(
110+
required=True,
111+
description="The app name to use for evaluation.",
112+
)
113+
114+
train_eval_set: str = Field(
115+
required=True,
116+
description="The name of the eval set to use for optimization.",
117+
)
118+
119+
train_eval_case_ids: Optional[list[str]] = Field(
120+
default=None,
121+
description=(
122+
"The ids of the eval cases to use for optimization. If not provided,"
123+
" all eval cases in the train_eval_set will be used."
124+
),
125+
)
126+
127+
validation_eval_set: Optional[str] = Field(
128+
default=None,
129+
description=(
130+
"The name of the eval set to use for validating the optimized agent."
131+
" If not provided, the train_eval_set will also be used for"
132+
" validation."
133+
),
134+
)
135+
136+
validation_eval_case_ids: Optional[list[str]] = Field(
137+
default=None,
138+
description=(
139+
"The ids of the eval cases to use for validating the optimized agent."
140+
" If not provided, all eval cases in the validation_eval_set will be"
141+
" used. If validation_eval_set is also not provided, all train eval"
142+
" cases will be used."
143+
),
144+
)
145+
146+
147+
class LocalEvalSampler(Sampler[UnstructuredSamplingResult]):
148+
"""Evaluates candidate agents with the ADK's LocalEvalService."""
149+
150+
def __init__(
151+
self,
152+
config: LocalEvalSamplerConfig,
153+
eval_sets_manager: EvalSetsManager,
154+
):
155+
self._config = config
156+
self._eval_sets_manager = eval_sets_manager
157+
158+
self._train_eval_set = self._config.train_eval_set
159+
self._train_eval_case_ids = (
160+
self._config.train_eval_case_ids
161+
or self._get_eval_case_ids(self._train_eval_set)
162+
)
163+
164+
self._validation_eval_set = (
165+
self._config.validation_eval_set or self._train_eval_set
166+
)
167+
if self._config.validation_eval_case_ids:
168+
self._validation_eval_case_ids = self._config.validation_eval_case_ids
169+
elif self._config.validation_eval_set:
170+
self._validation_eval_case_ids = self._get_eval_case_ids(
171+
self._validation_eval_set
172+
)
173+
else:
174+
self._validation_eval_case_ids = self._train_eval_case_ids
175+
176+
def _get_selected_example_set_id(
177+
self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET]
178+
) -> str:
179+
"""Returns the ID of the selected example set."""
180+
return {
181+
Sampler.TRAIN_SET: self._train_eval_set,
182+
Sampler.VALIDATION_SET: self._validation_eval_set,
183+
}[example_set]
184+
185+
def _get_all_example_ids(
186+
self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET]
187+
) -> list[str]:
188+
"""Returns the IDs of all examples in the selected example set."""
189+
return {
190+
Sampler.TRAIN_SET: self._train_eval_case_ids,
191+
Sampler.VALIDATION_SET: self._validation_eval_case_ids,
192+
}[example_set]
193+
194+
def _get_eval_case_ids(self, eval_set_id: str) -> list[str]:
195+
"""Returns the ids of eval cases in the given eval set."""
196+
eval_set = self._eval_sets_manager.get_eval_set(
197+
app_name=self._config.app_name,
198+
eval_set_id=eval_set_id,
199+
)
200+
if eval_set:
201+
return [eval_case.eval_id for eval_case in eval_set.eval_cases]
202+
else:
203+
raise ValueError(
204+
f"Eval set `{eval_set_id}` does not exist for app"
205+
f" `{self._config.app_name}`."
206+
)
207+
208+
async def _evaluate_agent(
209+
self,
210+
agent: Agent,
211+
eval_set_id: str,
212+
eval_case_ids: list[str],
213+
) -> list[EvalCaseResult]:
214+
"""Evaluates the agent on the requested eval cases and returns the results.
215+
216+
Args:
217+
agent: The agent to evaluate.
218+
eval_set_id: The id of the eval set to use for evaluation.
219+
eval_case_ids: The ids of the eval cases to use for evaluation.
220+
221+
Returns:
222+
A list of EvalCaseResult, one per eval case.
223+
"""
224+
# create the inference request
225+
inference_request = InferenceRequest(
226+
app_name=self._config.app_name,
227+
eval_set_id=eval_set_id,
228+
eval_case_ids=eval_case_ids,
229+
inference_config=InferenceConfig(),
230+
)
231+
232+
# create the LocalEvalService
233+
user_simulator_provider = UserSimulatorProvider(
234+
self._config.eval_config.user_simulator_config
235+
)
236+
eval_service = LocalEvalService(
237+
root_agent=agent,
238+
eval_sets_manager=self._eval_sets_manager,
239+
user_simulator_provider=user_simulator_provider,
240+
)
241+
242+
# inference/sampling
243+
async with Aclosing(
244+
eval_service.perform_inference(inference_request=inference_request)
245+
) as agen:
246+
inference_results: list[InferenceResult] = [
247+
inference_result async for inference_result in agen
248+
]
249+
250+
# evaluation
251+
eval_metrics = get_eval_metrics_from_config(self._config.eval_config)
252+
evaluate_request = EvaluateRequest(
253+
inference_results=inference_results,
254+
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
255+
)
256+
async with Aclosing(
257+
eval_service.evaluate(evaluate_request=evaluate_request)
258+
) as agen:
259+
eval_results: list[EvalCaseResult] = [
260+
eval_result async for eval_result in agen
261+
]
262+
263+
return eval_results
264+
265+
def _extract_eval_data(
266+
self,
267+
eval_set_id: str,
268+
eval_results: list[EvalCaseResult],
269+
) -> dict[str, dict[str, Any]]:
270+
"""Extracts evaluation data from the eval results."""
271+
eval_data = {}
272+
for eval_result in eval_results:
273+
eval_result_dict = {}
274+
eval_case = self._eval_sets_manager.get_eval_case(
275+
app_name=self._config.app_name,
276+
eval_set_id=eval_set_id,
277+
eval_case_id=eval_result.eval_id,
278+
)
279+
if eval_case and eval_case.conversation_scenario:
280+
eval_result_dict["conversation_scenario"] = (
281+
eval_case.conversation_scenario
282+
)
283+
284+
per_invocation_results = []
285+
for (
286+
per_invocation_result
287+
) in eval_result.eval_metric_result_per_invocation:
288+
eval_metric_results = []
289+
for eval_metric_result in per_invocation_result.eval_metric_results:
290+
eval_metric_results.append({
291+
"metric_name": eval_metric_result.metric_name,
292+
"score": round(eval_metric_result.score, 2), # accurate enough
293+
"eval_status": eval_metric_result.eval_status.name,
294+
})
295+
per_invocation_result_dict = {
296+
"actual_invocation": extract_single_invocation_info(
297+
per_invocation_result.actual_invocation
298+
),
299+
"eval_metric_results": eval_metric_results,
300+
}
301+
if per_invocation_result.expected_invocation:
302+
per_invocation_result_dict["expected_invocation"] = (
303+
extract_single_invocation_info(
304+
per_invocation_result.expected_invocation
305+
)
306+
)
307+
per_invocation_results.append(per_invocation_result_dict)
308+
eval_result_dict["invocations"] = per_invocation_results
309+
eval_data[eval_result.eval_id] = eval_result_dict
310+
311+
return eval_data
312+
313+
def get_train_example_ids(self) -> list[str]:
314+
"""Returns the UIDs of examples to use for training the agent."""
315+
return self._train_eval_case_ids
316+
317+
def get_validation_example_ids(self) -> list[str]:
318+
"""Returns the UIDs of examples to use for validating the optimized agent."""
319+
return self._validation_eval_case_ids
320+
321+
async def sample_and_score(
322+
self,
323+
candidate: Agent,
324+
example_set: Literal[
325+
Sampler.TRAIN_SET, Sampler.VALIDATION_SET
326+
] = Sampler.VALIDATION_SET,
327+
batch: Optional[list[str]] = None,
328+
capture_full_eval_data: bool = False,
329+
) -> UnstructuredSamplingResult:
330+
"""Evaluates the candidate agent on the batch of examples using the ADK LocalEvalService.
331+
332+
Args:
333+
candidate: The candidate agent to be evaluated.
334+
example_set: The set of examples to evaluate the candidate agent on.
335+
Possible values are "train" and "validation".
336+
batch: UIDs of examples to evaluate the candidate agent on. If not
337+
provided, all examples from the chosen set will be used.
338+
capture_full_eval_data: If false, it is enough to only calculate the
339+
scores for each example. If true, this method should also capture all
340+
other data required for optimizing the agent (e.g., outputs,
341+
trajectories, and tool calls).
342+
343+
Returns:
344+
The evaluation results, containing the scores for each example and (if
345+
requested) other data required for optimization.
346+
"""
347+
eval_set_id = self._get_selected_example_set_id(example_set)
348+
if batch is None:
349+
batch = self._get_all_example_ids(example_set)
350+
351+
eval_results = await self._evaluate_agent(candidate, eval_set_id, batch)
352+
_log_eval_summary(eval_results)
353+
354+
scores = {
355+
eval_result.eval_id: (
356+
1.0 if eval_result.final_eval_status == EvalStatus.PASSED else 0.0
357+
)
358+
for eval_result in eval_results
359+
}
360+
361+
eval_data = (
362+
self._extract_eval_data(eval_set_id, eval_results)
363+
if capture_full_eval_data
364+
else None
365+
)
366+
367+
return UnstructuredSamplingResult(scores=scores, data=eval_data)

src/google/adk/optimization/sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class Sampler(ABC, Generic[SamplingResult]):
3232
to get evaluation results for the candidate agent on the batch of examples.
3333
"""
3434

35+
TRAIN_SET = "train"
36+
VALIDATION_SET = "validation"
37+
3538
@abstractmethod
3639
def get_train_example_ids(self) -> list[str]:
3740
"""Returns the UIDs of examples to use for training the agent."""
@@ -46,7 +49,7 @@ def get_validation_example_ids(self) -> list[str]:
4649
async def sample_and_score(
4750
self,
4851
candidate: Agent,
49-
example_set: Literal["train", "validation"] = "validation",
52+
example_set: Literal[TRAIN_SET, VALIDATION_SET] = VALIDATION_SET,
5053
batch: Optional[list[str]] = None,
5154
capture_full_eval_data: bool = False,
5255
) -> SamplingResult:

0 commit comments

Comments
 (0)