Skip to content

Commit 8b85e6d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - auto-infer metric/candidate and validate inputs for generate_loss_clusters
PiperOrigin-RevId: 894079615
1 parent 8710b02 commit 8b85e6d

File tree

7 files changed

+1440
-64
lines changed

7 files changed

+1440
-64
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai import types
19+
import pytest
20+
21+
22+
def test_gen_loss_clusters(client):
23+
"""Tests that generate_loss_clusters() correctly calls the API and returns LossClusters."""
24+
eval_result = types.EvaluationResult()
25+
loss_clusters = client.evals.generate_loss_clusters(
26+
eval_result=eval_result,
27+
config=types.LossAnalysisConfig(
28+
metric="multi_turn_task_success_v1",
29+
candidate="travel-agent",
30+
),
31+
)
32+
assert type(loss_clusters).__name__ == "LossClusters"
33+
assert len(loss_clusters.results) == 1
34+
result = loss_clusters.results[0]
35+
assert result.config.metric == "multi_turn_task_success_v1"
36+
assert result.config.candidate == "travel-agent"
37+
assert len(result.clusters) == 2
38+
assert result.clusters[0].cluster_id == "cluster-1"
39+
assert result.clusters[0].taxonomy_entry.l1_category == "Tool Calling"
40+
assert (
41+
result.clusters[0].taxonomy_entry.l2_category == "Missing Tool Invocation"
42+
)
43+
assert result.clusters[0].item_count == 3
44+
assert result.clusters[1].cluster_id == "cluster-2"
45+
assert result.clusters[1].taxonomy_entry.l1_category == "Hallucination"
46+
assert result.clusters[1].item_count == 2
47+
48+
49+
pytest_plugins = ("pytest_asyncio",)
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_gen_loss_clusters_async(client):
54+
"""Tests that generate_loss_clusters() async correctly calls the API and returns LossClusters."""
55+
eval_result = types.EvaluationResult()
56+
loss_clusters = await client.aio.evals.generate_loss_clusters(
57+
eval_result=eval_result,
58+
config=types.LossAnalysisConfig(
59+
metric="multi_turn_task_success_v1",
60+
candidate="travel-agent",
61+
),
62+
)
63+
assert type(loss_clusters).__name__ == "LossClusters"
64+
assert len(loss_clusters.results) == 1
65+
result = loss_clusters.results[0]
66+
assert result.config.metric == "multi_turn_task_success_v1"
67+
assert len(result.clusters) == 2
68+
assert result.clusters[0].cluster_id == "cluster-1"
69+
assert result.clusters[1].cluster_id == "cluster-2"
70+
71+
72+
pytestmark = pytest_helper.setup(
73+
file=__file__,
74+
globals_for_file=globals(),
75+
test_method="evals.generate_loss_clusters",
76+
)

tests/unit/vertexai/genai/test_evals.py

Lines changed: 260 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@
2929
from google.cloud.aiplatform import initializer as aiplatform_initializer
3030
from vertexai import _genai
3131
from vertexai._genai import _evals_data_converters
32+
from vertexai._genai import _evals_utils
3233
from vertexai._genai import _evals_metric_handlers
3334
from vertexai._genai import _evals_visualization
3435
from vertexai._genai import _evals_metric_loaders
3536
from vertexai._genai import _gcs_utils
3637
from vertexai._genai import _observability_data_converter
38+
from vertexai._genai import _transformers
3739
from vertexai._genai import evals
3840
from vertexai._genai import types as vertexai_genai_types
41+
from vertexai._genai.types import common as common_types
3942
from google.genai import client
4043
from google.genai import errors as genai_errors
4144
from google.genai import types as genai_types
@@ -218,6 +221,262 @@ def test_get_api_client_with_none_location(
218221
mock_vertexai_client.assert_not_called()
219222

220223

224+
class TestTransformers:
225+
"""Unit tests for transformers."""
226+
227+
def test_t_inline_results(self):
228+
eval_result = common_types.EvaluationResult(
229+
eval_case_results=[
230+
common_types.EvalCaseResult(
231+
eval_case_index=0,
232+
response_candidate_results=[
233+
common_types.ResponseCandidateResult(
234+
response_index=0,
235+
metric_results={
236+
"tool_use_quality": common_types.EvalCaseMetricResult(
237+
score=0.0,
238+
explanation="Failed tool use",
239+
)
240+
},
241+
)
242+
],
243+
)
244+
],
245+
evaluation_dataset=[
246+
common_types.EvaluationDataset(
247+
eval_cases=[
248+
common_types.EvalCase(
249+
prompt=genai_types.Content(
250+
parts=[genai_types.Part(text="test prompt")]
251+
)
252+
)
253+
]
254+
)
255+
],
256+
metadata=common_types.EvaluationRunMetadata(candidate_names=["gemini-pro"]),
257+
)
258+
259+
payload = _transformers.t_inline_results([eval_result])
260+
261+
assert len(payload) == 1
262+
assert payload[0]["metric"] == "tool_use_quality"
263+
assert payload[0]["request"]["prompt"]["text"] == "test prompt"
264+
assert len(payload[0]["candidate_results"]) == 1
265+
assert payload[0]["candidate_results"][0]["candidate"] == "gemini-pro"
266+
assert payload[0]["candidate_results"][0]["score"] == 0.0
267+
268+
269+
class TestLossClusters:
270+
"""Unit tests for LossClusters and postprocessing."""
271+
272+
def test_postprocess_loss_clusters_response(self):
273+
response = common_types.GenerateLossClustersResponse(
274+
analysis_time="2026-04-01T10:00:00Z",
275+
results=[
276+
common_types.LossAnalysisResult(
277+
config=common_types.LossAnalysisConfig(
278+
metric="multi_turn_task_success_v1",
279+
candidate="travel-agent",
280+
),
281+
analysis_time="2026-04-01T10:00:00Z",
282+
clusters=[
283+
common_types.LossCluster(
284+
cluster_id="cluster-1",
285+
taxonomy_entry=common_types.LossTaxonomyEntry(
286+
l1_category="Tool Calling",
287+
l2_category="Missing Tool Invocation",
288+
description="The agent failed to invoke a required tool.",
289+
),
290+
item_count=3,
291+
),
292+
common_types.LossCluster(
293+
cluster_id="cluster-2",
294+
taxonomy_entry=common_types.LossTaxonomyEntry(
295+
l1_category="Hallucination",
296+
l2_category="Hallucination of Action",
297+
description="Verbally confirmed action without tool.",
298+
),
299+
item_count=2,
300+
),
301+
],
302+
)
303+
],
304+
)
305+
loss_clusters = _evals_utils._postprocess_loss_clusters_response(response)
306+
assert isinstance(loss_clusters, _evals_utils.LossClusters)
307+
assert len(loss_clusters.results) == 1
308+
assert loss_clusters.analysis_time == "2026-04-01T10:00:00Z"
309+
result = loss_clusters.results[0]
310+
assert result.config.metric == "multi_turn_task_success_v1"
311+
assert len(result.clusters) == 2
312+
assert result.clusters[0].cluster_id == "cluster-1"
313+
assert result.clusters[0].item_count == 3
314+
assert result.clusters[1].cluster_id == "cluster-2"
315+
316+
def test_loss_clusters_show_no_clusters(self, capsys):
317+
response = common_types.GenerateLossClustersResponse(results=[])
318+
loss_clusters = _evals_utils._postprocess_loss_clusters_response(response)
319+
loss_clusters.show()
320+
captured = capsys.readouterr()
321+
assert "No loss clusters found" in captured.out
322+
323+
def test_loss_clusters_show_with_clusters(self, capsys):
324+
response = common_types.GenerateLossClustersResponse(
325+
results=[
326+
common_types.LossAnalysisResult(
327+
config=common_types.LossAnalysisConfig(
328+
metric="test_metric",
329+
candidate="test-candidate",
330+
),
331+
clusters=[
332+
common_types.LossCluster(
333+
cluster_id="c1",
334+
taxonomy_entry=common_types.LossTaxonomyEntry(
335+
l1_category="Cat1",
336+
l2_category="SubCat1",
337+
),
338+
item_count=5,
339+
),
340+
],
341+
)
342+
],
343+
)
344+
loss_clusters = _evals_utils._postprocess_loss_clusters_response(response)
345+
loss_clusters.show()
346+
captured = capsys.readouterr()
347+
assert "Cat1" in captured.out
348+
assert "SubCat1" in captured.out
349+
350+
351+
def _make_eval_result(
352+
metrics=None,
353+
candidate_names=None,
354+
):
355+
"""Helper to create an EvaluationResult with the given metrics and candidates."""
356+
metrics = metrics or ["task_success_v1"]
357+
candidate_names = candidate_names or ["agent-1"]
358+
359+
metric_results = {}
360+
for m in metrics:
361+
metric_results[m] = common_types.EvalCaseMetricResult(metric_name=m)
362+
363+
eval_case_results = [
364+
common_types.EvalCaseResult(
365+
eval_case_index=0,
366+
response_candidate_results=[
367+
common_types.ResponseCandidateResult(
368+
response_index=0,
369+
metric_results=metric_results,
370+
)
371+
],
372+
)
373+
]
374+
metadata = common_types.EvaluationRunMetadata(
375+
candidate_names=candidate_names,
376+
)
377+
return common_types.EvaluationResult(
378+
eval_case_results=eval_case_results,
379+
metadata=metadata,
380+
)
381+
382+
383+
class TestResolveLossAnalysisConfig:
384+
"""Unit tests for _resolve_loss_analysis_config."""
385+
386+
def test_auto_infer_single_metric_and_candidate(self):
387+
eval_result = _make_eval_result(
388+
metrics=["task_success_v1"], candidate_names=["agent-1"]
389+
)
390+
resolved = _evals_utils._resolve_loss_analysis_config(
391+
eval_result=eval_result
392+
)
393+
assert resolved.metric == "task_success_v1"
394+
assert resolved.candidate == "agent-1"
395+
396+
def test_explicit_metric_and_candidate(self):
397+
eval_result = _make_eval_result(
398+
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
399+
)
400+
resolved = _evals_utils._resolve_loss_analysis_config(
401+
eval_result=eval_result, metric="m1", candidate="c2"
402+
)
403+
assert resolved.metric == "m1"
404+
assert resolved.candidate == "c2"
405+
406+
def test_config_provides_metric_and_candidate(self):
407+
eval_result = _make_eval_result(
408+
metrics=["m1"], candidate_names=["c1"]
409+
)
410+
config = common_types.LossAnalysisConfig(
411+
metric="m1", candidate="c1", predefined_taxonomy="my_taxonomy"
412+
)
413+
resolved = _evals_utils._resolve_loss_analysis_config(
414+
eval_result=eval_result, config=config
415+
)
416+
assert resolved.metric == "m1"
417+
assert resolved.candidate == "c1"
418+
assert resolved.predefined_taxonomy == "my_taxonomy"
419+
420+
def test_explicit_args_override_config(self):
421+
eval_result = _make_eval_result(
422+
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
423+
)
424+
config = common_types.LossAnalysisConfig(metric="m1", candidate="c1")
425+
resolved = _evals_utils._resolve_loss_analysis_config(
426+
eval_result=eval_result, config=config, metric="m2", candidate="c2"
427+
)
428+
assert resolved.metric == "m2"
429+
assert resolved.candidate == "c2"
430+
431+
def test_error_multiple_metrics_no_explicit(self):
432+
eval_result = _make_eval_result(
433+
metrics=["m1", "m2"], candidate_names=["c1"]
434+
)
435+
with pytest.raises(ValueError, match="multiple metrics"):
436+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
437+
438+
def test_error_multiple_candidates_no_explicit(self):
439+
eval_result = _make_eval_result(
440+
metrics=["m1"], candidate_names=["c1", "c2"]
441+
)
442+
with pytest.raises(ValueError, match="multiple candidates"):
443+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
444+
445+
def test_error_invalid_metric(self):
446+
eval_result = _make_eval_result(
447+
metrics=["m1"], candidate_names=["c1"]
448+
)
449+
with pytest.raises(ValueError, match="not found in eval_result"):
450+
_evals_utils._resolve_loss_analysis_config(
451+
eval_result=eval_result, metric="nonexistent"
452+
)
453+
454+
def test_error_invalid_candidate(self):
455+
eval_result = _make_eval_result(
456+
metrics=["m1"], candidate_names=["c1"]
457+
)
458+
with pytest.raises(ValueError, match="not found in eval_result"):
459+
_evals_utils._resolve_loss_analysis_config(
460+
eval_result=eval_result, candidate="nonexistent"
461+
)
462+
463+
def test_no_candidates_defaults_to_candidate_1(self):
464+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=[])
465+
eval_result = eval_result.model_copy(
466+
update={"metadata": common_types.EvaluationRunMetadata()}
467+
)
468+
resolved = _evals_utils._resolve_loss_analysis_config(
469+
eval_result=eval_result
470+
)
471+
assert resolved.metric == "m1"
472+
assert resolved.candidate == "candidate_1"
473+
474+
def test_no_eval_case_results_raises(self):
475+
eval_result = common_types.EvaluationResult()
476+
with pytest.raises(ValueError, match="no metric results"):
477+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
478+
479+
221480
class TestEvals:
222481
"""Unit tests for the GenAI client."""
223482

@@ -1754,7 +2013,7 @@ def test_run_inference_with_litellm_openai_request_format(
17542013
self,
17552014
mock_api_client_fixture,
17562015
):
1757-
"""Tests inference with LiteLLM where the row contains an chat completion request body."""
2016+
"""Tests inference with LiteLLM where the row contains a chat completion request body."""
17582017
with mock.patch(
17592018
"vertexai._genai._evals_common.litellm"
17602019
) as mock_litellm, mock.patch(

0 commit comments

Comments
 (0)